feat(search): return scores (#1870)

* feat(search): return scores

---------

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-09-25 10:03:17 +03:00 committed by GitHub
parent 19783face5
commit fc0943989e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 30 deletions

View file

@ -56,11 +56,18 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
tags.push_back(move(tag)); tags.push_back(move(tag));
} }
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec) AstKnnNode::AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec,
: filter{make_unique<AstNode>(std::move(filter))}, std::string_view score_alias)
: filter{nullptr},
limit{limit}, limit{limit},
field{field.substr(1)}, field{field.substr(1)},
vec{std::move(vec)} { vec{std::move(vec)},
score_alias{score_alias} {
}
AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
*this = std::move(self);
this->filter = make_unique<AstNode>(std::move(filter));
} }
} // namespace dfly::search } // namespace dfly::search

View file

@ -74,12 +74,19 @@ struct AstTagsNode {
// Applies nearest neighbor search to the final result set // Applies nearest neighbor search to the final result set
struct AstKnnNode { struct AstKnnNode {
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec); AstKnnNode() = default;
AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias);
AstKnnNode(AstNode&& sub, AstKnnNode&& self);
friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) {
return stream;
}
std::unique_ptr<AstNode> filter; std::unique_ptr<AstNode> filter;
size_t limit; size_t limit;
std::string field; std::string field;
OwnedFtVector vec; OwnedFtVector vec;
std::string score_alias;
}; };
using NodeVariants = using NodeVariants =
@ -89,6 +96,10 @@ using NodeVariants =
struct AstNode : public NodeVariants { struct AstNode : public NodeVariants {
using variant::variant; using variant::variant;
friend std::ostream& operator<<(std::ostream& stream, const AstNode& matrix) {
return stream;
}
const NodeVariants& Variant() const& { const NodeVariants& Variant() const& {
return *this; return *this;
} }
@ -99,11 +110,4 @@ using AstExpr = AstNode;
} // namespace search } // namespace search
} // namespace dfly } // namespace dfly
namespace std { namespace std {} // namespace std
inline std::ostream& operator<<(std::ostream& os, const dfly::search::AstExpr& ast) {
// os << "ast{" << ast->Debug() << "}";
return os;
}
} // namespace std

View file

@ -62,6 +62,7 @@ term_char [_]|\w
"}" return Parser::make_RCURLBR (loc()); "}" return Parser::make_RCURLBR (loc());
"|" return Parser::make_OR_OP (loc()); "|" return Parser::make_OR_OP (loc());
"KNN" return Parser::make_KNN (loc()); "KNN" return Parser::make_KNN (loc());
"AS" return Parser::make_AS (loc());
-?[0-9]+ return make_INT64(matched_view(), loc()); -?[0-9]+ return make_INT64(matched_view(), loc());

View file

@ -57,6 +57,7 @@ using namespace std;
RCURLBR "}" RCURLBR "}"
OR_OP "|" OR_OP "|"
KNN "KNN" KNN "KNN"
AS "AS"
; ;
%token AND_OP %token AND_OP
@ -75,6 +76,9 @@ using namespace std;
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr %nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list %nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
%nterm <AstKnnNode> knn_query
%nterm <std::string> opt_knn_alias
%printer { yyo << $$; } <*>; %printer { yyo << $$; } <*>;
%% %%
@ -82,8 +86,16 @@ using namespace std;
final_query: final_query:
filter filter
{ driver->Set(move($1)); } { driver->Set(move($1)); }
| filter ARROW LBRACKET KNN INT64 FIELD TERM RBRACKET | filter ARROW knn_query
{ driver->Set(AstKnnNode(move($1), $5, $6, BytesToFtVector($7))); } { driver->Set(AstKnnNode(move($1), move($3))); }
knn_query:
LBRACKET KNN INT64 FIELD TERM opt_knn_alias RBRACKET
{ $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); }
opt_knn_alias:
AS TERM { $$ = move($2); }
| { $$ = std::string{}; }
filter: filter:
search_expr { $$ = move($1); } search_expr { $$ = move($1); }

View file

@ -488,10 +488,10 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
return bs.Search(*query_); return bs.Search(*query_);
} }
optional<size_t> SearchAlgorithm::HasKnn() const { optional<pair<size_t, string_view>> SearchAlgorithm::HasKnn() const {
DCHECK(query_); DCHECK(query_);
if (holds_alternative<AstKnnNode>(*query_)) if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
return get<AstKnnNode>(*query_).limit; return make_pair(knn->limit, string_view{knn->score_alias});
return nullopt; return nullopt;
} }

View file

@ -107,8 +107,8 @@ class SearchAlgorithm {
SearchResult Search(const FieldIndices* index) const; SearchResult Search(const FieldIndices* index) const;
// Return KNN limit if it is enabled // if enabled, return limit & alias for knn query
std::optional<size_t> HasKnn() const; std::optional<std::pair<size_t /*limit*/, std::string_view /*alias*/>> HasKnn() const;
void EnableProfiling(); void EnableProfiling();

View file

@ -58,6 +58,11 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
} // namespace } // namespace
bool SearchParams::ShouldReturnField(std::string_view field) const {
auto cb = [field](const auto& entry) { return entry.first == field; };
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
}
optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) { optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
auto it = kSchemaTypes.find(name); auto it = kSchemaTypes.find(name);
return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt; return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt;

View file

@ -63,6 +63,8 @@ struct SearchParams {
bool IdsOnly() const { bool IdsOnly() const {
return return_fields && return_fields->empty(); return return_fields && return_fields->empty();
} }
bool ShouldReturnField(std::string_view field) const;
}; };
// Stores basic info about a document index. // Stores basic info about a document index.

View file

@ -241,11 +241,11 @@ void ReplyWithResults(const SearchParams& params, absl::Span<SearchResult> resul
} }
} }
void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchResult> results, void ReplyKnn(size_t knn_limit, string_view knn_score_alias, const SearchParams& params,
ConnectionContext* cntx) { absl::Span<SearchResult> results, ConnectionContext* cntx) {
vector<const SerializedSearchDoc*> docs; vector<SerializedSearchDoc*> docs;
for (const auto& shard_results : results) { for (auto& shard_results : results) {
for (const auto& doc : shard_results.docs) { for (auto& doc : shard_results.docs) {
docs.push_back(&doc); docs.push_back(&doc);
} }
} }
@ -262,15 +262,24 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
bool ids_only = params.IdsOnly(); bool ids_only = params.IdsOnly();
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1); size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);
// Clear knn score alias if its excluded from return values
if (!params.ShouldReturnField(knn_score_alias))
knn_score_alias = "";
facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()}; facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()};
(*cntx)->StartArray(reply_size); (*cntx)->StartArray(reply_size);
(*cntx)->SendLong(docs.size()); (*cntx)->SendLong(docs.size());
for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) { for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) {
if (ids_only) if (ids_only) {
(*cntx)->SendBulkString(doc->key); (*cntx)->SendBulkString(doc->key);
else continue;
SendSerializedDoc(*doc, cntx); }
if (!knn_score_alias.empty())
doc->values[knn_score_alias] = absl::StrCat(doc->knn_distance);
SendSerializedDoc(*doc, cntx);
} }
} }
@ -433,8 +442,8 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(std::move(*res.error)); return (*cntx)->SendError(std::move(*res.error));
} }
if (auto knn_limit = search_algo.HasKnn(); knn_limit) if (auto knn_params = search_algo.HasKnn(); knn_params)
ReplyKnn(*knn_limit, *params, absl::MakeSpan(docs), cntx); ReplyKnn(knn_params->first, knn_params->second, *params, absl::MakeSpan(docs), cntx);
else else
ReplyWithResults(*params, absl::MakeSpan(docs), cntx); ReplyWithResults(*params, absl::MakeSpan(docs), cntx);
} }

View file

@ -903,7 +903,7 @@ void Transaction::EnableShard(ShardId sid) {
void Transaction::EnableAllShards() { void Transaction::EnableAllShards() {
unique_shard_cnt_ = shard_set->size(); unique_shard_cnt_ = shard_set->size();
unique_shard_id_ = kInvalidSid; unique_shard_id_ = unique_shard_cnt_ == 1 ? 0 : kInvalidSid;
shard_data_.resize(shard_set->size()); shard_data_.resize(shard_set->size());
for (auto& sd : shard_data_) for (auto& sd : shard_data_)
sd.local_mask |= ACTIVE; sd.local_mask |= ACTIVE;

View file

@ -304,6 +304,42 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type, algo_type)
await i3.dropindex() await i3.dropindex()
async def test_knn_score_return(async_client: aioredis.Redis):
i1 = async_client.ft("i1")
vector_field = VectorField(
"pos",
algorithm="FLAT",
attributes={
"DIM": 1,
"DISTANCE_METRIC": "L2",
"INITICAL_CAP": 100,
},
)
await i1.create_index(
[vector_field],
definition=IndexDefinition(index_type=IndexType.HASH),
)
pipe = async_client.pipeline()
for i in range(100):
pipe.hset(f"k{i}", mapping={"pos": np.array(i, dtype=np.float32).tobytes()})
await pipe.execute()
params = {"vec": np.array([1.0], dtype=np.float32).tobytes()}
result = await i1.search("* => [KNN 3 @pos $vec AS distance]", params)
assert result.total == 3
assert [d["distance"] for d in result.docs] == ["0", "1", "1"]
result = await i1.search(
Query("* => [KNN 3 @pos $vec AS distance]").return_fields("pos"), params
)
assert not any(hasattr(d, "distance") for d in result.docs)
await i1.dropindex()
@dfly_args({"proactor_threads": 4, "dbfilename": "search-data"}) @dfly_args({"proactor_threads": 4, "dbfilename": "search-data"})
async def test_index_persistence(df_server): async def test_index_persistence(df_server):
client = aioredis.Redis(port=df_server.port) client = aioredis.Redis(port=df_server.port)