fix(search_family): Fix SORTBY option in FT.SEARCH for non-sortable fields and KNN search (#4942)

* fix(search_family): Fix SORTBY option in FT.SEARCH for non-sortable fields and KNN search

fixes dragonflydb#4939, dragonflydb#4934

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* fix compilation error

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
Co-authored-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2025-04-17 13:43:25 +02:00 committed by GitHub
parent 36e6d4527c
commit c81d99037d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 237 additions and 166 deletions

View file

@ -108,15 +108,9 @@ struct AstKnnNode {
std::optional<float> ef_runtime;
};
struct AstSortNode {
std::unique_ptr<AstNode> filter;
std::string field;
bool descending = false;
};
using NodeVariants =
std::variant<std::monostate, AstStarNode, AstTermNode, AstPrefixNode, AstRangeNode,
AstNegateNode, AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode, AstSortNode>;
AstNegateNode, AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode>;
struct AstNode : public NodeVariants {
using variant::variant;

View file

@ -37,11 +37,6 @@ struct QueryParams {
absl::flat_hash_map<std::string, std::string> params;
};
struct SortOption {
std::string field;
bool descending = false;
};
// Comparable string stored as char[]. Used to reduce size of std::variant with strings.
struct WrappedStrPtr {
// Intentionally implicit and const std::string& for use in templates

View file

@ -143,7 +143,6 @@ struct ProfileBuilder {
[](const AstKnnNode& n) { return absl::StrCat("KNN{l=", n.limit, "}"); },
[](const AstNegateNode& n) { return absl::StrCat("Negate{}"); },
[](const AstStarNode& n) { return absl::StrCat("Star{}"); },
[](const AstSortNode& n) { return absl::StrCat("Sort{f", n.field, "}"); },
};
return visit(node_info, node.Variant());
}
@ -177,8 +176,7 @@ struct ProfileBuilder {
struct BasicSearch {
using LogicOp = AstLogicalNode::LogicOp;
BasicSearch(const FieldIndices* indices, size_t limit)
: indices_{indices}, limit_{limit}, tmp_vec_{} {
BasicSearch(const FieldIndices* indices) : indices_{indices}, tmp_vec_{} {
}
void EnableProfiling() {
@ -372,27 +370,6 @@ struct BasicSearch {
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}
// SORTBY field [DESC]: Sort by field. Part of params and not "core query".
IndexResult Search(const AstSortNode& node, string_view active_field) {
auto sub_results = SearchGeneric(*node.filter, active_field);
// Skip sorting again for KNN queries, reverse if needed will be applied on aggregation
if (auto knn = get_if<AstKnnNode>(&node.filter->Variant());
knn && (knn->score_alias == node.field || "__vector_score" == node.field)) {
return sub_results;
}
preagg_total_ = sub_results.Size();
if (auto* sort_index = GetSortIndex(node.field); sort_index) {
auto ids_vec = sub_results.Take();
scores_ = sort_index->Sort(&ids_vec, limit_, node.descending);
return ids_vec;
}
return IndexResult{};
}
void SearchKnnFlat(FlatVectorIndex* vec_index, const AstKnnNode& knn, IndexResult&& sub_results) {
knn_distances_.reserve(sub_results.Size());
auto cb = [&](auto* set) {
@ -482,14 +459,13 @@ struct BasicSearch {
size_t total = result.Size();
return SearchResult{total,
max(total, preagg_total_),
result.Take(limit_),
result.Take(),
std::move(scores_),
std::move(profile),
std::move(error_)};
}
const FieldIndices* indices_;
size_t limit_;
size_t preagg_total_ = 0;
string error_;
@ -677,7 +653,7 @@ const Synonyms* FieldIndices::GetSynonyms() const {
SearchAlgorithm::SearchAlgorithm() = default;
SearchAlgorithm::~SearchAlgorithm() = default;
bool SearchAlgorithm::Init(string_view query, const QueryParams* params, const SortOption* sort) {
bool SearchAlgorithm::Init(string_view query, const QueryParams* params) {
try {
query_ = make_unique<AstExpr>(ParseQuery(query, params));
} catch (const Parser::syntax_error& se) {
@ -693,35 +669,22 @@ bool SearchAlgorithm::Init(string_view query, const QueryParams* params, const S
return false;
}
if (sort != nullptr)
query_ = make_unique<AstNode>(AstSortNode{std::move(query_), sort->field, sort->descending});
return true;
}
SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t limit) const {
auto bs = BasicSearch{index, limit};
SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
auto bs = BasicSearch{index};
if (profiling_enabled_)
bs.EnableProfiling();
return bs.Search(*query_);
}
optional<AggregationInfo> SearchAlgorithm::GetAggregationInfo() const {
optional<KnnScoreSortOption> SearchAlgorithm::GetKnnScoreSortOption() const {
DCHECK(query_);
// KNN query
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
return AggregationInfo{string_view{knn->score_alias}, false, knn->limit};
// SEARCH query with SORTBY option
if (auto* sort = get_if<AstSortNode>(query_.get()); sort) {
string_view alias = "";
if (auto* knn = get_if<AstKnnNode>(&sort->filter->Variant());
knn && knn->score_alias == sort->field)
alias = knn->score_alias;
return AggregationInfo{alias, sort->descending};
}
return KnnScoreSortOption{string_view{knn->score_alias}, knn->limit};
return nullopt;
}

View file

@ -140,9 +140,8 @@ struct SearchResult {
std::string error;
};
struct AggregationInfo {
std::string_view alias;
bool descending;
struct KnnScoreSortOption {
std::string_view score_field_alias;
size_t limit = std::numeric_limits<size_t>::max();
};
@ -153,13 +152,12 @@ class SearchAlgorithm {
~SearchAlgorithm();
// Init with query and return true if successful.
bool Init(std::string_view query, const QueryParams* params, const SortOption* sort = nullptr);
bool Init(std::string_view query, const QueryParams* params);
SearchResult Search(const FieldIndices* index,
size_t limit = std::numeric_limits<size_t>::max()) const;
SearchResult Search(const FieldIndices* index) const;
// if enabled, return limit & alias for knn query
std::optional<AggregationInfo> GetAggregationInfo() const;
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const;
void EnableProfiling();

View file

@ -5,6 +5,7 @@
#include "server/search/aggregator.h"
#include "base/logging.h"
#include "server/search/doc_index.h"
namespace dfly::aggregate {
@ -97,7 +98,7 @@ void Aggregator::DoSort(const SortParams& sort_params) {
if (lv == rv) {
continue;
}
return order == SortParams::SortOrder::ASC ? lv < rv : lv > rv;
return order == SortOrder::ASC ? lv < rv : lv > rv;
}
return false;
};

View file

@ -15,6 +15,10 @@
#include "facade/reply_builder.h"
#include "io/io.h"
namespace dfly {
enum class SortOrder;
}
namespace dfly::aggregate {
struct Reducer;
@ -34,8 +38,6 @@ struct AggregationResult {
};
struct SortParams {
enum class SortOrder { ASC, DESC };
constexpr static int64_t kSortAll = -1;
bool SortAll() const {

View file

@ -5,6 +5,7 @@
#include "server/search/aggregator.h"
#include "base/gtest.h"
#include "server/search/doc_index.h"
namespace dfly::aggregate {
@ -20,7 +21,7 @@ TEST(AggregatorTest, Sort) {
};
SortParams params;
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
params.fields.emplace_back("a", SortOrder::ASC);
StepsList steps = {MakeSortStep(std::move(params))};
auto result = Process(values, {"a"}, steps);

View file

@ -50,6 +50,68 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) {
} while (cursor);
}
bool IsSortableField(std::string_view field_identifier, const search::Schema& schema) {
auto it = schema.fields.find(field_identifier);
return it != schema.fields.end() && (it->second.flags & search::SchemaField::SORTABLE);
}
SearchFieldsList ToSV(const search::Schema& schema, const std::optional<SearchFieldsList>& fields) {
SearchFieldsList sv_fields;
if (fields) {
sv_fields.reserve(fields->size());
for (const auto& field : fields.value()) {
sv_fields.push_back(field.View());
}
}
return sv_fields;
}
using SortIndiciesFieldsList =
std::vector<std::pair<string_view /*identifier*/, string_view /*alias*/>>;
std::pair<SearchFieldsList, SortIndiciesFieldsList> PreprocessAggregateFields(
const search::Schema& schema, const AggregateParams& params,
const std::optional<SearchFieldsList>& load_fields) {
absl::flat_hash_map<std::string_view, SearchField> fields_by_identifier;
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
fields_by_identifier.reserve(schema.field_names.size());
sort_indicies_aliases.reserve(schema.field_names.size());
for (const auto& [fname, fident] : schema.field_names) {
if (!IsSortableField(fident, schema)) {
fields_by_identifier[fident] = {StringOrView::FromView(fident), true,
StringOrView::FromView(fname)};
} else {
sort_indicies_aliases[fident] = fname;
}
}
if (load_fields) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(schema, false);
if (!IsSortableField(fident, schema)) {
fields_by_identifier[fident] = field.View();
} else {
sort_indicies_aliases[fident] = field.GetShortName();
}
}
}
SearchFieldsList fields;
fields.reserve(fields_by_identifier.size());
for (auto& [_, field] : fields_by_identifier) {
fields.emplace_back(std::move(field));
}
SortIndiciesFieldsList sort_fields;
sort_fields.reserve(sort_indicies_aliases.size());
for (auto& [fident, fname] : sort_indicies_aliases) {
sort_fields.emplace_back(fident, fname);
}
return {std::move(fields), std::move(sort_fields)};
}
} // namespace
bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const {
@ -271,21 +333,10 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}
SearchFieldsList ToSV(const search::Schema& schema, const std::optional<SearchFieldsList>& fields) {
SearchFieldsList sv_fields;
if (fields) {
sv_fields.reserve(fields->size());
for (const auto& field : fields.value()) {
sv_fields.push_back(field.View());
}
}
return sv_fields;
}
SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.GetDbSlice();
auto search_results = search_algo->Search(&*indices_, params.limit_offset + params.limit_total);
auto search_results = search_algo->Search(&*indices_);
if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
@ -298,7 +349,8 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
size_t expired_count = 0;
for (size_t i = 0; i < search_results.ids.size(); i++) {
auto key = key_index_.Get(search_results.ids[i]);
const DocId doc = search_results.ids[i];
auto key = key_index_.Get(doc);
auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) { // Item must have expired
@ -315,13 +367,24 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
For JSON indexes it would be {"$", <the whole document as string>}
*/
doc_data = accessor->SerializeDocument(base_->schema);
}
SearchDocData loaded_fields = accessor->Serialize(base_->schema, fields_to_load);
doc_data.insert(std::make_move_iterator(loaded_fields.begin()),
std::make_move_iterator(loaded_fields.end()));
} else {
/* Load only specific fields */
doc_data = accessor->Serialize(base_->schema, fields_to_load);
SearchDocData loaded_fields = accessor->Serialize(base_->schema, fields_to_load);
doc_data.insert(std::make_move_iterator(loaded_fields.begin()),
std::make_move_iterator(loaded_fields.end()));
if (params.sort_option) {
auto& field = params.sort_option->field;
auto fident = field.GetIdentifier(base_->schema, false);
if (IsSortableField(fident, base_->schema)) {
doc_data[field.NameView()] = indices_->GetSortIndexValue(doc, fident);
} else {
SearchDocData sort_field_data = accessor->Serialize(base_->schema, {field});
DCHECK_LE(sort_field_data.size(), 1u);
if (!sort_field_data.empty()) {
doc_data[field.NameView()] = sort_field_data.begin()->second;
}
}
}
auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]);
@ -332,57 +395,6 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
std::move(search_results.profile)};
}
using SortIndiciesFieldsList =
std::vector<std::pair<string_view /*identifier*/, string_view /*alias*/>>;
std::pair<SearchFieldsList, SortIndiciesFieldsList> PreprocessAggregateFields(
const search::Schema& schema, const AggregateParams& params,
const std::optional<SearchFieldsList>& load_fields) {
auto is_sortable = [&schema](std::string_view fident) {
auto it = schema.fields.find(fident);
return it != schema.fields.end() && (it->second.flags & search::SchemaField::SORTABLE);
};
absl::flat_hash_map<std::string_view, SearchField> fields_by_identifier;
absl::flat_hash_map<std::string_view, std::string_view> sort_indicies_aliases;
fields_by_identifier.reserve(schema.field_names.size());
sort_indicies_aliases.reserve(schema.field_names.size());
for (const auto& [fname, fident] : schema.field_names) {
if (!is_sortable(fident)) {
fields_by_identifier[fident] = {StringOrView::FromView(fident), true,
StringOrView::FromView(fname)};
} else {
sort_indicies_aliases[fident] = fname;
}
}
if (load_fields) {
for (const auto& field : load_fields.value()) {
const auto& fident = field.GetIdentifier(schema, false);
if (!is_sortable(fident)) {
fields_by_identifier[fident] = field.View();
} else {
sort_indicies_aliases[fident] = field.GetShortName();
}
}
}
SearchFieldsList fields;
fields.reserve(fields_by_identifier.size());
for (auto& [_, field] : fields_by_identifier) {
fields.emplace_back(std::move(field));
}
SortIndiciesFieldsList sort_fields;
sort_fields.reserve(sort_indicies_aliases.size());
for (auto& [fident, fname] : sort_indicies_aliases) {
sort_fields.emplace_back(fident, fname);
}
return {std::move(fields), std::move(sort_fields)};
}
vector<SearchDocData> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, const AggregateParams& params,
search::SearchAlgorithm* search_algo) const {

View file

@ -6,6 +6,7 @@
#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/strings/match.h>
#include <memory>
#include <optional>
@ -71,6 +72,10 @@ class SearchField {
public:
SearchField() = default;
explicit SearchField(StringOrView name) : name_(std::move(name)) {
is_short_name_ = !IsJsonPath(NameView());
}
SearchField(StringOrView name, bool is_short_name)
: name_(std::move(name)), is_short_name_(is_short_name) {
}
@ -111,15 +116,15 @@ class SearchField {
return SearchField{StringOrView::FromView(NameView()), is_short_name_};
}
std::string_view NameView() const {
return name_.view();
}
private:
bool HasNewAlias() const {
return !new_alias_.empty();
}
std::string_view NameView() const {
return name_.view();
}
std::string_view AliasView() const {
return new_alias_.view();
}
@ -132,7 +137,14 @@ class SearchField {
using SearchFieldsList = std::vector<SearchField>;
enum class SortOrder { ASC, DESC };
struct SearchParams {
struct SortOption {
SearchField field;
SortOrder order = SortOrder::ASC;
};
// Parameters for "LIMIT offset total": select total amount documents with a specific offset from
// the whole result set
size_t limit_offset = 0;
@ -153,7 +165,7 @@ struct SearchParams {
std::optional<SearchFieldsList> load_fields;
bool no_content = false;
std::optional<search::SortOption> sort_option;
std::optional<SortOption> sort_option;
search::QueryParams query_params;
bool ShouldReturnAllFields() const {

View file

@ -377,8 +377,10 @@ ParseResult<SearchParams> ParseSearchParams(CmdArgParser* parser) {
} else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(parser);
} else if (parser->Check("SORTBY")) {
params.sort_option =
search::SortOption{parser->Next<std::string>(), bool(parser->Check("DESC"))};
auto parsed_field = ParseField(parser);
StringOrView field = StringOrView::FromString(std::string{parsed_field});
params.sort_option = SearchParams::SortOption{
SearchField{std::move(field)}, parser->Check("DESC") ? SortOrder::DESC : SortOrder::ASC};
} else {
// Unsupported parameters are ignored for now
parser->Skip(1);
@ -389,8 +391,6 @@ ParseResult<SearchParams> ParseSearchParams(CmdArgParser* parser) {
}
std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
using SortOrder = aggregate::SortParams::SortOrder;
size_t strings_num = parser->Next<size_t>();
aggregate::SortParams sort_params;
@ -537,7 +537,8 @@ void SendSerializedDoc(const SerializedSearchDoc& doc, SinkReplyBuilder* builder
}
}
void SearchReply(const SearchParams& params, std::optional<search::AggregationInfo> agg_info,
void SearchReply(const SearchParams& params,
std::optional<search::KnnScoreSortOption> knn_sort_option,
absl::Span<SearchResult> results, SinkReplyBuilder* builder) {
size_t total_hits = 0;
absl::InlinedVector<SerializedSearchDoc*, 5> docs;
@ -552,17 +553,14 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
size_t size = docs.size();
bool should_add_score_field = false;
if (agg_info) {
size = std::min(size, agg_info->limit);
total_hits = std::min(total_hits, agg_info->limit);
should_add_score_field = params.ShouldReturnField(agg_info->alias);
if (knn_sort_option) {
size = std::min(size, knn_sort_option->limit);
total_hits = std::min(total_hits, knn_sort_option->limit);
should_add_score_field = params.ShouldReturnField(knn_sort_option->score_field_alias);
using Comparator = bool (*)(const SerializedSearchDoc*, const SerializedSearchDoc*);
auto comparator =
!agg_info->descending
? static_cast<Comparator>([](const SerializedSearchDoc* l,
const SerializedSearchDoc* r) { return *l < *r; })
: [](const SerializedSearchDoc* l, const SerializedSearchDoc* r) { return *r < *l; };
auto comparator = [](const SerializedSearchDoc* l, const SerializedSearchDoc* r) {
return *l < *r;
};
const size_t prefix_size_to_sort = std::min(params.limit_offset + params.limit_total, size);
if (prefix_size_to_sort == docs.size()) {
@ -575,6 +573,38 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
const size_t offset = std::min(params.limit_offset, size);
const size_t limit = std::min(size - offset, params.limit_total);
const size_t end = offset + limit;
DCHECK(end <= docs.size());
if (params.sort_option) {
auto field_alias = params.sort_option->field.NameView();
auto comparator = [&](const SerializedSearchDoc* l_doc, const SerializedSearchDoc* r_doc) {
auto& l = l_doc->values;
auto& r = r_doc->values;
auto l_it = l.find(field_alias);
auto r_it = r.find(field_alias);
// If some of the values is not present
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end();
}
const auto& lv = l_it->second;
const auto& rv = r_it->second;
return params.sort_option->order == SortOrder::ASC ? lv < rv : lv > rv;
};
auto sort_begin = docs.begin();
auto sort_end = docs.end();
// If we first sorted by knn, we need to sort only the result of knn
if (knn_sort_option) {
sort_begin = docs.begin() + offset;
sort_end = docs.begin() + end;
}
std::sort(sort_begin, sort_end, std::move(comparator));
}
const bool reply_with_ids_only = params.IdsOnly();
const size_t reply_size = reply_with_ids_only ? (limit + 1) : (limit * 2 + 1);
@ -585,7 +615,6 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
rb->StartArray(reply_size);
rb->SendLong(total_hits);
const size_t end = offset + limit;
for (size_t i = offset; i < end; i++) {
if (reply_with_ids_only) {
rb->SendBulkString(docs[i]->key);
@ -593,7 +622,8 @@ void SearchReply(const SearchParams& params, std::optional<search::AggregationIn
}
if (should_add_score_field && holds_alternative<float>(docs[i]->score))
docs[i]->values[agg_info->alias] = absl::StrCat(get<float>(docs[i]->score));
docs[i]->values[knn_sort_option->score_field_alias] =
absl::StrCat(get<float>(docs[i]->score));
SendSerializedDoc(*docs[i], builder);
}
@ -811,8 +841,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
return;
search::SearchAlgorithm search_algo;
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
if (!search_algo.Init(query_str, &params->query_params, sort_opt))
if (!search_algo.Init(query_str, &params->query_params))
return builder->SendError("Query syntax error");
// Because our coordinator thread may not have a shard, we can't check ahead if the index exists.
@ -835,7 +864,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
return builder->SendError(*res.error);
}
SearchReply(*params, search_algo.GetAggregationInfo(), absl::MakeSpan(docs), builder);
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(docs), builder);
}
void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
@ -858,8 +887,7 @@ void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
return;
search::SearchAlgorithm search_algo;
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
if (!search_algo.Init(query_str, &params->query_params, sort_opt))
if (!search_algo.Init(query_str, &params->query_params))
return rb->SendError("query syntax error");
search_algo.EnableProfiling();
@ -911,7 +939,7 @@ void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
// Result of the search command
if (!result_is_empty) {
SearchReply(*params, search_algo.GetAggregationInfo(), absl::MakeSpan(search_results), rb);
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(search_results), rb);
} else {
rb->StartArray(1);
rb->SendLong(0);
@ -1010,7 +1038,7 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
return;
search::SearchAlgorithm search_algo;
if (!search_algo.Init(params->query, &params->params, nullptr))
if (!search_algo.Init(params->query, &params->params))
return builder->SendError("Query syntax error");
using ResultContainer = decltype(declval<ShardDocIndex>().SearchForAggregator(

View file

@ -2278,4 +2278,69 @@ TEST_F(SearchFamilyTest, PrefixSearchWithSynonyms) {
EXPECT_THAT(resp, AreDocIds("doc:6")); // Should only find macintosh
}
TEST_F(SearchFamilyTest, SearchSortByOptionNonSortableFieldJson) {
Run({"JSON.SET", "json1", "$", R"({"text":"2"})"});
Run({"JSON.SET", "json2", "$", R"({"text":"1"})"});
auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.text", "AS", "text", "TEXT"});
EXPECT_EQ(resp, "OK");
auto expect_expr = [](std::string_view text_field) {
return IsArray(2, "json2", IsMap(text_field, "\"1\"", "$", R"({"text":"1"})"), "json1",
IsMap(text_field, "\"2\"", "$", R"({"text":"2"})"));
};
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "text"});
EXPECT_THAT(resp, expect_expr("text"sv));
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "@text"});
EXPECT_THAT(resp, expect_expr("text"sv));
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "$.text"});
EXPECT_THAT(resp, expect_expr("$.text"sv));
}
TEST_F(SearchFamilyTest, SearchSortByOptionNonSortableFieldHash) {
Run({"HSET", "h1", "text", "2"});
Run({"HSET", "h2", "text", "1"});
auto resp = Run({"FT.CREATE", "index", "ON", "HASH", "SCHEMA", "text", "TEXT"});
EXPECT_EQ(resp, "OK");
auto expected_expr = IsArray(2, "h2", IsMap("text", "1"), "h1", IsMap("text", "2"));
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "text"});
EXPECT_THAT(resp, expected_expr);
resp = Run({"FT.SEARCH", "index", "*", "SORTBY", "@text"});
EXPECT_THAT(resp, expected_expr);
}
TEST_F(SearchFamilyTest, KnnSearchWithSortby) {
auto to_vector = [](const char* value) { return std::string(value, 16); };
Run({"HSET", "doc:1", "timestamp", "1713100000", "embedding",
to_vector("\x3d\xcc\xcc\x3d\x00\x00\x80\x3f\xcd\xcc\x4c\x3e\x9a\x99\x19\x3f")});
Run({"HSET", "doc:2", "timestamp", "1713200000", "embedding",
to_vector("\x9a\x99\x19\x3f\xcd\xcc\x4c\x3e\x00\x00\x80\x3f\x3d\xcc\xcc\x3d")});
Run({"HSET", "doc:3", "timestamp", "1713300000", "embedding",
to_vector("\x00\x00\x80\x3f\x3d\xcc\xcc\x3d\xcd\xcc\x4c\x3e\x9a\x99\x19\x3f")});
Run({"FT.CREATE", "my_index", "ON", "HASH", "SCHEMA", "timestamp", "NUMERIC", "SORTABLE",
"embedding", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "4", "DISTANCE_METRIC",
"COSINE"});
auto search_vector =
to_vector("\x3d\xcc\xcc\x3d\x00\x00\x80\x3f\xcd\xcc\x4c\x3e\x9a\x99\x19\x3f");
auto resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 2 @embedding $vec]", "PARAMS", "2", "vec",
search_vector, "NOCONTENT"});
EXPECT_THAT(resp, IsArray(2, "doc:1", "doc:3"));
// FT.SEARCH with KNN + SORTBY
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 2 @embedding $vec]", "PARAMS", "2", "vec",
search_vector, "SORTBY", "timestamp", "DESC", "NOCONTENT"});
EXPECT_THAT(resp, IsArray(2, "doc:3", "doc:1"));
}
} // namespace dfly