feat(search): return scores (#1870)

* feat(search): return scores

---------

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-09-25 10:03:17 +03:00 committed by GitHub
parent 19783face5
commit fc0943989e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 30 deletions

View file

@ -56,11 +56,18 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
tags.push_back(move(tag));
}
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec)
: filter{make_unique<AstNode>(std::move(filter))},
AstKnnNode::AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec,
std::string_view score_alias)
: filter{nullptr},
limit{limit},
field{field.substr(1)},
vec{std::move(vec)} {
vec{std::move(vec)},
score_alias{score_alias} {
}
AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
*this = std::move(self);
this->filter = make_unique<AstNode>(std::move(filter));
}
} // namespace dfly::search

View file

@ -74,12 +74,19 @@ struct AstTagsNode {
// Applies nearest neighbor search to the final result set
struct AstKnnNode {
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec);
AstKnnNode() = default;
AstKnnNode(size_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias);
AstKnnNode(AstNode&& sub, AstKnnNode&& self);
friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) {
return stream;
}
std::unique_ptr<AstNode> filter;
size_t limit;
std::string field;
OwnedFtVector vec;
std::string score_alias;
};
using NodeVariants =
@ -89,6 +96,10 @@ using NodeVariants =
struct AstNode : public NodeVariants {
using variant::variant;
friend std::ostream& operator<<(std::ostream& stream, const AstNode& matrix) {
return stream;
}
const NodeVariants& Variant() const& {
return *this;
}
@ -99,11 +110,4 @@ using AstExpr = AstNode;
} // namespace search
} // namespace dfly
namespace std {
inline std::ostream& operator<<(std::ostream& os, const dfly::search::AstExpr& ast) {
// os << "ast{" << ast->Debug() << "}";
return os;
}
} // namespace std
namespace std {} // namespace std

View file

@ -62,6 +62,7 @@ term_char [_]|\w
"}" return Parser::make_RCURLBR (loc());
"|" return Parser::make_OR_OP (loc());
"KNN" return Parser::make_KNN (loc());
"AS" return Parser::make_AS (loc());
-?[0-9]+ return make_INT64(matched_view(), loc());

View file

@ -57,6 +57,7 @@ using namespace std;
RCURLBR "}"
OR_OP "|"
KNN "KNN"
AS "AS"
;
%token AND_OP
@ -75,6 +76,9 @@ using namespace std;
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
%nterm <AstKnnNode> knn_query
%nterm <std::string> opt_knn_alias
%printer { yyo << $$; } <*>;
%%
@ -82,8 +86,16 @@ using namespace std;
final_query:
filter
{ driver->Set(move($1)); }
| filter ARROW LBRACKET KNN INT64 FIELD TERM RBRACKET
{ driver->Set(AstKnnNode(move($1), $5, $6, BytesToFtVector($7))); }
| filter ARROW knn_query
{ driver->Set(AstKnnNode(move($1), move($3))); }
knn_query:
LBRACKET KNN INT64 FIELD TERM opt_knn_alias RBRACKET
{ $$ = AstKnnNode($3, $4, BytesToFtVector($5), $6); }
opt_knn_alias:
AS TERM { $$ = move($2); }
| { $$ = std::string{}; }
filter:
search_expr { $$ = move($1); }

View file

@ -488,10 +488,10 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
return bs.Search(*query_);
}
optional<size_t> SearchAlgorithm::HasKnn() const {
optional<pair<size_t, string_view>> SearchAlgorithm::HasKnn() const {
DCHECK(query_);
if (holds_alternative<AstKnnNode>(*query_))
return get<AstKnnNode>(*query_).limit;
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
return make_pair(knn->limit, string_view{knn->score_alias});
return nullopt;
}

View file

@ -107,8 +107,8 @@ class SearchAlgorithm {
SearchResult Search(const FieldIndices* index) const;
// Return KNN limit if it is enabled
std::optional<size_t> HasKnn() const;
// if enabled, return limit & alias for knn query
std::optional<std::pair<size_t /*limit*/, std::string_view /*alias*/>> HasKnn() const;
void EnableProfiling();

View file

@ -58,6 +58,11 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
} // namespace
bool SearchParams::ShouldReturnField(std::string_view field) const {
auto cb = [field](const auto& entry) { return entry.first == field; };
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
}
optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
auto it = kSchemaTypes.find(name);
return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt;

View file

@ -63,6 +63,8 @@ struct SearchParams {
bool IdsOnly() const {
return return_fields && return_fields->empty();
}
bool ShouldReturnField(std::string_view field) const;
};
// Stores basic info about a document index.

View file

@ -241,11 +241,11 @@ void ReplyWithResults(const SearchParams& params, absl::Span<SearchResult> resul
}
}
void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchResult> results,
ConnectionContext* cntx) {
vector<const SerializedSearchDoc*> docs;
for (const auto& shard_results : results) {
for (const auto& doc : shard_results.docs) {
void ReplyKnn(size_t knn_limit, string_view knn_score_alias, const SearchParams& params,
absl::Span<SearchResult> results, ConnectionContext* cntx) {
vector<SerializedSearchDoc*> docs;
for (auto& shard_results : results) {
for (auto& doc : shard_results.docs) {
docs.push_back(&doc);
}
}
@ -262,15 +262,24 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
bool ids_only = params.IdsOnly();
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);
// Clear knn score alias if its excluded from return values
if (!params.ShouldReturnField(knn_score_alias))
knn_score_alias = "";
facade::SinkReplyBuilder::ReplyAggregator agg{cntx->reply_builder()};
(*cntx)->StartArray(reply_size);
(*cntx)->SendLong(docs.size());
for (auto* doc : absl::MakeSpan(docs).subspan(params.limit_offset, result_count)) {
if (ids_only)
if (ids_only) {
(*cntx)->SendBulkString(doc->key);
else
SendSerializedDoc(*doc, cntx);
continue;
}
if (!knn_score_alias.empty())
doc->values[knn_score_alias] = absl::StrCat(doc->knn_distance);
SendSerializedDoc(*doc, cntx);
}
}
@ -433,8 +442,8 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(std::move(*res.error));
}
if (auto knn_limit = search_algo.HasKnn(); knn_limit)
ReplyKnn(*knn_limit, *params, absl::MakeSpan(docs), cntx);
if (auto knn_params = search_algo.HasKnn(); knn_params)
ReplyKnn(knn_params->first, knn_params->second, *params, absl::MakeSpan(docs), cntx);
else
ReplyWithResults(*params, absl::MakeSpan(docs), cntx);
}

View file

@ -903,7 +903,7 @@ void Transaction::EnableShard(ShardId sid) {
void Transaction::EnableAllShards() {
unique_shard_cnt_ = shard_set->size();
unique_shard_id_ = kInvalidSid;
unique_shard_id_ = unique_shard_cnt_ == 1 ? 0 : kInvalidSid;
shard_data_.resize(shard_set->size());
for (auto& sd : shard_data_)
sd.local_mask |= ACTIVE;

View file

@ -304,6 +304,42 @@ async def test_multidim_knn(async_client: aioredis.Redis, index_type, algo_type)
await i3.dropindex()
async def test_knn_score_return(async_client: aioredis.Redis):
i1 = async_client.ft("i1")
vector_field = VectorField(
"pos",
algorithm="FLAT",
attributes={
"DIM": 1,
"DISTANCE_METRIC": "L2",
"INITICAL_CAP": 100,
},
)
await i1.create_index(
[vector_field],
definition=IndexDefinition(index_type=IndexType.HASH),
)
pipe = async_client.pipeline()
for i in range(100):
pipe.hset(f"k{i}", mapping={"pos": np.array(i, dtype=np.float32).tobytes()})
await pipe.execute()
params = {"vec": np.array([1.0], dtype=np.float32).tobytes()}
result = await i1.search("* => [KNN 3 @pos $vec AS distance]", params)
assert result.total == 3
assert [d["distance"] for d in result.docs] == ["0", "1", "1"]
result = await i1.search(
Query("* => [KNN 3 @pos $vec AS distance]").return_fields("pos"), params
)
assert not any(hasattr(d, "distance") for d in result.docs)
await i1.dropindex()
@dfly_args({"proactor_threads": 4, "dbfilename": "search-data"})
async def test_index_persistence(df_server):
client = aioredis.Redis(port=df_server.port)