diff --git a/src/core/search/base.cc b/src/core/search/base.cc index 7f3dfa2ba..b6c4d3f6a 100644 --- a/src/core/search/base.cc +++ b/src/core/search/base.cc @@ -21,6 +21,10 @@ WrappedStrPtr::WrappedStrPtr(const PMR_NS::string& s) std::strcpy(ptr.get(), s.c_str()); } +WrappedStrPtr::WrappedStrPtr(const std::string& s) : ptr{std::make_unique(s.size() + 1)} { + std::strcpy(ptr.get(), s.c_str()); +} + bool WrappedStrPtr::operator<(const WrappedStrPtr& other) const { return std::strcmp(ptr.get(), other.ptr.get()) < 0; } @@ -29,4 +33,8 @@ bool WrappedStrPtr::operator>=(const WrappedStrPtr& other) const { return !operator<(other); } +WrappedStrPtr::operator std::string_view() const { + return std::string_view{ptr.get(), std::strlen(ptr.get())}; +} + } // namespace dfly::search diff --git a/src/core/search/base.h b/src/core/search/base.h index 89fe9b74c..4949f95ab 100644 --- a/src/core/search/base.h +++ b/src/core/search/base.h @@ -43,18 +43,28 @@ struct SortOption { bool descending = false; }; +// Comparable string stored as char[]. Used to reduce size of std::variant with strings. struct WrappedStrPtr { // Intentionally implicit and const std::string& for use in templates WrappedStrPtr(const PMR_NS::string& s); + WrappedStrPtr(const std::string& s); bool operator<(const WrappedStrPtr& other) const; bool operator>=(const WrappedStrPtr& other) const; + operator std::string_view() const; + private: std::unique_ptr ptr; }; +// Score produced either by KNN (float) or SORT (double / wrapped str) using ResultScore = std::variant; +// Values are either sortable as doubles or strings, or not sortable at all. +// Contrary to ResultScore it doesn't include KNN results and is not optimized for smaller struct +// size. +using SortableValue = std::variant; + // Interface for accessing document values with different data structures underneath. struct DocumentAccessor { using VectorInfo = search::OwnedFtVector; @@ -78,6 +88,7 @@ struct BaseIndex { // Base class for type-specific sorting indices. struct BaseSortIndex : BaseIndex { + virtual SortableValue Lookup(DocId doc) const = 0; virtual std::vector Sort(std::vector* ids, size_t limit, bool desc) const = 0; }; diff --git a/src/core/search/search.cc b/src/core/search/search.cc index fc59c34b6..1307226df 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -443,6 +443,12 @@ struct BasicSearch { } // namespace +string_view Schema::LookupAlias(string_view alias) const { + if (auto it = field_names.find(alias); it != field_names.end()) + return it->second; + return alias; +} + FieldIndices::FieldIndices(Schema schema, PMR_NS::memory_resource* mr) : schema_{std::move(schema)}, all_ids_{}, indices_{} { CreateIndices(mr); @@ -521,20 +527,12 @@ void FieldIndices::Remove(DocId doc, DocumentAccessor* access) { } BaseIndex* FieldIndices::GetIndex(string_view field) const { - // Replace short field name with full identifier - if (auto it = schema_.field_names.find(field); it != schema_.field_names.end()) - field = it->second; - - auto it = indices_.find(field); + auto it = indices_.find(schema_.LookupAlias(field)); return it != indices_.end() ? it->second.get() : nullptr; } BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const { - // Replace short field name with full identifier - if (auto it = schema_.field_names.find(field); it != schema_.field_names.end()) - field = it->second; - - auto it = sort_indices_.find(field); + auto it = sort_indices_.find(schema_.LookupAlias(field)); return it != sort_indices_.end() ? it->second.get() : nullptr; } @@ -558,6 +556,14 @@ const Schema& FieldIndices::GetSchema() const { return schema_; } +vector> FieldIndices::ExtractStoredValues(DocId doc) const { + vector> out; + for (const auto& [ident, index] : sort_indices_) { + out.emplace_back(ident, index->Lookup(doc)); + } + return out; +} + SearchAlgorithm::SearchAlgorithm() = default; SearchAlgorithm::~SearchAlgorithm() = default; diff --git a/src/core/search/search.h b/src/core/search/search.h index b7b464e0f..3202e0e9c 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -51,6 +51,9 @@ struct Schema { // Mapping for short field names (aliases). absl::flat_hash_map field_names; + + // Return identifier for alias if found, otherwise return passed value + std::string_view LookupAlias(std::string_view alias) const; }; // Collection of indices for all fields in schema @@ -64,13 +67,14 @@ class FieldIndices { BaseIndex* GetIndex(std::string_view field) const; BaseSortIndex* GetSortIndex(std::string_view field) const; - std::vector GetAllTextIndices() const; const std::vector& GetAllDocs() const; - const Schema& GetSchema() const; + // Extract values stored in sort indices + std::vector> ExtractStoredValues(DocId doc) const; + private: void CreateIndices(PMR_NS::memory_resource* mr); void CreateSortIndices(PMR_NS::memory_resource* mr); diff --git a/src/core/search/sort_indices.cc b/src/core/search/sort_indices.cc index 8f4cc0a6f..ed9e84255 100644 --- a/src/core/search/sort_indices.cc +++ b/src/core/search/sort_indices.cc @@ -10,6 +10,7 @@ #include #include +#include namespace dfly::search { @@ -21,6 +22,15 @@ template SimpleValueSortIndex::SimpleValueSortIndex(PMR_NS::memory_resource* mr) : values_{mr} { } +template SortableValue SimpleValueSortIndex::Lookup(DocId doc) const { + DCHECK_LT(doc, values_.size()); + if constexpr (is_same_v) { + return std::string(values_[doc]); + } else { + return values_[doc]; + } +} + template std::vector SimpleValueSortIndex::Sort(std::vector* ids, size_t limit, bool desc) const { @@ -37,7 +47,7 @@ std::vector SimpleValueSortIndex::Sort(std::vector* ids, template void SimpleValueSortIndex::Add(DocId id, DocumentAccessor* doc, std::string_view field) { - DCHECK(id <= values_.size()); // Doc ids grow at most by one + DCHECK_LE(id, values_.size()); // Doc ids grow at most by one if (id >= values_.size()) values_.resize(id + 1); values_[id] = Get(id, doc, field); @@ -45,6 +55,7 @@ void SimpleValueSortIndex::Add(DocId id, DocumentAccessor* doc, std::string_v template void SimpleValueSortIndex::Remove(DocId id, DocumentAccessor* doc, std::string_view field) { + DCHECK_LT(id, values_.size()); values_[id] = T{}; } diff --git a/src/core/search/sort_indices.h b/src/core/search/sort_indices.h index f66a1de14..591839a77 100644 --- a/src/core/search/sort_indices.h +++ b/src/core/search/sort_indices.h @@ -21,6 +21,7 @@ namespace dfly::search { template struct SimpleValueSortIndex : BaseSortIndex { SimpleValueSortIndex(PMR_NS::memory_resource* mr); + SortableValue Lookup(DocId doc) const override; std::vector Sort(std::vector* ids, size_t limit, bool desc) const override; void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; diff --git a/src/server/search/aggregator.cc b/src/server/search/aggregator.cc index 27645fd09..e70df3791 100644 --- a/src/server/search/aggregator.cc +++ b/src/server/search/aggregator.cc @@ -96,8 +96,8 @@ PipelineStep MakeGroupStep(absl::Span fields, return GroupStep{std::vector(fields.begin(), fields.end()), std::move(reducers)}; } -PipelineStep MakeSortStep(std::string field, bool descending) { - return [field, descending](std::vector values) -> PipelineResult { +PipelineStep MakeSortStep(std::string_view field, bool descending) { + return [field = std::string(field), descending](std::vector values) -> PipelineResult { std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) { auto it1 = l.find(field); auto it2 = r.find(field); @@ -117,7 +117,7 @@ PipelineStep MakeLimitStep(size_t offset, size_t num) { }; } -PipelineResult Process(std::vector values, absl::Span steps) { +PipelineResult Process(std::vector values, absl::Span steps) { for (auto& step : steps) { auto result = step(std::move(values)); if (!result.has_value()) diff --git a/src/server/search/aggregator.h b/src/server/search/aggregator.h index e8a9a22cf..0699ce233 100644 --- a/src/server/search/aggregator.h +++ b/src/server/search/aggregator.h @@ -10,14 +10,17 @@ #include #include +#include "core/search/base.h" #include "facade/reply_builder.h" #include "io/io.h" namespace dfly::aggregate { -using Value = std::variant; +using Value = ::dfly::search::SortableValue; using DocValues = absl::flat_hash_map; // documents sent through the pipeline +// TODO: Replace DocValues with compact linear search map instead of hash map + using PipelineResult = io::Result, facade::ErrorReply>; using PipelineStep = std::function)>; // Group, Sort, etc. @@ -71,12 +74,12 @@ PipelineStep MakeGroupStep(absl::Span fields, std::vector reducers); // Make `SORYBY field [DESC]` step -PipelineStep MakeSortStep(std::string field, bool descending = false); +PipelineStep MakeSortStep(std::string_view field, bool descending = false); // Make `LIMIT offset num` step PipelineStep MakeLimitStep(size_t offset, size_t num); // Process values with given steps -PipelineResult Process(std::vector values, absl::Span steps); +PipelineResult Process(std::vector values, absl::Span steps); } // namespace dfly::aggregate diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index cb445bfff..b277b53fe 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -220,8 +220,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields) : accessor->Serialize(base_->schema); - auto score = - search_results.scores.empty() ? std::monostate{} : std::move(search_results.scores[i]); + auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]); out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)}); } @@ -229,6 +228,38 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa std::move(search_results.profile)}; } +vector> ShardDocIndex::SearchForAggregator( + const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const { + auto& db_slice = op_args.shard->db_slice(); + auto search_results = search_algo->Search(&indices_); + + if (!search_results.error.empty()) + return {}; + + // Convert load_fields into return_list required by accessor interface + SearchParams::FieldReturnList return_fields; + for (string_view load_field : load_fields) + return_fields.emplace_back(indices_.GetSchema().LookupAlias(load_field), load_field); + + vector> out; + for (DocId doc : search_results.ids) { + auto key = key_index_.Get(doc); + auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode()); + + if (!it || !IsValid(*it)) // Item must have expired + continue; + + auto accessor = GetAccessor(op_args.db_cntx, (*it)->second); + auto extracted = indices_.ExtractStoredValues(doc); + auto loaded = accessor->Serialize(base_->schema, return_fields); + + out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end())); + out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end())); + } + + return out; +} + DocIndexInfo ShardDocIndex::GetInfo() const { return {*base_, key_index_.Size()}; } diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index e6bef434d..164680ea5 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -126,6 +126,10 @@ class ShardDocIndex { SearchResult Search(const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const; + // Perform search and load requested values - note params might be interpreted differently. + std::vector> SearchForAggregator( + const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const; + // Return whether base index matches bool Matches(std::string_view key, unsigned obj_code) const; diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 1420df666..082a17c74 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -24,8 +24,10 @@ #include "server/conn_context.h" #include "server/container_utils.h" #include "server/engine_shard_set.h" +#include "server/search/aggregator.h" #include "server/search/doc_index.h" #include "server/transaction.h" +#include "src/core/overloaded.h" namespace dfly { @@ -152,6 +154,16 @@ optional ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse return schema; } +search::QueryParams ParseQueryParams(CmdArgParser* parser) { + search::QueryParams params; + size_t num_args = parser->Next(); + while (parser->HasNext() && params.Size() * 2 < num_args) { + auto [k, v] = parser->Next(); + params[k] = v; + } + return params; +} + optional ParseSearchParamsOrReply(CmdArgParser parser, ConnectionContext* cntx) { SearchParams params; @@ -183,12 +195,7 @@ optional ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC // [PARAMS num(ignored) name(ignored) knn_vector] if (parser.Check("PARAMS").ExpectTail(1)) { - size_t num_args = parser.Next(); - while (parser.HasNext() && params.query_params.Size() * 2 < num_args) { - string_view k = parser.Next(); - string_view v = parser.Next(); - params.query_params[k] = v; - } + params.query_params = ParseQueryParams(&parser); continue; } @@ -210,6 +217,95 @@ optional ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC return params; } +struct AggregateParams { + string_view index, query; + search::QueryParams params; + + vector load_fields; + vector steps; +}; + +optional ParseAggregatorParamsOrReply(CmdArgParser parser, + ConnectionContext* cntx) { + AggregateParams params; + tie(params.index, params.query) = parser.Next(); + + while (parser.ToUpper().HasNext()) { + // LOAD count field [field ...] + if (parser.Check("LOAD").ExpectTail(1)) { + params.load_fields.resize(parser.Next()); + for (string_view& field : params.load_fields) + field = parser.Next(); + continue; + } + + // GROUPBY nargs property [property ...] + if (parser.Check("GROUPBY").ExpectTail(1)) { + vector fields(parser.Next()); + for (string_view& field : fields) + field = parser.Next(); + + vector reducers; + while (parser.ToUpper().Check("REDUCE").ExpectTail(2)) { + parser.ToUpper(); // uppercase for func_name + auto [func_name, nargs] = parser.Next(); + auto func = aggregate::FindReducerFunc(func_name); + + if (!parser.HasError() && !func) { + cntx->SendError(absl::StrCat("reducer function ", func_name, " not found")); + return nullopt; + } + + string source_field = ""; + if (nargs > 0) { + source_field = parser.Next(); + } + + parser.ExpectTag("AS"); + string result_field = parser.Next(); + + reducers.push_back(aggregate::Reducer{source_field, result_field, std::move(func)}); + } + + params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers))); + continue; + } + + // SORTBY nargs + if (parser.Check("SORTBY").ExpectTail(1)) { + parser.ExpectTag("1"); + string_view field = parser.Next(); + bool desc = bool(parser.Check("DESC").IgnoreCase()); + + params.steps.push_back(aggregate::MakeSortStep(field, desc)); + continue; + } + + // LIMIT + if (parser.Check("LIMIT").ExpectTail(2)) { + auto [offset, num] = parser.Next(); + params.steps.push_back(aggregate::MakeLimitStep(offset, num)); + continue; + } + + // PARAMS + if (parser.Check("PARAMS").ExpectTail(1)) { + params.params = ParseQueryParams(&parser); + continue; + } + + cntx->SendError(absl::StrCat("Unknown clause: ", parser.Peek())); + return nullopt; + } + + if (auto err = parser.Error(); err) { + cntx->SendError(err->MakeReply()); + return nullopt; + } + + return params; +} + void SendSerializedDoc(const SerializedSearchDoc& doc, ConnectionContext* cntx) { auto* rb = static_cast(cntx->reply_builder()); rb->SendBulkString(doc.key); @@ -594,6 +690,55 @@ void SearchFamily::FtProfile(CmdArgList args, ConnectionContext* cntx) { } } +void SearchFamily::FtAggregate(CmdArgList args, ConnectionContext* cntx) { + const auto params = ParseAggregatorParamsOrReply(args, cntx); + if (!params) + return; + + search::SearchAlgorithm search_algo; + if (!search_algo.Init(params->query, ¶ms->params, nullptr)) + return cntx->SendError("Query syntax error"); + + using ResultContainer = + decltype(declval().SearchForAggregator(declval(), {}, &search_algo)); + + vector query_results(shard_set->size()); + cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + if (auto* index = es->search_indices()->GetIndex(params->index); index) { + query_results[es->shard_id()] = + index->SearchForAggregator(t->GetOpArgs(es), params->load_fields, &search_algo); + } + return OpStatus::OK; + }); + + vector values; + for (auto& sub_results : query_results) { + values.insert(values.end(), make_move_iterator(sub_results.begin()), + make_move_iterator(sub_results.end())); + } + + auto agg_results = aggregate::Process(std::move(values), params->steps); + if (!agg_results.has_value()) + return cntx->SendError(agg_results.error()); + + auto* rb = static_cast(cntx->reply_builder()); + Overloaded replier{ + [rb](monostate) { rb->SendNull(); }, + [rb](double d) { rb->SendDouble(d); }, + [rb](const string& s) { rb->SendBulkString(s); }, + }; + + rb->StartArray(agg_results->size()); + for (const auto& result : agg_results.value()) { + rb->StartArray(result.size()); + for (const auto& [k, v] : result) { + rb->StartArray(2); + rb->SendBulkString(k); + visit(replier, v); + } + } +} + #define HFUNC(x) SetHandler(&SearchFamily::x) // Redis search is a module. Therefore we introduce dragonfly extension search @@ -616,6 +761,7 @@ void SearchFamily::Register(CommandRegistry* registry) { // Underscore same as in RediSearch because it's "temporary" (long time already) << CI{"FT._LIST", kReadOnlyMask, 1, 0, 0, acl::FT_SEARCH}.HFUNC(FtList) << CI{"FT.SEARCH", kReadOnlyMask, -3, 0, 0, acl::FT_SEARCH}.HFUNC(FtSearch) + << CI{"FT.AGGREGATE", kReadOnlyMask, -3, 0, 0, acl::FT_SEARCH}.HFUNC(FtAggregate) << CI{"FT.PROFILE", kReadOnlyMask, -4, 0, 0, acl::FT_SEARCH}.HFUNC(FtProfile); } diff --git a/src/server/search/search_family.h b/src/server/search/search_family.h index 8c97252f5..50edc39fe 100644 --- a/src/server/search/search_family.h +++ b/src/server/search/search_family.h @@ -20,6 +20,7 @@ class SearchFamily { static void FtList(CmdArgList args, ConnectionContext* cntx); static void FtSearch(CmdArgList args, ConnectionContext* cntx); static void FtProfile(CmdArgList args, ConnectionContext* cntx); + static void FtAggregate(CmdArgList args, ConnectionContext* cntx); public: static void Register(CommandRegistry* registry); diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index bf0981087..31a750905 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -62,6 +62,14 @@ template auto AreDocIds(Args... args) { return DocIds(sizeof...(args), vector{args...}); } +template auto IsArray(Args... args) { + return RespArray(ElementsAre(std::forward(args)...)); +} + +template auto IsUnordArray(Args... args) { + return RespArray(UnorderedElementsAre(std::forward(args)...)); +} + TEST_F(SearchFamilyTest, CreateDropListIndex) { EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK"); EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK"); @@ -90,12 +98,10 @@ TEST_F(SearchFamilyTest, InfoIndex) { } auto info = Run({"ft.info", "idx-1"}); - EXPECT_THAT( - info, RespArray(ElementsAre( - _, _, _, RespArray(ElementsAre("key_type", "HASH", "prefix", "doc-")), "attributes", - RespArray(ElementsAre(RespArray( - ElementsAre("identifier", "name", "attribute", "name", "type", "TEXT")))), - "num_docs", IntArg(15)))); + EXPECT_THAT(info, + IsArray(_, _, _, IsArray("key_type", "HASH", "prefix", "doc-"), "attributes", + IsArray(IsArray("identifier", "name", "attribute", "name", "type", "TEXT")), + "num_docs", IntArg(15))); } TEST_F(SearchFamilyTest, Stats) { @@ -385,7 +391,7 @@ TEST_F(SearchFamilyTest, TestReturn) { "NUMERIC", "vector", "VECTOR", "FLAT", "2", "DIM", "1"}); auto MatchEntry = [](string key, auto... fields) { - return RespArray(ElementsAre(IntArg(1), key, RespArray(UnorderedElementsAre(fields...)))); + return RespArray(ElementsAre(IntArg(1), key, IsUnordArray(fields...))); }; // Check all fields are returned @@ -596,4 +602,48 @@ TEST_F(SearchFamilyTest, SimpleExpiry) { Run({"flushall"}); } +TEST_F(SearchFamilyTest, AggregateGroupByReduceSort) { + for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd + Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value", + absl::StrCat(i)}); + } + Run({"ft.create", "i1", "schema", "even", "tag", "sortable", "value", "numeric", "sortable"}); + + // clang-format off + auto resp = Run({"ft.aggregate", "i1", "*", + "GROUPBY", "1", "even", + "REDUCE", "count", "0", "as", "count", + "REDUCE", "count_distinct", "1", "even", "as", "distinct_tags", + "REDUCE", "count_distinct", "1", "value", "as", "distinct_vals", + "REDUCE", "max", "1", "value", "as", "max_val", + "REDUCE", "min", "1", "value", "as", "min_val", + "SORTBY", "1", "count"}); + // clang-format on + + EXPECT_THAT(resp, + IsArray(IsUnordArray(IsArray("even", "false"), IsArray("count", "50"), + IsArray("distinct_tags", "1"), IsArray("distinct_vals", "50"), + IsArray("max_val", "99"), IsArray("min_val", "1")), + IsUnordArray(IsArray("even", "true"), IsArray("count", "51"), + IsArray("distinct_tags", "1"), IsArray("distinct_vals", "51"), + IsArray("max_val", "100"), IsArray("min_val", "0")))); +} + +TEST_F(SearchFamilyTest, AggregateLoadGroupBy) { + for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd + Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value", + absl::StrCat(i)}); + } + Run({"ft.create", "i1", "schema", "value", "numeric", "sortable"}); + + // clang-format off + auto resp = Run({"ft.aggregate", "i1", "*", + "LOAD", "1", "even", + "GROUPBY", "1", "even"}); + // clang-format on + + EXPECT_THAT(resp, IsUnordArray(IsUnordArray(IsArray("even", "false")), + IsUnordArray(IsArray("even", "true")))); +} + } // namespace dfly