mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
feat(search): return scores (#1870)
* feat(search): return scores --------- Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
19783face5
commit
fc0943989e
11 changed files with 106 additions and 30 deletions
|
@ -56,11 +56,18 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
|
||||||
tags.push_back(move(tag));
|
tags.push_back(move(tag));
|
||||||
}
|
}
|
||||||
|
|
||||||
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec)
|
AstKnnNode::AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec,
|
||||||
: filter{make_unique<AstNode>(std::move(filter))},
|
std::string_view score_alias)
|
||||||
|
: filter{nullptr},
|
||||||
limit{limit},
|
limit{limit},
|
||||||
field{field.substr(1)},
|
field{field.substr(1)},
|
||||||
vec{std::move(vec)} {
|
vec{std::move(vec)},
|
||||||
|
score_alias{score_alias} {
|
||||||
|
}
|
||||||
|
|
||||||
|
AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
|
||||||
|
*this = std::move(self);
|
||||||
|
this->filter = make_unique<AstNode>(std::move(filter));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dfly::search
|
} // namespace dfly::search
|
||||||
|
|
|
@ -74,12 +74,19 @@ struct AstTagsNode {
|
||||||
|
|
||||||
// Applies nearest neighbor search to the final result set
|
// Applies nearest neighbor search to the final result set
|
||||||
struct AstKnnNode {
|
struct AstKnnNode {
|
||||||
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec);
|
AstKnnNode() = default;
|
||||||
|
AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias);
|
||||||
|
AstKnnNode(AstNode&& sub, AstKnnNode&& self);
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) {
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<AstNode> filter;
|
std::unique_ptr<AstNode> filter;
|
||||||
size_t limit;
|
size_t limit;
|
||||||
std::string field;
|
std::string field;
|
||||||
OwnedFtVector vec;
|
OwnedFtVector vec;
|
||||||
|
std::string score_alias;
|
||||||
};
|
};
|
||||||
|
|
||||||
using NodeVariants =
|
using NodeVariants =
|
||||||
|
@ -89,6 +96,10 @@ using NodeVariants =
|
||||||
struct AstNode : public NodeVariants {
|
struct AstNode : public NodeVariants {
|
||||||
using variant::variant;
|
using variant::variant;
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& stream, const AstNode& matrix) {
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
const NodeVariants& Variant() const& {
|
const NodeVariants& Variant() const& {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -99,11 +110,4 @@ using AstExpr = AstNode;
|
||||||
} // namespace search
|
} // namespace search
|
||||||
} // namespace dfly
|
} // namespace dfly
|
||||||
|
|
||||||
namespace std {
|
namespace std {} // namespace std
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& os, const dfly::search::AstExpr& ast) {
|
|
||||||
// os << "ast{" << ast->Debug() << "}";
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace std
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ term_char [_]|\w
|
||||||
"}" return Parser::make_RCURLBR (loc());
|
"}" return Parser::make_RCURLBR (loc());
|
||||||
"|" return Parser::make_OR_OP (loc());
|
"|" return Parser::make_OR_OP (loc());
|
||||||
"KNN" return Parser::make_KNN (loc());
|
"KNN" return Parser::make_KNN (loc());
|
||||||
|
"AS" return Parser::make_AS (loc());
|
||||||
|
|
||||||
-?[0-9]+ return make_INT64(matched_view(), loc());
|
-?[0-9]+ return make_INT64(matched_view(), loc());
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,7 @@ using namespace std;
|
||||||
RCURLBR "}"
|
RCURLBR "}"
|
||||||
OR_OP "|"
|
OR_OP "|"
|
||||||
KNN "KNN"
|
KNN "KNN"
|
||||||
|
AS "AS"
|
||||||
;
|
;
|
||||||
|
|
||||||
%token AND_OP
|
%token AND_OP
|
||||||
|
@ -75,6 +76,9 @@ using namespace std;
|
||||||
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr
|
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr
|
||||||
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
|
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
|
||||||
|
|
||||||
|
%nterm <AstKnnNode> knn_query
|
||||||
|
%nterm <std::string> opt_knn_alias
|
||||||
|
|
||||||
%printer { yyo << $$; } <*>;
|
%printer { yyo << $$; } <*>;
|
||||||
|
|
||||||
%%
|
%%
|
||||||
|
@ -82,8 +86,16 @@ using namespace std;
|
||||||
final_query:
|
final_query:
|
||||||
filter
|
filter
|
||||||
{ driver->Set(move($1)); }
|
{ driver->Set(move($1)); }
|
||||||
| filter ARROW LBRACKET KNN INT64 FIELD TERM RBRACKET
|
| filter ARROW knn_query
|
||||||
{ driver->Set(AstKnnNode(move($1), $5, $6, BytesToFtVector($7))); }
|
{ driver->Set(AstKnnNode(move($1), move($3))); }
|
||||||
|
|
||||||
|
knn_query:
|
||||||
|
LBRACKET KNN INT64 FIELD TERM opt_knn_alias RBRACKET
|
||||||
|
{ $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); }
|
||||||
|
|
||||||
|
opt_knn_alias:
|
||||||
|
AS TERM { $$ = move($2); }
|
||||||
|
| { $$ = std::string{}; }
|
||||||
|
|
||||||
filter:
|
filter:
|
||||||
search_expr { $$ = move($1); }
|
search_expr { $$ = move($1); }
|
||||||
|
|
|
@ -488,10 +488,10 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
|
||||||
return bs.Search(*query_);
|
return bs.Search(*query_);
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<size_t> SearchAlgorithm::HasKnn() const {
|
optional<pair<size_t, string_view>> SearchAlgorithm::HasKnn() const {
|
||||||
DCHECK(query_);
|
DCHECK(query_);
|
||||||
if (holds_alternative<AstKnnNode>(*query_))
|
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
|
||||||
return get<AstKnnNode>(*query_).limit;
|
return make_pair(knn->limit, string_view{knn->score_alias});
|
||||||
return nullopt;
|
return nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -107,8 +107,8 @@ class SearchAlgorithm {
|
||||||
|
|
||||||
SearchResult Search(const FieldIndices* index) const;
|
SearchResult Search(const FieldIndices* index) const;
|
||||||
|
|
||||||
// Return KNN limit if it is enabled
|
// if enabled, return limit & alias for knn query
|
||||||
std::optional<size_t> HasKnn() const;
|
std::optional<std::pair<size_t /*limit*/, std::string_view /*alias*/>> HasKnn() const;
|
||||||
|
|
||||||
void EnableProfiling();
|
void EnableProfiling();
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,11 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
bool SearchParams::ShouldReturnField(std::string_view field) const {
|
||||||
|
auto cb = [field](const auto& entry) { return entry.first == field; };
|
||||||
|
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
|
||||||
|
}
|
||||||
|
|
||||||
optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
|
optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
|
||||||
auto it = kSchemaTypes.find(name);
|
auto it = kSchemaTypes.find(name);
|
||||||
return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt;
|
return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt;
|
||||||
|
|
|
@ -63,6 +63,8 @@ struct SearchParams {
|
||||||
bool IdsOnly() const {
|
bool IdsOnly() const {
|
||||||
return return_fields && return_fields->empty();
|
return return_fields && return_fields->empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ShouldReturnField(std::string_view field) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Stores basic info about a document index.
|
// Stores basic info about a document index.
|
||||||
|
|
|
@ -241,11 +241,11 @@ void ReplyWithResults(const SearchParams& params, absl::Span<SearchResult> resul
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchResult> results,
|
void ReplyKnn(size_t knn_limit, string_view knn_score_alias, const SearchParams& params,
|
||||||
ConnectionContext* cntx) {
|
absl::Span<SearchResult> results, ConnectionContext* cntx) {
|
||||||
vector<const SerializedSearchDoc*> docs;
|
vector<SerializedSearchDoc*> docs;
|
||||||
for (const auto& shard_results : results) {
|
for (auto& shard_results : results) {
|
||||||
for (const auto& doc : shard_results.docs) {
|
for (auto& doc : shard_results.docs) {
|
||||||
docs.push_back(&doc);
|
docs.push_back(&doc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -262,15 +262,24 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
|
||||||
bool ids_only = params.IdsOnly();
|
bool ids_only = params.IdsOnly();
|
||||||
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);
|
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);
|
||||||
|
|
||||||
|
// Clear knn score alias if its excluded from return values
|
||||||
|
if (!params.ShouldReturnField(knn_score_alias))
|
||||||
|
knn_score_alias = "";
|
||||||
|
|
||||||
facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()};
|
facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()};
|
||||||
|
|
||||||
(*cntx)->StartArray(reply_size);
|
(*cntx)->StartArray(reply_size);
|
||||||
(*cntx)->SendLong(docs.size());
|
(*cntx)->SendLong(docs.size());
|
||||||
for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) {
|
for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) {
|
||||||
if (ids_only)
|
if (ids_only) {
|
||||||
(*cntx)->SendBulkString(doc->key);
|
(*cntx)->SendBulkString(doc->key);
|
||||||
else
|
continue;
|
||||||
SendSerializedDoc(*doc, cntx);
|
}
|
||||||
|
|
||||||
|
if (!knn_score_alias.empty())
|
||||||
|
doc->values[knn_score_alias] = absl::StrCat(doc->knn_distance);
|
||||||
|
|
||||||
|
SendSerializedDoc(*doc, cntx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -433,8 +442,8 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
|
||||||
return (*cntx)->SendError(std::move(*res.error));
|
return (*cntx)->SendError(std::move(*res.error));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto knn_limit = search_algo.HasKnn(); knn_limit)
|
if (auto knn_params = search_algo.HasKnn(); knn_params)
|
||||||
ReplyKnn(*knn_limit, *params, absl::MakeSpan(docs), cntx);
|
ReplyKnn(knn_params->first, knn_params->second, *params, absl::MakeSpan(docs), cntx);
|
||||||
else
|
else
|
||||||
ReplyWithResults(*params, absl::MakeSpan(docs), cntx);
|
ReplyWithResults(*params, absl::MakeSpan(docs), cntx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -903,7 +903,7 @@ void Transaction::EnableShard(ShardId sid) {
|
||||||
|
|
||||||
void Transaction::EnableAllShards() {
|
void Transaction::EnableAllShards() {
|
||||||
unique_shard_cnt_ = shard_set->size();
|
unique_shard_cnt_ = shard_set->size();
|
||||||
unique_shard_id_ = kInvalidSid;
|
unique_shard_id_ = unique_shard_cnt_ == 1 ? 0 : kInvalidSid;
|
||||||
shard_data_.resize(shard_set->size());
|
shard_data_.resize(shard_set->size());
|
||||||
for (auto& sd : shard_data_)
|
for (auto& sd : shard_data_)
|
||||||
sd.local_mask |= ACTIVE;
|
sd.local_mask |= ACTIVE;
|
||||||
|
|
|
@ -304,6 +304,42 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type, algo_type)
|
||||||
await i3.dropindex()
|
await i3.dropindex()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_knn_score_return(async_client: aioredis.Redis):
|
||||||
|
i1 = async_client.ft("i1")
|
||||||
|
vector_field = VectorField(
|
||||||
|
"pos",
|
||||||
|
algorithm="FLAT",
|
||||||
|
attributes={
|
||||||
|
"DIM": 1,
|
||||||
|
"DISTANCE_METRIC": "L2",
|
||||||
|
"INITICAL_CAP": 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await i1.create_index(
|
||||||
|
[vector_field],
|
||||||
|
definition=IndexDefinition(index_type=IndexType.HASH),
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = async_client.pipeline()
|
||||||
|
for i in range(100):
|
||||||
|
pipe.hset(f"k{i}", mapping={"pos": np.array(i, dtype=np.float32).tobytes()})
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
params = {"vec": np.array([1.0], dtype=np.float32).tobytes()}
|
||||||
|
result = await i1.search("* => [KNN 3 @pos $vec AS distance]", params)
|
||||||
|
|
||||||
|
assert result.total == 3
|
||||||
|
assert [d["distance"] for d in result.docs] == ["0", "1", "1"]
|
||||||
|
|
||||||
|
result = await i1.search(
|
||||||
|
Query("* => [KNN 3 @pos $vec AS distance]").return_fields("pos"), params
|
||||||
|
)
|
||||||
|
assert not any(hasattr(d, "distance") for d in result.docs)
|
||||||
|
|
||||||
|
await i1.dropindex()
|
||||||
|
|
||||||
|
|
||||||
@dfly_args({"proactor_threads": 4, "dbfilename": "search-data"})
|
@dfly_args({"proactor_threads": 4, "dbfilename": "search-data"})
|
||||||
async def test_index_persistence(df_server):
|
async def test_index_persistence(df_server):
|
||||||
client = aioredis.Redis(port=df_server.port)
|
client = aioredis.Redis(port=df_server.port)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue