fix(search): Fix score alias for knn wrapped in sort (#2215)

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-11-25 16:11:59 +03:00 committed by GitHub
parent 7d53d196aa
commit cc6210d077
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 13 deletions

View file

@ -314,11 +314,14 @@ struct BasicSearch {
// SORTBY field [DESC]: Sort by field. Part of params and not "core query".
IndexResult Search(const AstSortNode& node, string_view active_field) {
auto sub_results = SearchGeneric(*node.filter, active_field);
preagg_total_ = sub_results.Size();
// Skip sorting again for KNN queries
if (holds_alternative<AstKnnNode>(node.filter->Variant()))
// Skip sorting again for KNN queries, reverse if needed will be applied on aggregation
if (auto knn = get_if<AstKnnNode>(&node.filter->Variant());
knn && knn->score_alias == node.field) {
return sub_results;
}
preagg_total_ = sub_results.Size();
if (auto* sort_index = GetSortIndex(node.field); sort_index) {
auto ids_vec = sub_results.Take();
@ -398,7 +401,7 @@ struct BasicSearch {
// Top level results don't need to be sorted, because they will be scored, sorted by fields or
// used by knn
DCHECK(top_level ||
DCHECK(top_level || holds_alternative<AstKnnNode>(node.Variant()) ||
visit([](auto* set) { return is_sorted(set->begin(), set->end()); }, result.Borrowed()));
if (profile_builder_)
@ -590,8 +593,16 @@ optional<AggregationInfo> SearchAlgorithm::HasAggregation() const {
DCHECK(query_);
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
return AggregationInfo{knn->limit, string_view{knn->score_alias}, false};
if (holds_alternative<AstSortNode>(*query_))
return AggregationInfo{nullopt, "", get<AstSortNode>(*query_.get()).descending};
if (auto* sort = get_if<AstSortNode>(query_.get()); sort) {
string_view alias = "";
if (auto* knn = get_if<AstKnnNode>(&sort->filter->Variant());
knn && knn->score_alias == sort->field)
alias = knn->score_alias;
return AggregationInfo{nullopt, alias, sort->descending};
}
return nullopt;
}

View file

@ -362,20 +362,28 @@ TEST_F(SearchFamilyTest, TestLimit) {
}
TEST_F(SearchFamilyTest, TestReturn) {
for (unsigned i = 0; i < 20; i++)
Run({"hset", "k"s + to_string(i), "longA", to_string(i), "longB", to_string(i + 1), "longC",
to_string(i + 2), "secret", to_string(i + 3)});
auto floatsv = [](const float* f) -> string_view {
return {reinterpret_cast<const char*>(f), sizeof(float)};
};
Run({"ft.create", "i1", "SCHEMA", "longA", "AS", "justA", "TEXT", "longB", "AS", "justB",
"NUMERIC", "longC", "AS", "justC", "NUMERIC"});
for (unsigned i = 0; i < 20; i++) {
const float score = i;
Run({"hset", "k"s + to_string(i), "longA", to_string(i), "longB", to_string(i + 1), "longC",
to_string(i + 2), "secret", to_string(i + 3), "vector", floatsv(&score)});
}
Run({"ft.create", "i1", "SCHEMA", "longA", "AS", "justA", "TEXT",
"longB", "AS", "justB", "NUMERIC", "longC", "AS", "justC",
"NUMERIC", "vector", "VECTOR", "FLAT", "2", "DIM", "1"});
auto MatchEntry = [](string key, auto... fields) {
return RespArray(ElementsAre(IntArg(1), "k0", RespArray(UnorderedElementsAre(fields...))));
return RespArray(ElementsAre(IntArg(1), key, RespArray(UnorderedElementsAre(fields...))));
};
// Check all fields are returned
auto resp = Run({"ft.search", "i1", "@justA:0"});
EXPECT_THAT(resp, MatchEntry("k0", "longA", "0", "longB", "1", "longC", "2", "secret", "3"));
EXPECT_THAT(resp, MatchEntry("k0", "longA", "0", "longB", "1", "longC", "2", "secret", "3",
"vector", "[0]"));
// Check no fields are returned
resp = Run({"ft.search", "i1", "@justA:0", "return", "0"});
@ -399,6 +407,13 @@ TEST_F(SearchFamilyTest, TestReturn) {
// Check non-existing field
resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"});
EXPECT_THAT(resp, MatchEntry("k0", "nothere", ""));
// Check sort doesn't shadow knn return alias
const float score = 20;
resp = Run({"ft.search", "i1", "@justA:0 => [KNN 20 @vector $vector AS vec_return]", "SORTBY",
"vec_return", "DESC", "RETURN", "1", "vec_return", "PARAMS", "2", "vector",
floatsv(&score)});
EXPECT_THAT(resp, MatchEntry("k0", "vec_return", "20"));
}
TEST_F(SearchFamilyTest, SimpleUpdates) {