mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
feat(search): return scores (#1870)
* feat(search): return scores --------- Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
19783face5
commit
fc0943989e
11 changed files with 106 additions and 30 deletions
|
@ -56,11 +56,18 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
|
|||
tags.push_back(move(tag));
|
||||
}
|
||||
|
||||
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec)
|
||||
: filter{make_unique<AstNode>(std::move(filter))},
|
||||
AstKnnNode::AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec,
|
||||
std::string_view score_alias)
|
||||
: filter{nullptr},
|
||||
limit{limit},
|
||||
field{field.substr(1)},
|
||||
vec{std::move(vec)} {
|
||||
vec{std::move(vec)},
|
||||
score_alias{score_alias} {
|
||||
}
|
||||
|
||||
AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
|
||||
*this = std::move(self);
|
||||
this->filter = make_unique<AstNode>(std::move(filter));
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -74,12 +74,19 @@ struct AstTagsNode {
|
|||
|
||||
// Applies nearest neighbor search to the final result set
|
||||
struct AstKnnNode {
|
||||
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec);
|
||||
AstKnnNode() = default;
|
||||
AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias);
|
||||
AstKnnNode(AstNode&& sub, AstKnnNode&& self);
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) {
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<AstNode> filter;
|
||||
size_t limit;
|
||||
std::string field;
|
||||
OwnedFtVector vec;
|
||||
std::string score_alias;
|
||||
};
|
||||
|
||||
using NodeVariants =
|
||||
|
@ -89,6 +96,10 @@ using NodeVariants =
|
|||
struct AstNode : public NodeVariants {
|
||||
using variant::variant;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& stream, const AstNode& matrix) {
|
||||
return stream;
|
||||
}
|
||||
|
||||
const NodeVariants& Variant() const& {
|
||||
return *this;
|
||||
}
|
||||
|
@ -99,11 +110,4 @@ using AstExpr = AstNode;
|
|||
} // namespace search
|
||||
} // namespace dfly
|
||||
|
||||
namespace std {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const dfly::search::AstExpr& ast) {
|
||||
// os << "ast{" << ast->Debug() << "}";
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace std
|
||||
namespace std {} // namespace std
|
||||
|
|
|
@ -62,6 +62,7 @@ term_char [_]|\w
|
|||
"}" return Parser::make_RCURLBR (loc());
|
||||
"|" return Parser::make_OR_OP (loc());
|
||||
"KNN" return Parser::make_KNN (loc());
|
||||
"AS" return Parser::make_AS (loc());
|
||||
|
||||
-?[0-9]+ return make_INT64(matched_view(), loc());
|
||||
|
||||
|
|
|
@ -57,6 +57,7 @@ using namespace std;
|
|||
RCURLBR "}"
|
||||
OR_OP "|"
|
||||
KNN "KNN"
|
||||
AS "AS"
|
||||
;
|
||||
|
||||
%token AND_OP
|
||||
|
@ -75,6 +76,9 @@ using namespace std;
|
|||
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr
|
||||
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
|
||||
|
||||
%nterm <AstKnnNode> knn_query
|
||||
%nterm <std::string> opt_knn_alias
|
||||
|
||||
%printer { yyo << $$; } <*>;
|
||||
|
||||
%%
|
||||
|
@ -82,8 +86,16 @@ using namespace std;
|
|||
final_query:
|
||||
filter
|
||||
{ driver->Set(move($1)); }
|
||||
| filter ARROW LBRACKET KNN INT64 FIELD TERM RBRACKET
|
||||
{ driver->Set(AstKnnNode(move($1), $5, $6, BytesToFtVector($7))); }
|
||||
| filter ARROW knn_query
|
||||
{ driver->Set(AstKnnNode(move($1), move($3))); }
|
||||
|
||||
knn_query:
|
||||
LBRACKET KNN INT64 FIELD TERM opt_knn_alias RBRACKET
|
||||
{ $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); }
|
||||
|
||||
opt_knn_alias:
|
||||
AS TERM { $$ = move($2); }
|
||||
| { $$ = std::string{}; }
|
||||
|
||||
filter:
|
||||
search_expr { $$ = move($1); }
|
||||
|
|
|
@ -488,10 +488,10 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
|
|||
return bs.Search(*query_);
|
||||
}
|
||||
|
||||
optional<size_t> SearchAlgorithm::HasKnn() const {
|
||||
optional<pair<size_t, string_view>> SearchAlgorithm::HasKnn() const {
|
||||
DCHECK(query_);
|
||||
if (holds_alternative<AstKnnNode>(*query_))
|
||||
return get<AstKnnNode>(*query_).limit;
|
||||
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
|
||||
return make_pair(knn->limit, string_view{knn->score_alias});
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
|
|
|
@ -107,8 +107,8 @@ class SearchAlgorithm {
|
|||
|
||||
SearchResult Search(const FieldIndices* index) const;
|
||||
|
||||
// Return KNN limit if it is enabled
|
||||
std::optional<size_t> HasKnn() const;
|
||||
// if enabled, return limit & alias for knn query
|
||||
std::optional<std::pair<size_t /*limit*/, std::string_view /*alias*/>> HasKnn() const;
|
||||
|
||||
void EnableProfiling();
|
||||
|
||||
|
|
|
@ -58,6 +58,11 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
|
|||
|
||||
} // namespace
|
||||
|
||||
bool SearchParams::ShouldReturnField(std::string_view field) const {
|
||||
auto cb = [field](const auto& entry) { return entry.first == field; };
|
||||
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
|
||||
}
|
||||
|
||||
optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
|
||||
auto it = kSchemaTypes.find(name);
|
||||
return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt;
|
||||
|
|
|
@ -63,6 +63,8 @@ struct SearchParams {
|
|||
bool IdsOnly() const {
|
||||
return return_fields && return_fields->empty();
|
||||
}
|
||||
|
||||
bool ShouldReturnField(std::string_view field) const;
|
||||
};
|
||||
|
||||
// Stores basic info about a document index.
|
||||
|
|
|
@ -241,11 +241,11 @@ void ReplyWithResults(const SearchParams& params, absl::Span<SearchResult> resul
|
|||
}
|
||||
}
|
||||
|
||||
void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchResult> results,
|
||||
ConnectionContext* cntx) {
|
||||
vector<const SerializedSearchDoc*> docs;
|
||||
for (const auto& shard_results : results) {
|
||||
for (const auto& doc : shard_results.docs) {
|
||||
void ReplyKnn(size_t knn_limit, string_view knn_score_alias, const SearchParams& params,
|
||||
absl::Span<SearchResult> results, ConnectionContext* cntx) {
|
||||
vector<SerializedSearchDoc*> docs;
|
||||
for (auto& shard_results : results) {
|
||||
for (auto& doc : shard_results.docs) {
|
||||
docs.push_back(&doc);
|
||||
}
|
||||
}
|
||||
|
@ -262,15 +262,24 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
|
|||
bool ids_only = params.IdsOnly();
|
||||
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);
|
||||
|
||||
// Clear knn score alias if its excluded from return values
|
||||
if (!params.ShouldReturnField(knn_score_alias))
|
||||
knn_score_alias = "";
|
||||
|
||||
facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()};
|
||||
|
||||
(*cntx)->StartArray(reply_size);
|
||||
(*cntx)->SendLong(docs.size());
|
||||
for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) {
|
||||
if (ids_only)
|
||||
if (ids_only) {
|
||||
(*cntx)->SendBulkString(doc->key);
|
||||
else
|
||||
SendSerializedDoc(*doc, cntx);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!knn_score_alias.empty())
|
||||
doc->values[knn_score_alias] = absl::StrCat(doc->knn_distance);
|
||||
|
||||
SendSerializedDoc(*doc, cntx);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -433,8 +442,8 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
|
|||
return (*cntx)->SendError(std::move(*res.error));
|
||||
}
|
||||
|
||||
if (auto knn_limit = search_algo.HasKnn(); knn_limit)
|
||||
ReplyKnn(*knn_limit, *params, absl::MakeSpan(docs), cntx);
|
||||
if (auto knn_params = search_algo.HasKnn(); knn_params)
|
||||
ReplyKnn(knn_params->first, knn_params->second, *params, absl::MakeSpan(docs), cntx);
|
||||
else
|
||||
ReplyWithResults(*params, absl::MakeSpan(docs), cntx);
|
||||
}
|
||||
|
|
|
@ -903,7 +903,7 @@ void Transaction::EnableShard(ShardId sid) {
|
|||
|
||||
void Transaction::EnableAllShards() {
|
||||
unique_shard_cnt_ = shard_set->size();
|
||||
unique_shard_id_ = kInvalidSid;
|
||||
unique_shard_id_ = unique_shard_cnt_ == 1 ? 0 : kInvalidSid;
|
||||
shard_data_.resize(shard_set->size());
|
||||
for (auto& sd : shard_data_)
|
||||
sd.local_mask |= ACTIVE;
|
||||
|
|
|
@ -304,6 +304,42 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type, algo_type)
|
|||
await i3.dropindex()
|
||||
|
||||
|
||||
async def test_knn_score_return(async_client: aioredis.Redis):
|
||||
i1 = async_client.ft("i1")
|
||||
vector_field = VectorField(
|
||||
"pos",
|
||||
algorithm="FLAT",
|
||||
attributes={
|
||||
"DIM": 1,
|
||||
"DISTANCE_METRIC": "L2",
|
||||
"INITICAL_CAP": 100,
|
||||
},
|
||||
)
|
||||
|
||||
await i1.create_index(
|
||||
[vector_field],
|
||||
definition=IndexDefinition(index_type=IndexType.HASH),
|
||||
)
|
||||
|
||||
pipe = async_client.pipeline()
|
||||
for i in range(100):
|
||||
pipe.hset(f"k{i}", mapping={"pos": np.array(i, dtype=np.float32).tobytes()})
|
||||
await pipe.execute()
|
||||
|
||||
params = {"vec": np.array([1.0], dtype=np.float32).tobytes()}
|
||||
result = await i1.search("* => [KNN 3 @pos $vec AS distance]", params)
|
||||
|
||||
assert result.total == 3
|
||||
assert [d["distance"] for d in result.docs] == ["0", "1", "1"]
|
||||
|
||||
result = await i1.search(
|
||||
Query("* => [KNN 3 @pos $vec AS distance]").return_fields("pos"), params
|
||||
)
|
||||
assert not any(hasattr(d, "distance") for d in result.docs)
|
||||
|
||||
await i1.dropindex()
|
||||
|
||||
|
||||
@dfly_args({"proactor_threads": 4, "dbfilename": "search-data"})
|
||||
async def test_index_persistence(df_server):
|
||||
client = aioredis.Redis(port=df_server.port)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue