diff --git a/src/core/search/ast_expr.cc b/src/core/search/ast_expr.cc index 86c763d97..5f21003a5 100644 --- a/src/core/search/ast_expr.cc +++ b/src/core/search/ast_expr.cc @@ -56,11 +56,18 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) { tags.push_back(move(tag)); } -AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec) - : filter{make_unique(std::move(filter))}, +AstKnnNode::AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, + std::string_view score_alias) + : filter{nullptr}, limit{limit}, field{field.substr(1)}, - vec{std::move(vec)} { + vec{std::move(vec)}, + score_alias{score_alias} { +} + +AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) { + *this = std::move(self); + this->filter = make_unique(std::move(filter)); } } // namespace dfly::search diff --git a/src/core/search/ast_expr.h b/src/core/search/ast_expr.h index bd7331ceb..3518399ef 100644 --- a/src/core/search/ast_expr.h +++ b/src/core/search/ast_expr.h @@ -74,12 +74,19 @@ struct AstTagsNode { // Applies nearest neighbor search to the final result set struct AstKnnNode { - AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec); + AstKnnNode() = default; + AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias); + AstKnnNode(AstNode&& sub, AstKnnNode&& self); + + friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) { + return stream; + } std::unique_ptr filter; size_t limit; std::string field; OwnedFtVector vec; + std::string score_alias; }; using NodeVariants = @@ -89,6 +96,10 @@ using NodeVariants = struct AstNode : public NodeVariants { using variant::variant; + friend std::ostream& operator<<(std::ostream& stream, const AstNode& matrix) { + return stream; + } + const NodeVariants& Variant() const& { return *this; } @@ -99,11 +110,4 @@ using AstExpr = AstNode; } // namespace search } // namespace dfly -namespace std { - -inline std::ostream& operator<<(std::ostream& os, const dfly::search::AstExpr& ast) { - // os << "ast{" << ast->Debug() << "}"; - return os; -} - -} // namespace std +namespace std {} // namespace std diff --git a/src/core/search/lexer.lex b/src/core/search/lexer.lex index 2966607da..efffc7b09 100644 --- a/src/core/search/lexer.lex +++ b/src/core/search/lexer.lex @@ -62,6 +62,7 @@ term_char [_]|\w "}" return Parser::make_RCURLBR (loc()); "|" return Parser::make_OR_OP (loc()); "KNN" return Parser::make_KNN (loc()); +"AS" return Parser::make_AS (loc()); -?[0-9]+ return make_INT64(matched_view(), loc()); diff --git a/src/core/search/parser.y b/src/core/search/parser.y index 77aadbe34..1bd9f3f9b 100644 --- a/src/core/search/parser.y +++ b/src/core/search/parser.y @@ -57,6 +57,7 @@ using namespace std; RCURLBR "}" OR_OP "|" KNN "KNN" + AS "AS" ; %token AND_OP @@ -75,6 +76,9 @@ using namespace std; %nterm final_query filter search_expr search_unary_expr search_or_expr search_and_expr %nterm field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list +%nterm knn_query +%nterm opt_knn_alias + %printer { yyo << $$; } <*>; %% @@ -82,8 +86,16 @@ using namespace std; final_query: filter { driver->Set(move($1)); } - | filter ARROW LBRACKET KNN INT64 FIELD TERM RBRACKET - { driver->Set(AstKnnNode(move($1), $5, $6, BytesToFtVector($7))); } + | filter ARROW knn_query + { driver->Set(AstKnnNode(move($1), move($3))); } + +knn_query: + LBRACKET KNN INT64 FIELD TERM opt_knn_alias RBRACKET + { $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); } + +opt_knn_alias: + AS TERM { $$ = move($2); } + | { $$ = std::string{}; } filter: search_expr { $$ = move($1); } diff --git a/src/core/search/search.cc b/src/core/search/search.cc index d05766788..d9b6d6e41 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -488,10 +488,10 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index) const { return bs.Search(*query_); } -optional SearchAlgorithm::HasKnn() const { +optional> SearchAlgorithm::HasKnn() const { DCHECK(query_); - if (holds_alternative(*query_)) - return get(*query_).limit; + if (auto* knn = get_if(query_.get()); knn) + return make_pair(knn->limit, string_view{knn->score_alias}); return nullopt; } diff --git a/src/core/search/search.h b/src/core/search/search.h index 8679fb3ee..ffce5274b 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -107,8 +107,8 @@ class SearchAlgorithm { SearchResult Search(const FieldIndices* index) const; - // Return KNN limit if it is enabled - std::optional HasKnn() const; + // if enabled, return limit & alias for knn query + std::optional> HasKnn() const; void EnableProfiling(); diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 6a36b98fe..741c3f5e0 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -58,6 +58,11 @@ const absl::flat_hash_map kSchemaTy } // namespace +bool SearchParams::ShouldReturnField(std::string_view field) const { + auto cb = [field](const auto& entry) { return entry.first == field; }; + return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb); +} + optional ParseSearchFieldType(string_view name) { auto it = kSchemaTypes.find(name); return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt; diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index d5d353da8..48129efc0 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -63,6 +63,8 @@ struct SearchParams { bool IdsOnly() const { return return_fields && return_fields->empty(); } + + bool ShouldReturnField(std::string_view field) const; }; // Stores basic info about a document index. diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index f8715ba27..b71f6d92b 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -241,11 +241,11 @@ void ReplyWithResults(const SearchParams& params, absl::Span resul } } -void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span results, - ConnectionContext* cntx) { - vector docs; - for (const auto& shard_results : results) { - for (const auto& doc : shard_results.docs) { +void ReplyKnn(size_t knn_limit, string_view knn_score_alias, const SearchParams& params, + absl::Span results, ConnectionContext* cntx) { + vector docs; + for (auto& shard_results : results) { + for (auto& doc : shard_results.docs) { docs.push_back(&doc); } } @@ -262,15 +262,24 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Spanreply_builder()}; (*cntx)->StartArray(reply_size); (*cntx)->SendLong(docs.size()); for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) { - if (ids_only) + if (ids_only) { (*cntx)->SendBulkString(doc->key); - else - SendSerializedDoc(*doc, cntx); + continue; + } + + if (!knn_score_alias.empty()) + doc->values[knn_score_alias] = absl::StrCat(doc->knn_distance); + + SendSerializedDoc(*doc, cntx); } } @@ -433,8 +442,8 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendError(std::move(*res.error)); } - if (auto knn_limit = search_algo.HasKnn(); knn_limit) - ReplyKnn(*knn_limit, *params, absl::MakeSpan(docs), cntx); + if (auto knn_params = search_algo.HasKnn(); knn_params) + ReplyKnn(knn_params->first, knn_params->second, *params, absl::MakeSpan(docs), cntx); else ReplyWithResults(*params, absl::MakeSpan(docs), cntx); } diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 1c1d1a094..903d22be4 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -903,7 +903,7 @@ void Transaction::EnableShard(ShardId sid) { void Transaction::EnableAllShards() { unique_shard_cnt_ = shard_set->size(); - unique_shard_id_ = kInvalidSid; + unique_shard_id_ = unique_shard_cnt_ == 1 ? 0 : kInvalidSid; shard_data_.resize(shard_set->size()); for (auto& sd : shard_data_) sd.local_mask |= ACTIVE; diff --git a/tests/dragonfly/search_test.py b/tests/dragonfly/search_test.py index a110d0bb7..cb9b8ab08 100644 --- a/tests/dragonfly/search_test.py +++ b/tests/dragonfly/search_test.py @@ -304,6 +304,42 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type, algo_type) await i3.dropindex() +async def test_knn_score_return(async_client: aioredis.Redis): + i1 = async_client.ft("i1") + vector_field = VectorField( + "pos", + algorithm="FLAT", + attributes={ + "DIM": 1, + "DISTANCE_METRIC": "L2", + "INITICAL_CAP": 100, + }, + ) + + await i1.create_index( + [vector_field], + definition=IndexDefinition(index_type=IndexType.HASH), + ) + + pipe = async_client.pipeline() + for i in range(100): + pipe.hset(f"k{i}", mapping={"pos": np.array(i, dtype=np.float32).tobytes()}) + await pipe.execute() + + params = {"vec": np.array([1.0], dtype=np.float32).tobytes()} + result = await i1.search("* => [KNN 3 @pos $vec AS distance]", params) + + assert result.total == 3 + assert [d["distance"] for d in result.docs] == ["0", "1", "1"] + + result = await i1.search( + Query("* => [KNN 3 @pos $vec AS distance]").return_fields("pos"), params + ) + assert not any(hasattr(d, "distance") for d in result.docs) + + await i1.dropindex() + + @dfly_args({"proactor_threads": 4, "dbfilename": "search-data"}) async def test_index_persistence(df_server): client = aioredis.Redis(port=df_server.port)