mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 10:25:47 +02:00
fix(search_family): Fix SORTBY option in FT.SEARCH for non-sortable fields and KNN search (#4942)
* fix(search_family): Fix SORTBY option in FT.SEARCH for non-sortable fields and KNN search fixes dragonflydb#4939, dragonflydb#4934 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * fix compilation error Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> Co-authored-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
parent
36e6d4527c
commit
c81d99037d
11 changed files with 237 additions and 166 deletions
|
@ -108,15 +108,9 @@ struct AstKnnNode {
|
|||
std::optional<float> ef_runtime;
|
||||
};
|
||||
|
||||
struct AstSortNode {
|
||||
std::unique_ptr<AstNode> filter;
|
||||
std::string field;
|
||||
bool descending = false;
|
||||
};
|
||||
|
||||
using NodeVariants =
|
||||
std::variant<std::monostate, AstStarNode, AstTermNode, AstPrefixNode, AstRangeNode,
|
||||
AstNegateNode, AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode, AstSortNode>;
|
||||
AstNegateNode, AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode>;
|
||||
|
||||
struct AstNode : public NodeVariants {
|
||||
using variant::variant;
|
||||
|
|
|
@ -37,11 +37,6 @@ struct QueryParams {
|
|||
absl::flat_hash_map<std::string, std::string> params;
|
||||
};
|
||||
|
||||
struct SortOption {
|
||||
std::string field;
|
||||
bool descending = false;
|
||||
};
|
||||
|
||||
// Comparable string stored as char[]. Used to reduce size of std::variant with strings.
|
||||
struct WrappedStrPtr {
|
||||
// Intentionally implicit and const std::string& for use in templates
|
||||
|
|
|
@ -143,7 +143,6 @@ struct ProfileBuilder {
|
|||
[](const AstKnnNode& n) { return absl::StrCat("KNN{l=", n.limit, "}"); },
|
||||
[](const AstNegateNode& n) { return absl::StrCat("Negate{}"); },
|
||||
[](const AstStarNode& n) { return absl::StrCat("Star{}"); },
|
||||
[](const AstSortNode& n) { return absl::StrCat("Sort{f", n.field, "}"); },
|
||||
};
|
||||
return visit(node_info, node.Variant());
|
||||
}
|
||||
|
@ -177,8 +176,7 @@ struct ProfileBuilder {
|
|||
struct BasicSearch {
|
||||
using LogicOp = AstLogicalNode::LogicOp;
|
||||
|
||||
BasicSearch(const FieldIndices* indices, size_t limit)
|
||||
: indices_{indices}, limit_{limit}, tmp_vec_{} {
|
||||
BasicSearch(const FieldIndices* indices) : indices_{indices}, tmp_vec_{} {
|
||||
}
|
||||
|
||||
void EnableProfiling() {
|
||||
|
@ -372,27 +370,6 @@ struct BasicSearch {
|
|||
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
|
||||
}
|
||||
|
||||
// SORTBY field [DESC]: Sort by field. Part of params and not "core query".
|
||||
IndexResult Search(const AstSortNode& node, string_view active_field) {
|
||||
auto sub_results = SearchGeneric(*node.filter, active_field);
|
||||
|
||||
// Skip sorting again for KNN queries, reverse if needed will be applied on aggregation
|
||||
if (auto knn = get_if<AstKnnNode>(&node.filter->Variant());
|
||||
knn && (knn->score_alias == node.field || "__vector_score" == node.field)) {
|
||||
return sub_results;
|
||||
}
|
||||
|
||||
preagg_total_ = sub_results.Size();
|
||||
|
||||
if (auto* sort_index = GetSortIndex(node.field); sort_index) {
|
||||
auto ids_vec = sub_results.Take();
|
||||
scores_ = sort_index->Sort(&ids_vec, limit_, node.descending);
|
||||
return ids_vec;
|
||||
}
|
||||
|
||||
return IndexResult{};
|
||||
}
|
||||
|
||||
void SearchKnnFlat(FlatVectorIndex* vec_index, const AstKnnNode& knn, IndexResult&& sub_results) {
|
||||
knn_distances_.reserve(sub_results.Size());
|
||||
auto cb = [&](auto* set) {
|
||||
|
@ -482,14 +459,13 @@ struct BasicSearch {
|
|||
size_t total = result.Size();
|
||||
return SearchResult{total,
|
||||
max(total, preagg_total_),
|
||||
result.Take(limit_),
|
||||
result.Take(),
|
||||
std::move(scores_),
|
||||
std::move(profile),
|
||||
std::move(error_)};
|
||||
}
|
||||
|
||||
const FieldIndices* indices_;
|
||||
size_t limit_;
|
||||
|
||||
size_t preagg_total_ = 0;
|
||||
string error_;
|
||||
|
@ -677,7 +653,7 @@ const Synonyms* FieldIndices::GetSynonyms() const {
|
|||
SearchAlgorithm::SearchAlgorithm() = default;
|
||||
SearchAlgorithm::~SearchAlgorithm() = default;
|
||||
|
||||
bool SearchAlgorithm::Init(string_view query, const QueryParams* params, const SortOption* sort) {
|
||||
bool SearchAlgorithm::Init(string_view query, const QueryParams* params) {
|
||||
try {
|
||||
query_ = make_unique<AstExpr>(ParseQuery(query, params));
|
||||
} catch (const Parser::syntax_error& se) {
|
||||
|
@ -693,35 +669,22 @@ bool SearchAlgorithm::Init(string_view query, const QueryParams* params, const S
|
|||
return false;
|
||||
}
|
||||
|
||||
if (sort != nullptr)
|
||||
query_ = make_unique<AstNode>(AstSortNode{std::move(query_), sort->field, sort->descending});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t limit) const {
|
||||
auto bs = BasicSearch{index, limit};
|
||||
SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
|
||||
auto bs = BasicSearch{index};
|
||||
if (profiling_enabled_)
|
||||
bs.EnableProfiling();
|
||||
return bs.Search(*query_);
|
||||
}
|
||||
|
||||
optional<AggregationInfo> SearchAlgorithm::GetAggregationInfo() const {
|
||||
optional<KnnScoreSortOption> SearchAlgorithm::GetKnnScoreSortOption() const {
|
||||
DCHECK(query_);
|
||||
|
||||
// KNN query
|
||||
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
|
||||
return AggregationInfo{string_view{knn->score_alias}, false, knn->limit};
|
||||
|
||||
// SEARCH query with SORTBY option
|
||||
if (auto* sort = get_if<AstSortNode>(query_.get()); sort) {
|
||||
string_view alias = "";
|
||||
if (auto* knn = get_if<AstKnnNode>(&sort->filter->Variant());
|
||||
knn && knn->score_alias == sort->field)
|
||||
alias = knn->score_alias;
|
||||
|
||||
return AggregationInfo{alias, sort->descending};
|
||||
}
|
||||
return KnnScoreSortOption{string_view{knn->score_alias}, knn->limit};
|
||||
|
||||
return nullopt;
|
||||
}
|
||||
|
|
|
@ -140,9 +140,8 @@ struct SearchResult {
|
|||
std::string error;
|
||||
};
|
||||
|
||||
struct AggregationInfo {
|
||||
std::string_view alias;
|
||||
bool descending;
|
||||
struct KnnScoreSortOption {
|
||||
std::string_view score_field_alias;
|
||||
size_t limit = std::numeric_limits<size_t>::max();
|
||||
};
|
||||
|
||||
|
@ -153,13 +152,12 @@ class SearchAlgorithm {
|
|||
~SearchAlgorithm();
|
||||
|
||||
// Init with query and return true if successful.
|
||||
bool Init(std::string_view query, const QueryParams* params, const SortOption* sort = nullptr);
|
||||
bool Init(std::string_view query, const QueryParams* params);
|
||||
|
||||
SearchResult Search(const FieldIndices* index,
|
||||
size_t limit = std::numeric_limits<size_t>::max()) const;
|
||||
SearchResult Search(const FieldIndices* index) const;
|
||||
|
||||
// if enabled, return limit & alias for knn query
|
||||
std::optional<AggregationInfo> GetAggregationInfo() const;
|
||||
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const;
|
||||
|
||||
void EnableProfiling();
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "server/search/aggregator.h"
|
||||
|
||||
#include "base/logging.h"
|
||||
#include "server/search/doc_index.h"
|
||||
|
||||
namespace dfly::aggregate {
|
||||
|
||||
|
@ -97,7 +98,7 @@ void Aggregator::DoSort(const SortParams& sort_params) {
|
|||
if (lv == rv) {
|
||||
continue;
|
||||
}
|
||||
return order == SortParams::SortOrder::ASC ? lv < rv : lv > rv;
|
||||
return order == SortOrder::ASC ? lv < rv : lv > rv;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
|
|
@ -15,6 +15,10 @@
|
|||
#include "facade/reply_builder.h"
|
||||
#include "io/io.h"
|
||||
|
||||
namespace dfly {
|
||||
enum class SortOrder;
|
||||
}
|
||||
|
||||
namespace dfly::aggregate {
|
||||
|
||||
struct Reducer;
|
||||
|
@ -34,8 +38,6 @@ struct AggregationResult {
|
|||
};
|
||||
|
||||
struct SortParams {
|
||||
enum class SortOrder { ASC, DESC };
|
||||
|
||||
constexpr static int64_t kSortAll = -1;
|
||||
|
||||
bool SortAll() const {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "server/search/aggregator.h"
|
||||
|
||||
#include "base/gtest.h"
|
||||
#include "server/search/doc_index.h"
|
||||
|
||||
namespace dfly::aggregate {
|
||||
|
||||
|
@ -20,7 +21,7 @@ TEST(AggregatorTest, Sort) {
|
|||
};
|
||||
|
||||
SortParams params;
|
||||
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
|
||||
params.fields.emplace_back("a", SortOrder::ASC);
|
||||
StepsList steps = {MakeSortStep(std::move(params))};
|
||||
|
||||
auto result = Process(values, {"a"}, steps);
|
||||
|
|
|
@ -50,6 +50,68 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) {
|
|||
} while (cursor);
|
||||
}
|
||||
|
||||
bool IsSortableField(std::string_view field_identifier, const search::Schema& schema) {
|
||||
auto it = schema.fields.find(field_identifier);
|
||||
return it != schema.fields.end() && (it->second.flags & search::SchemaField::SORTABLE);
|
||||
}
|
||||
|
||||
SearchFieldsList ToSV(const search::Schema& schema, const std::optional<SearchFieldsList>& fields) {
|
||||
SearchFieldsList sv_fields;
|
||||
if (fields) {
|
||||
sv_fields.reserve(fields->size());
|
||||
for (const auto& field : fields.value()) {
|
||||
sv_fields.push_back(field.View());
|
||||
}
|
||||
}
|
||||
return sv_fields;
|
||||
}
|
||||
|
||||
using SortIndiciesFieldsList =
|
||||
std::vector<std::pair<string_view /*identifier*/, string_view /*alias*/>>;
|
||||
|
||||
std::pair<SearchFieldsList, SortIndiciesFieldsList> PreprocessAggregateFields(
|
||||
const search::Schema& schema, const AggregateParams& params,
|
||||
const std::optional<SearchFieldsList>& load_fields) {
|
||||
absl::flat_hash_map<std::string_view, SearchField> fields_by_identifier;
|
||||
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
|
||||
fields_by_identifier.reserve(schema.field_names.size());
|
||||
sort_indicies_aliases.reserve(schema.field_names.size());
|
||||
|
||||
for (const auto& [fname, fident] : schema.field_names) {
|
||||
if (!IsSortableField(fident, schema)) {
|
||||
fields_by_identifier[fident] = {StringOrView::FromView(fident), true,
|
||||
StringOrView::FromView(fname)};
|
||||
} else {
|
||||
sort_indicies_aliases[fident] = fname;
|
||||
}
|
||||
}
|
||||
|
||||
if (load_fields) {
|
||||
for (const auto& field : load_fields.value()) {
|
||||
const auto& fident = field.GetIdentifier(schema, false);
|
||||
if (!IsSortableField(fident, schema)) {
|
||||
fields_by_identifier[fident] = field.View();
|
||||
} else {
|
||||
sort_indicies_aliases[fident] = field.GetShortName();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SearchFieldsList fields;
|
||||
fields.reserve(fields_by_identifier.size());
|
||||
for (auto& [_, field] : fields_by_identifier) {
|
||||
fields.emplace_back(std::move(field));
|
||||
}
|
||||
|
||||
SortIndiciesFieldsList sort_fields;
|
||||
sort_fields.reserve(sort_indicies_aliases.size());
|
||||
for (auto& [fident, fname] : sort_indicies_aliases) {
|
||||
sort_fields.emplace_back(fident, fname);
|
||||
}
|
||||
|
||||
return {std::move(fields), std::move(sort_fields)};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const {
|
||||
|
@ -271,21 +333,10 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
|
|||
return base_->Matches(key, obj_code);
|
||||
}
|
||||
|
||||
SearchFieldsList ToSV(const search::Schema& schema, const std::optional<SearchFieldsList>& fields) {
|
||||
SearchFieldsList sv_fields;
|
||||
if (fields) {
|
||||
sv_fields.reserve(fields->size());
|
||||
for (const auto& field : fields.value()) {
|
||||
sv_fields.push_back(field.View());
|
||||
}
|
||||
}
|
||||
return sv_fields;
|
||||
}
|
||||
|
||||
SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
|
||||
search::SearchAlgorithm* search_algo) const {
|
||||
auto& db_slice = op_args.GetDbSlice();
|
||||
auto search_results = search_algo->Search(&*indices_, params.limit_offset + params.limit_total);
|
||||
auto search_results = search_algo->Search(&*indices_);
|
||||
|
||||
if (!search_results.error.empty())
|
||||
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
|
||||
|
@ -298,7 +349,8 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
|
|||
|
||||
size_t expired_count = 0;
|
||||
for (size_t i = 0; i < search_results.ids.size(); i++) {
|
||||
auto key = key_index_.Get(search_results.ids[i]);
|
||||
const DocId doc = search_results.ids[i];
|
||||
auto key = key_index_.Get(doc);
|
||||
auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode());
|
||||
|
||||
if (!it || !IsValid(*it)) { // Item must have expired
|
||||
|
@ -315,13 +367,24 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
|
|||
For JSON indexes it would be {"$", <the whole document as string>}
|
||||
*/
|
||||
doc_data = accessor->SerializeDocument(base_->schema);
|
||||
}
|
||||
|
||||
SearchDocData loaded_fields = accessor->Serialize(base_->schema, fields_to_load);
|
||||
doc_data.insert(std::make_move_iterator(loaded_fields.begin()),
|
||||
std::make_move_iterator(loaded_fields.end()));
|
||||
} else {
|
||||
/* Load only specific fields */
|
||||
doc_data = accessor->Serialize(base_->schema, fields_to_load);
|
||||
SearchDocData loaded_fields = accessor->Serialize(base_->schema, fields_to_load);
|
||||
doc_data.insert(std::make_move_iterator(loaded_fields.begin()),
|
||||
std::make_move_iterator(loaded_fields.end()));
|
||||
|
||||
if (params.sort_option) {
|
||||
auto& field = params.sort_option->field;
|
||||
auto fident = field.GetIdentifier(base_->schema, false);
|
||||
if (IsSortableField(fident, base_->schema)) {
|
||||
doc_data[field.NameView()] = indices_->GetSortIndexValue(doc, fident);
|
||||
} else {
|
||||
SearchDocData sort_field_data = accessor->Serialize(base_->schema, {field});
|
||||
DCHECK_LE(sort_field_data.size(), 1u);
|
||||
if (!sort_field_data.empty()) {
|
||||
doc_data[field.NameView()] = sort_field_data.begin()->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]);
|
||||
|
@ -332,57 +395,6 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
|
|||
std::move(search_results.profile)};
|
||||
}
|
||||
|
||||
using SortIndiciesFieldsList =
|
||||
std::vector<std::pair<string_view /*identifier*/, string_view /*alias*/>>;
|
||||
|
||||
std::pair<SearchFieldsList, SortIndiciesFieldsList> PreprocessAggregateFields(
|
||||
const search::Schema& schema, const AggregateParams& params,
|
||||
const std::optional<SearchFieldsList>& load_fields) {
|
||||
auto is_sortable = [&schema](std::string_view fident) {
|
||||
auto it = schema.fields.find(fident);
|
||||
return it != schema.fields.end() && (it->second.flags & search::SchemaField::SORTABLE);
|
||||
};
|
||||
|
||||
absl::flat_hash_map<std::string_view, SearchField> fields_by_identifier;
|
||||
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
|
||||
fields_by_identifier.reserve(schema.field_names.size());
|
||||
sort_indicies_aliases.reserve(schema.field_names.size());
|
||||
|
||||
for (const auto& [fname, fident] : schema.field_names) {
|
||||
if (!is_sortable(fident)) {
|
||||
fields_by_identifier[fident] = {StringOrView::FromView(fident), true,
|
||||
StringOrView::FromView(fname)};
|
||||
} else {
|
||||
sort_indicies_aliases[fident] = fname;
|
||||
}
|
||||
}
|
||||
|
||||
if (load_fields) {
|
||||
for (const auto& field : load_fields.value()) {
|
||||
const auto& fident = field.GetIdentifier(schema, false);
|
||||
if (!is_sortable(fident)) {
|
||||
fields_by_identifier[fident] = field.View();
|
||||
} else {
|
||||
sort_indicies_aliases[fident] = field.GetShortName();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SearchFieldsList fields;
|
||||
fields.reserve(fields_by_identifier.size());
|
||||
for (auto& [_, field] : fields_by_identifier) {
|
||||
fields.emplace_back(std::move(field));
|
||||
}
|
||||
|
||||
SortIndiciesFieldsList sort_fields;
|
||||
sort_fields.reserve(sort_indicies_aliases.size());
|
||||
for (auto& [fident, fname] : sort_indicies_aliases) {
|
||||
sort_fields.emplace_back(fident, fname);
|
||||
}
|
||||
|
||||
return {std::move(fields), std::move(sort_fields)};
|
||||
}
|
||||
|
||||
vector<SearchDocData> ShardDocIndex::SearchForAggregator(
|
||||
const OpArgs& op_args, const AggregateParams& params,
|
||||
search::SearchAlgorithm* search_algo) const {
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include <absl/container/flat_hash_map.h>
|
||||
#include <absl/container/flat_hash_set.h>
|
||||
#include <absl/strings/match.h>
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
@ -71,6 +72,10 @@ class SearchField {
|
|||
public:
|
||||
SearchField() = default;
|
||||
|
||||
explicit SearchField(StringOrView name) : name_(std::move(name)) {
|
||||
is_short_name_ = !IsJsonPath(NameView());
|
||||
}
|
||||
|
||||
SearchField(StringOrView name, bool is_short_name)
|
||||
: name_(std::move(name)), is_short_name_(is_short_name) {
|
||||
}
|
||||
|
@ -111,15 +116,15 @@ class SearchField {
|
|||
return SearchField{StringOrView::FromView(NameView()), is_short_name_};
|
||||
}
|
||||
|
||||
std::string_view NameView() const {
|
||||
return name_.view();
|
||||
}
|
||||
|
||||
private:
|
||||
bool HasNewAlias() const {
|
||||
return !new_alias_.empty();
|
||||
}
|
||||
|
||||
std::string_view NameView() const {
|
||||
return name_.view();
|
||||
}
|
||||
|
||||
std::string_view AliasView() const {
|
||||
return new_alias_.view();
|
||||
}
|
||||
|
@ -132,7 +137,14 @@ class SearchField {
|
|||
|
||||
using SearchFieldsList = std::vector<SearchField>;
|
||||
|
||||
enum class SortOrder { ASC, DESC };
|
||||
|
||||
struct SearchParams {
|
||||
struct SortOption {
|
||||
SearchField field;
|
||||
SortOrder order = SortOrder::ASC;
|
||||
};
|
||||
|
||||
// Parameters for "LIMIT offset total": select total amount documents with a specific offset from
|
||||
// the whole result set
|
||||
size_t limit_offset = 0;
|
||||
|
@ -153,7 +165,7 @@ struct SearchParams {
|
|||
std::optional<SearchFieldsList> load_fields;
|
||||
bool no_content = false;
|
||||
|
||||
std::optional<search::SortOption> sort_option;
|
||||
std::optional<SortOption> sort_option;
|
||||
search::QueryParams query_params;
|
||||
|
||||
bool ShouldReturnAllFields() const {
|
||||
|
|
|
@ -377,8 +377,10 @@ ParseResult<SearchParams> ParseSearchParams(CmdArgParser* parser) {
|
|||
} else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
|
||||
params.query_params = ParseQueryParams(parser);
|
||||
} else if (parser->Check("SORTBY")) {
|
||||
params.sort_option =
|
||||
search::SortOption{parser->Next<std::string>(), bool(parser->Check("DESC"))};
|
||||
auto parsed_field = ParseField(parser);
|
||||
StringOrView field = StringOrView::FromString(std::string{parsed_field});
|
||||
params.sort_option = SearchParams::SortOption{
|
||||
SearchField{std::move(field)}, parser->Check("DESC") ? SortOrder::DESC : SortOrder::ASC};
|
||||
} else {
|
||||
// Unsupported parameters are ignored for now
|
||||
parser->Skip(1);
|
||||
|
@ -389,8 +391,6 @@ ParseResult<SearchParams> ParseSearchParams(CmdArgParser* parser) {
|
|||
}
|
||||
|
||||
std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
|
||||
using SortOrder = aggregate::SortParams::SortOrder;
|
||||
|
||||
size_t strings_num = parser->Next<size_t>();
|
||||
|
||||
aggregate::SortParams sort_params;
|
||||
|
@ -537,7 +537,8 @@ void SendSerializedDoc(const SerializedSearchDoc& doc, SinkReplyBuilder* builder
|
|||
}
|
||||
}
|
||||
|
||||
void SearchReply(const SearchParams& params, std::optional<search::AggregationInfo> agg_info,
|
||||
void SearchReply(const SearchParams& params,
|
||||
std::optional<search::KnnScoreSortOption> knn_sort_option,
|
||||
absl::Span<SearchResult> results, SinkReplyBuilder* builder) {
|
||||
size_t total_hits = 0;
|
||||
absl::InlinedVector<SerializedSearchDoc*, 5> docs;
|
||||
|
@ -552,17 +553,14 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
|
|||
size_t size = docs.size();
|
||||
bool should_add_score_field = false;
|
||||
|
||||
if (agg_info) {
|
||||
size = std::min(size, agg_info->limit);
|
||||
total_hits = std::min(total_hits, agg_info->limit);
|
||||
should_add_score_field = params.ShouldReturnField(agg_info->alias);
|
||||
if (knn_sort_option) {
|
||||
size = std::min(size, knn_sort_option->limit);
|
||||
total_hits = std::min(total_hits, knn_sort_option->limit);
|
||||
should_add_score_field = params.ShouldReturnField(knn_sort_option->score_field_alias);
|
||||
|
||||
using Comparator = bool (*)(const SerializedSearchDoc*, const SerializedSearchDoc*);
|
||||
auto comparator =
|
||||
!agg_info->descending
|
||||
? static_cast<Comparator>([](const SerializedSearchDoc* l,
|
||||
const SerializedSearchDoc* r) { return *l < *r; })
|
||||
: [](const SerializedSearchDoc* l, const SerializedSearchDoc* r) { return *r < *l; };
|
||||
auto comparator = [](const SerializedSearchDoc* l, const SerializedSearchDoc* r) {
|
||||
return *l < *r;
|
||||
};
|
||||
|
||||
const size_t prefix_size_to_sort = std::min(params.limit_offset + params.limit_total, size);
|
||||
if (prefix_size_to_sort == docs.size()) {
|
||||
|
@ -575,6 +573,38 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
|
|||
|
||||
const size_t offset = std::min(params.limit_offset, size);
|
||||
const size_t limit = std::min(size - offset, params.limit_total);
|
||||
const size_t end = offset + limit;
|
||||
DCHECK(end <= docs.size());
|
||||
|
||||
if (params.sort_option) {
|
||||
auto field_alias = params.sort_option->field.NameView();
|
||||
|
||||
auto comparator = [&](const SerializedSearchDoc* l_doc, const SerializedSearchDoc* r_doc) {
|
||||
auto& l = l_doc->values;
|
||||
auto& r = r_doc->values;
|
||||
|
||||
auto l_it = l.find(field_alias);
|
||||
auto r_it = r.find(field_alias);
|
||||
|
||||
// If some of the values is not present
|
||||
if (l_it == l.end() || r_it == r.end()) {
|
||||
return l_it != l.end();
|
||||
}
|
||||
|
||||
const auto& lv = l_it->second;
|
||||
const auto& rv = r_it->second;
|
||||
return params.sort_option->order == SortOrder::ASC ? lv < rv : lv > rv;
|
||||
};
|
||||
|
||||
auto sort_begin = docs.begin();
|
||||
auto sort_end = docs.end();
|
||||
// If we first sorted by knn, we need to sort only the result of knn
|
||||
if (knn_sort_option) {
|
||||
sort_begin = docs.begin() + offset;
|
||||
sort_end = docs.begin() + end;
|
||||
}
|
||||
std::sort(sort_begin, sort_end, std::move(comparator));
|
||||
}
|
||||
|
||||
const bool reply_with_ids_only = params.IdsOnly();
|
||||
const size_t reply_size = reply_with_ids_only ? (limit + 1) : (limit * 2 + 1);
|
||||
|
@ -585,7 +615,6 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
|
|||
rb->StartArray(reply_size);
|
||||
rb->SendLong(total_hits);
|
||||
|
||||
const size_t end = offset + limit;
|
||||
for (size_t i = offset; i < end; i++) {
|
||||
if (reply_with_ids_only) {
|
||||
rb->SendBulkString(docs[i]->key);
|
||||
|
@ -593,7 +622,8 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
|
|||
}
|
||||
|
||||
if (should_add_score_field && holds_alternative<float>(docs[i]->score))
|
||||
docs[i]->values[agg_info->alias] = absl::StrCat(get<float>(docs[i]->score));
|
||||
docs[i]->values[knn_sort_option->score_field_alias] =
|
||||
absl::StrCat(get<float>(docs[i]->score));
|
||||
|
||||
SendSerializedDoc(*docs[i], builder);
|
||||
}
|
||||
|
@ -811,8 +841,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
|
|||
return;
|
||||
|
||||
search::SearchAlgorithm search_algo;
|
||||
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
|
||||
if (!search_algo.Init(query_str, ¶ms->query_params, sort_opt))
|
||||
if (!search_algo.Init(query_str, ¶ms->query_params))
|
||||
return builder->SendError("Query syntax error");
|
||||
|
||||
// Because our coordinator thread may not have a shard, we can't check ahead if the index exists.
|
||||
|
@ -835,7 +864,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
|
|||
return builder->SendError(*res.error);
|
||||
}
|
||||
|
||||
SearchReply(*params, search_algo.GetAggregationInfo(), absl::MakeSpan(docs), builder);
|
||||
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(docs), builder);
|
||||
}
|
||||
|
||||
void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
|
||||
|
@ -858,8 +887,7 @@ void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
|
|||
return;
|
||||
|
||||
search::SearchAlgorithm search_algo;
|
||||
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
|
||||
if (!search_algo.Init(query_str, ¶ms->query_params, sort_opt))
|
||||
if (!search_algo.Init(query_str, ¶ms->query_params))
|
||||
return rb->SendError("query syntax error");
|
||||
|
||||
search_algo.EnableProfiling();
|
||||
|
@ -911,7 +939,7 @@ void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
|
|||
|
||||
// Result of the search command
|
||||
if (!result_is_empty) {
|
||||
SearchReply(*params, search_algo.GetAggregationInfo(), absl::MakeSpan(search_results), rb);
|
||||
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(search_results), rb);
|
||||
} else {
|
||||
rb->StartArray(1);
|
||||
rb->SendLong(0);
|
||||
|
@ -1010,7 +1038,7 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
|
|||
return;
|
||||
|
||||
search::SearchAlgorithm search_algo;
|
||||
if (!search_algo.Init(params->query, ¶ms->params, nullptr))
|
||||
if (!search_algo.Init(params->query, ¶ms->params))
|
||||
return builder->SendError("Query syntax error");
|
||||
|
||||
using ResultContainer = decltype(declval<ShardDocIndex>().SearchForAggregator(
|
||||
|
|
|
@ -2278,4 +2278,69 @@ TEST_F(SearchFamilyTest, PrefixSearchWithSynonyms) {
|
|||
EXPECT_THAT(resp, AreDocIds("doc:6")); // Should only find macintosh
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, SearchSortByOptionNonSortableFieldJson) {
|
||||
Run({"JSON.SET", "json1", "$", R"({"text":"2"})"});
|
||||
Run({"JSON.SET", "json2", "$", R"({"text":"1"})"});
|
||||
|
||||
auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.text", "AS", "text", "TEXT"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
auto expect_expr = [](std::string_view text_field) {
|
||||
return IsArray(2, "json2", IsMap(text_field, "\"1\"", "$", R"({"text":"1"})"), "json1",
|
||||
IsMap(text_field, "\"2\"", "$", R"({"text":"2"})"));
|
||||
};
|
||||
|
||||
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "text"});
|
||||
EXPECT_THAT(resp, expect_expr("text"sv));
|
||||
|
||||
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "@text"});
|
||||
EXPECT_THAT(resp, expect_expr("text"sv));
|
||||
|
||||
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "$.text"});
|
||||
EXPECT_THAT(resp, expect_expr("$.text"sv));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, SearchSortByOptionNonSortableFieldHash) {
|
||||
Run({"HSET", "h1", "text", "2"});
|
||||
Run({"HSET", "h2", "text", "1"});
|
||||
|
||||
auto resp = Run({"FT.CREATE", "index", "ON", "HASH", "SCHEMA", "text", "TEXT"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
auto expected_expr = IsArray(2, "h2", IsMap("text", "1"), "h1", IsMap("text", "2"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "text"});
|
||||
EXPECT_THAT(resp, expected_expr);
|
||||
|
||||
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "@text"});
|
||||
EXPECT_THAT(resp, expected_expr);
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, KnnSearchWithSortby) {
|
||||
auto to_vector = [](const char* value) { return std::string(value, 16); };
|
||||
|
||||
Run({"HSET", "doc:1", "timestamp", "1713100000", "embedding",
|
||||
to_vector("\x3d\xcc\xcc\x3d\x00\x00\x80\x3f\xcd\xcc\x4c\x3e\x9a\x99\x19\x3f")});
|
||||
Run({"HSET", "doc:2", "timestamp", "1713200000", "embedding",
|
||||
to_vector("\x9a\x99\x19\x3f\xcd\xcc\x4c\x3e\x00\x00\x80\x3f\x3d\xcc\xcc\x3d")});
|
||||
Run({"HSET", "doc:3", "timestamp", "1713300000", "embedding",
|
||||
to_vector("\x00\x00\x80\x3f\x3d\xcc\xcc\x3d\xcd\xcc\x4c\x3e\x9a\x99\x19\x3f")});
|
||||
|
||||
Run({"FT.CREATE", "my_index", "ON", "HASH", "SCHEMA", "timestamp", "NUMERIC", "SORTABLE",
|
||||
"embedding", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "4", "DISTANCE_METRIC",
|
||||
"COSINE"});
|
||||
|
||||
auto search_vector =
|
||||
to_vector("\x3d\xcc\xcc\x3d\x00\x00\x80\x3f\xcd\xcc\x4c\x3e\x9a\x99\x19\x3f");
|
||||
|
||||
auto resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 2 @embedding $vec]", "PARAMS", "2", "vec",
|
||||
search_vector, "NOCONTENT"});
|
||||
EXPECT_THAT(resp, IsArray(2, "doc:1", "doc:3"));
|
||||
|
||||
// FT.SEARCH with KNN + SORTBY
|
||||
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 2 @embedding $vec]", "PARAMS", "2", "vec",
|
||||
search_vector, "SORTBY", "timestamp", "DESC", "NOCONTENT"});
|
||||
EXPECT_THAT(resp, IsArray(2, "doc:3", "doc:1"));
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue