mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 10:25:47 +02:00
fix(search): Fix score alias for knn wrapped in sort (#2215)
Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
7d53d196aa
commit
cc6210d077
2 changed files with 39 additions and 13 deletions
|
@ -314,11 +314,14 @@ struct BasicSearch {
|
|||
// SORTBY field [DESC]: Sort by field. Part of params and not "core query".
|
||||
IndexResult Search(const AstSortNode& node, string_view active_field) {
|
||||
auto sub_results = SearchGeneric(*node.filter, active_field);
|
||||
preagg_total_ = sub_results.Size();
|
||||
|
||||
// Skip sorting again for KNN queries
|
||||
if (holds_alternative<AstKnnNode>(node.filter->Variant()))
|
||||
// Skip sorting again for KNN queries, reverse if needed will be applied on aggregation
|
||||
if (auto knn = get_if<AstKnnNode>(&node.filter->Variant());
|
||||
knn && knn->score_alias == node.field) {
|
||||
return sub_results;
|
||||
}
|
||||
|
||||
preagg_total_ = sub_results.Size();
|
||||
|
||||
if (auto* sort_index = GetSortIndex(node.field); sort_index) {
|
||||
auto ids_vec = sub_results.Take();
|
||||
|
@ -398,7 +401,7 @@ struct BasicSearch {
|
|||
|
||||
// Top level results don't need to be sorted, because they will be scored, sorted by fields or
|
||||
// used by knn
|
||||
DCHECK(top_level ||
|
||||
DCHECK(top_level || holds_alternative<AstKnnNode>(node.Variant()) ||
|
||||
visit([](auto* set) { return is_sorted(set->begin(), set->end()); }, result.Borrowed()));
|
||||
|
||||
if (profile_builder_)
|
||||
|
@ -590,8 +593,16 @@ optional<AggregationInfo> SearchAlgorithm::HasAggregation() const {
|
|||
DCHECK(query_);
|
||||
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
|
||||
return AggregationInfo{knn->limit, string_view{knn->score_alias}, false};
|
||||
if (holds_alternative<AstSortNode>(*query_))
|
||||
return AggregationInfo{nullopt, "", get<AstSortNode>(*query_.get()).descending};
|
||||
|
||||
if (auto* sort = get_if<AstSortNode>(query_.get()); sort) {
|
||||
string_view alias = "";
|
||||
if (auto* knn = get_if<AstKnnNode>(&sort->filter->Variant());
|
||||
knn && knn->score_alias == sort->field)
|
||||
alias = knn->score_alias;
|
||||
|
||||
return AggregationInfo{nullopt, alias, sort->descending};
|
||||
}
|
||||
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
|
|
|
@ -362,20 +362,28 @@ TEST_F(SearchFamilyTest, TestLimit) {
|
|||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, TestReturn) {
|
||||
for (unsigned i = 0; i < 20; i++)
|
||||
Run({"hset", "k"s + to_string(i), "longA", to_string(i), "longB", to_string(i + 1), "longC",
|
||||
to_string(i + 2), "secret", to_string(i + 3)});
|
||||
auto floatsv = [](const float* f) -> string_view {
|
||||
return {reinterpret_cast<const char*>(f), sizeof(float)};
|
||||
};
|
||||
|
||||
Run({"ft.create", "i1", "SCHEMA", "longA", "AS", "justA", "TEXT", "longB", "AS", "justB",
|
||||
"NUMERIC", "longC", "AS", "justC", "NUMERIC"});
|
||||
for (unsigned i = 0; i < 20; i++) {
|
||||
const float score = i;
|
||||
Run({"hset", "k"s + to_string(i), "longA", to_string(i), "longB", to_string(i + 1), "longC",
|
||||
to_string(i + 2), "secret", to_string(i + 3), "vector", floatsv(&score)});
|
||||
}
|
||||
|
||||
Run({"ft.create", "i1", "SCHEMA", "longA", "AS", "justA", "TEXT",
|
||||
"longB", "AS", "justB", "NUMERIC", "longC", "AS", "justC",
|
||||
"NUMERIC", "vector", "VECTOR", "FLAT", "2", "DIM", "1"});
|
||||
|
||||
auto MatchEntry = [](string key, auto... fields) {
|
||||
return RespArray(ElementsAre(IntArg(1), "k0", RespArray(UnorderedElementsAre(fields...))));
|
||||
return RespArray(ElementsAre(IntArg(1), key, RespArray(UnorderedElementsAre(fields...))));
|
||||
};
|
||||
|
||||
// Check all fields are returned
|
||||
auto resp = Run({"ft.search", "i1", "@justA:0"});
|
||||
EXPECT_THAT(resp, MatchEntry("k0", "longA", "0", "longB", "1", "longC", "2", "secret", "3"));
|
||||
EXPECT_THAT(resp, MatchEntry("k0", "longA", "0", "longB", "1", "longC", "2", "secret", "3",
|
||||
"vector", "[0]"));
|
||||
|
||||
// Check no fields are returned
|
||||
resp = Run({"ft.search", "i1", "@justA:0", "return", "0"});
|
||||
|
@ -399,6 +407,13 @@ TEST_F(SearchFamilyTest, TestReturn) {
|
|||
// Check non-existing field
|
||||
resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"});
|
||||
EXPECT_THAT(resp, MatchEntry("k0", "nothere", ""));
|
||||
|
||||
// Check sort doesn't shadow knn return alias
|
||||
const float score = 20;
|
||||
resp = Run({"ft.search", "i1", "@justA:0 => [KNN 20 @vector $vector AS vec_return]", "SORTBY",
|
||||
"vec_return", "DESC", "RETURN", "1", "vec_return", "PARAMS", "2", "vector",
|
||||
floatsv(&score)});
|
||||
EXPECT_THAT(resp, MatchEntry("k0", "vec_return", "20"));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, SimpleUpdates) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue