mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 10:25:47 +02:00
Basic FT.AGGREGATE (#2413)
Introduces basic FT.AGGREGATE command, supporting GROUPBY, SORTBY, LIMIT Signed-off-by: Vladislav <vladislav.oleshko@gmail.com> ---------
This commit is contained in:
parent
98616755c0
commit
22e413a00b
13 changed files with 310 additions and 34 deletions
|
@ -21,6 +21,10 @@ WrappedStrPtr::WrappedStrPtr(const PMR_NS::string& s)
|
|||
std::strcpy(ptr.get(), s.c_str());
|
||||
}
|
||||
|
||||
WrappedStrPtr::WrappedStrPtr(const std::string& s) : ptr{std::make_unique<char[]>(s.size() + 1)} {
|
||||
std::strcpy(ptr.get(), s.c_str());
|
||||
}
|
||||
|
||||
bool WrappedStrPtr::operator<(const WrappedStrPtr& other) const {
|
||||
return std::strcmp(ptr.get(), other.ptr.get()) < 0;
|
||||
}
|
||||
|
@ -29,4 +33,8 @@ bool WrappedStrPtr::operator>=(const WrappedStrPtr& other) const {
|
|||
return !operator<(other);
|
||||
}
|
||||
|
||||
WrappedStrPtr::operator std::string_view() const {
|
||||
return std::string_view{ptr.get(), std::strlen(ptr.get())};
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -43,18 +43,28 @@ struct SortOption {
|
|||
bool descending = false;
|
||||
};
|
||||
|
||||
// Comparable string stored as char[]. Used to reduce size of std::variant with strings.
|
||||
struct WrappedStrPtr {
|
||||
// Intentionally implicit and const std::string& for use in templates
|
||||
WrappedStrPtr(const PMR_NS::string& s);
|
||||
WrappedStrPtr(const std::string& s);
|
||||
bool operator<(const WrappedStrPtr& other) const;
|
||||
bool operator>=(const WrappedStrPtr& other) const;
|
||||
|
||||
operator std::string_view() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<char[]> ptr;
|
||||
};
|
||||
|
||||
// Score produced either by KNN (float) or SORT (double / wrapped str)
|
||||
using ResultScore = std::variant<std::monostate, float, double, WrappedStrPtr>;
|
||||
|
||||
// Values are either sortable as doubles or strings, or not sortable at all.
|
||||
// Contrary to ResultScore it doesn't include KNN results and is not optimized for smaller struct
|
||||
// size.
|
||||
using SortableValue = std::variant<std::monostate, double, std::string>;
|
||||
|
||||
// Interface for accessing document values with different data structures underneath.
|
||||
struct DocumentAccessor {
|
||||
using VectorInfo = search::OwnedFtVector;
|
||||
|
@ -78,6 +88,7 @@ struct BaseIndex {
|
|||
|
||||
// Base class for type-specific sorting indices.
|
||||
struct BaseSortIndex : BaseIndex {
|
||||
virtual SortableValue Lookup(DocId doc) const = 0;
|
||||
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -443,6 +443,12 @@ struct BasicSearch {
|
|||
|
||||
} // namespace
|
||||
|
||||
string_view Schema::LookupAlias(string_view alias) const {
|
||||
if (auto it = field_names.find(alias); it != field_names.end())
|
||||
return it->second;
|
||||
return alias;
|
||||
}
|
||||
|
||||
FieldIndices::FieldIndices(Schema schema, PMR_NS::memory_resource* mr)
|
||||
: schema_{std::move(schema)}, all_ids_{}, indices_{} {
|
||||
CreateIndices(mr);
|
||||
|
@ -521,20 +527,12 @@ void FieldIndices::Remove(DocId doc, DocumentAccessor* access) {
|
|||
}
|
||||
|
||||
BaseIndex* FieldIndices::GetIndex(string_view field) const {
|
||||
// Replace short field name with full identifier
|
||||
if (auto it = schema_.field_names.find(field); it != schema_.field_names.end())
|
||||
field = it->second;
|
||||
|
||||
auto it = indices_.find(field);
|
||||
auto it = indices_.find(schema_.LookupAlias(field));
|
||||
return it != indices_.end() ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const {
|
||||
// Replace short field name with full identifier
|
||||
if (auto it = schema_.field_names.find(field); it != schema_.field_names.end())
|
||||
field = it->second;
|
||||
|
||||
auto it = sort_indices_.find(field);
|
||||
auto it = sort_indices_.find(schema_.LookupAlias(field));
|
||||
return it != sort_indices_.end() ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
|
@ -558,6 +556,14 @@ const Schema& FieldIndices::GetSchema() const {
|
|||
return schema_;
|
||||
}
|
||||
|
||||
vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(DocId doc) const {
|
||||
vector<pair<string, SortableValue>> out;
|
||||
for (const auto& [ident, index] : sort_indices_) {
|
||||
out.emplace_back(ident, index->Lookup(doc));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
SearchAlgorithm::SearchAlgorithm() = default;
|
||||
SearchAlgorithm::~SearchAlgorithm() = default;
|
||||
|
||||
|
|
|
@ -51,6 +51,9 @@ struct Schema {
|
|||
|
||||
// Mapping for short field names (aliases).
|
||||
absl::flat_hash_map<std::string /* short name*/, std::string /*identifier*/> field_names;
|
||||
|
||||
// Return identifier for alias if found, otherwise return passed value
|
||||
std::string_view LookupAlias(std::string_view alias) const;
|
||||
};
|
||||
|
||||
// Collection of indices for all fields in schema
|
||||
|
@ -64,13 +67,14 @@ class FieldIndices {
|
|||
|
||||
BaseIndex* GetIndex(std::string_view field) const;
|
||||
BaseSortIndex* GetSortIndex(std::string_view field) const;
|
||||
|
||||
std::vector<TextIndex*> GetAllTextIndices() const;
|
||||
|
||||
const std::vector<DocId>& GetAllDocs() const;
|
||||
|
||||
const Schema& GetSchema() const;
|
||||
|
||||
// Extract values stored in sort indices
|
||||
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(DocId doc) const;
|
||||
|
||||
private:
|
||||
void CreateIndices(PMR_NS::memory_resource* mr);
|
||||
void CreateSortIndices(PMR_NS::memory_resource* mr);
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <absl/strings/str_split.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
namespace dfly::search {
|
||||
|
||||
|
@ -21,6 +22,15 @@ template <typename T>
|
|||
SimpleValueSortIndex<T>::SimpleValueSortIndex(PMR_NS::memory_resource* mr) : values_{mr} {
|
||||
}
|
||||
|
||||
template <typename T> SortableValue SimpleValueSortIndex<T>::Lookup(DocId doc) const {
|
||||
DCHECK_LT(doc, values_.size());
|
||||
if constexpr (is_same_v<T, PMR_NS::string>) {
|
||||
return std::string(values_[doc]);
|
||||
} else {
|
||||
return values_[doc];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids, size_t limit,
|
||||
bool desc) const {
|
||||
|
@ -37,7 +47,7 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
|
|||
|
||||
template <typename T>
|
||||
void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
DCHECK(id <= values_.size()); // Doc ids grow at most by one
|
||||
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
|
||||
if (id >= values_.size())
|
||||
values_.resize(id + 1);
|
||||
values_[id] = Get(id, doc, field);
|
||||
|
@ -45,6 +55,7 @@ void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_v
|
|||
|
||||
template <typename T>
|
||||
void SimpleValueSortIndex<T>::Remove(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
DCHECK_LT(id, values_.size());
|
||||
values_[id] = T{};
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ namespace dfly::search {
|
|||
template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
|
||||
SimpleValueSortIndex(PMR_NS::memory_resource* mr);
|
||||
|
||||
SortableValue Lookup(DocId doc) const override;
|
||||
std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const override;
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
|
|
|
@ -96,8 +96,8 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
|
|||
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
|
||||
}
|
||||
|
||||
PipelineStep MakeSortStep(std::string field, bool descending) {
|
||||
return [field, descending](std::vector<DocValues> values) -> PipelineResult {
|
||||
PipelineStep MakeSortStep(std::string_view field, bool descending) {
|
||||
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
|
||||
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
|
||||
auto it1 = l.find(field);
|
||||
auto it2 = r.find(field);
|
||||
|
@ -117,7 +117,7 @@ PipelineStep MakeLimitStep(size_t offset, size_t num) {
|
|||
};
|
||||
}
|
||||
|
||||
PipelineResult Process(std::vector<DocValues> values, absl::Span<PipelineStep> steps) {
|
||||
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
|
||||
for (auto& step : steps) {
|
||||
auto result = step(std::move(values));
|
||||
if (!result.has_value())
|
||||
|
|
|
@ -10,14 +10,17 @@
|
|||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "core/search/base.h"
|
||||
#include "facade/reply_builder.h"
|
||||
#include "io/io.h"
|
||||
|
||||
namespace dfly::aggregate {
|
||||
|
||||
using Value = std::variant<std::monostate, double, std::string>;
|
||||
using Value = ::dfly::search::SortableValue;
|
||||
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
|
||||
|
||||
// TODO: Replace DocValues with compact linear search map instead of hash map
|
||||
|
||||
using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
|
||||
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
|
||||
|
||||
|
@ -71,12 +74,12 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
|
|||
std::vector<Reducer> reducers);
|
||||
|
||||
// Make `SORYBY field [DESC]` step
|
||||
PipelineStep MakeSortStep(std::string field, bool descending = false);
|
||||
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
|
||||
|
||||
// Make `LIMIT offset num` step
|
||||
PipelineStep MakeLimitStep(size_t offset, size_t num);
|
||||
|
||||
// Process values with given steps
|
||||
PipelineResult Process(std::vector<DocValues> values, absl::Span<PipelineStep> steps);
|
||||
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
|
||||
|
||||
} // namespace dfly::aggregate
|
||||
|
|
|
@ -220,8 +220,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
|
|||
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
|
||||
: accessor->Serialize(base_->schema);
|
||||
|
||||
auto score =
|
||||
search_results.scores.empty() ? std::monostate{} : std::move(search_results.scores[i]);
|
||||
auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]);
|
||||
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)});
|
||||
}
|
||||
|
||||
|
@ -229,6 +228,38 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
|
|||
std::move(search_results.profile)};
|
||||
}
|
||||
|
||||
vector<absl::flat_hash_map<string, search::SortableValue>> ShardDocIndex::SearchForAggregator(
|
||||
const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const {
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
auto search_results = search_algo->Search(&indices_);
|
||||
|
||||
if (!search_results.error.empty())
|
||||
return {};
|
||||
|
||||
// Convert load_fields into return_list required by accessor interface
|
||||
SearchParams::FieldReturnList return_fields;
|
||||
for (string_view load_field : load_fields)
|
||||
return_fields.emplace_back(indices_.GetSchema().LookupAlias(load_field), load_field);
|
||||
|
||||
vector<absl::flat_hash_map<string, search::SortableValue>> out;
|
||||
for (DocId doc : search_results.ids) {
|
||||
auto key = key_index_.Get(doc);
|
||||
auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode());
|
||||
|
||||
if (!it || !IsValid(*it)) // Item must have expired
|
||||
continue;
|
||||
|
||||
auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
|
||||
auto extracted = indices_.ExtractStoredValues(doc);
|
||||
auto loaded = accessor->Serialize(base_->schema, return_fields);
|
||||
|
||||
out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end()));
|
||||
out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end()));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
DocIndexInfo ShardDocIndex::GetInfo() const {
|
||||
return {*base_, key_index_.Size()};
|
||||
}
|
||||
|
|
|
@ -126,6 +126,10 @@ class ShardDocIndex {
|
|||
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
|
||||
search::SearchAlgorithm* search_algo) const;
|
||||
|
||||
// Perform search and load requested values - note params might be interpreted differently.
|
||||
std::vector<absl::flat_hash_map<std::string, search::SortableValue>> SearchForAggregator(
|
||||
const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const;
|
||||
|
||||
// Return whether base index matches
|
||||
bool Matches(std::string_view key, unsigned obj_code) const;
|
||||
|
||||
|
|
|
@ -24,8 +24,10 @@
|
|||
#include "server/conn_context.h"
|
||||
#include "server/container_utils.h"
|
||||
#include "server/engine_shard_set.h"
|
||||
#include "server/search/aggregator.h"
|
||||
#include "server/search/doc_index.h"
|
||||
#include "server/transaction.h"
|
||||
#include "src/core/overloaded.h"
|
||||
|
||||
namespace dfly {
|
||||
|
||||
|
@ -152,6 +154,16 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
|
|||
return schema;
|
||||
}
|
||||
|
||||
search::QueryParams ParseQueryParams(CmdArgParser* parser) {
|
||||
search::QueryParams params;
|
||||
size_t num_args = parser->Next<size_t>();
|
||||
while (parser->HasNext() && params.Size() * 2 < num_args) {
|
||||
auto [k, v] = parser->Next<string_view, string_view>();
|
||||
params[k] = v;
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionContext* cntx) {
|
||||
SearchParams params;
|
||||
|
||||
|
@ -183,12 +195,7 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
|
|||
|
||||
// [PARAMS num(ignored) name(ignored) knn_vector]
|
||||
if (parser.Check("PARAMS").ExpectTail(1)) {
|
||||
size_t num_args = parser.Next<size_t>();
|
||||
while (parser.HasNext() && params.query_params.Size() * 2 < num_args) {
|
||||
string_view k = parser.Next();
|
||||
string_view v = parser.Next();
|
||||
params.query_params[k] = v;
|
||||
}
|
||||
params.query_params = ParseQueryParams(&parser);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -210,6 +217,95 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
|
|||
return params;
|
||||
}
|
||||
|
||||
struct AggregateParams {
|
||||
string_view index, query;
|
||||
search::QueryParams params;
|
||||
|
||||
vector<string_view> load_fields;
|
||||
vector<aggregate::PipelineStep> steps;
|
||||
};
|
||||
|
||||
optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
|
||||
ConnectionContext* cntx) {
|
||||
AggregateParams params;
|
||||
tie(params.index, params.query) = parser.Next<string_view, string_view>();
|
||||
|
||||
while (parser.ToUpper().HasNext()) {
|
||||
// LOAD count field [field ...]
|
||||
if (parser.Check("LOAD").ExpectTail(1)) {
|
||||
params.load_fields.resize(parser.Next<size_t>());
|
||||
for (string_view& field : params.load_fields)
|
||||
field = parser.Next();
|
||||
continue;
|
||||
}
|
||||
|
||||
// GROUPBY nargs property [property ...]
|
||||
if (parser.Check("GROUPBY").ExpectTail(1)) {
|
||||
vector<string_view> fields(parser.Next<size_t>());
|
||||
for (string_view& field : fields)
|
||||
field = parser.Next();
|
||||
|
||||
vector<aggregate::Reducer> reducers;
|
||||
while (parser.ToUpper().Check("REDUCE").ExpectTail(2)) {
|
||||
parser.ToUpper(); // uppercase for func_name
|
||||
auto [func_name, nargs] = parser.Next<string_view, size_t>();
|
||||
auto func = aggregate::FindReducerFunc(func_name);
|
||||
|
||||
if (!parser.HasError() && !func) {
|
||||
cntx->SendError(absl::StrCat("reducer function ", func_name, " not found"));
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
string source_field = "";
|
||||
if (nargs > 0) {
|
||||
source_field = parser.Next<string>();
|
||||
}
|
||||
|
||||
parser.ExpectTag("AS");
|
||||
string result_field = parser.Next<string>();
|
||||
|
||||
reducers.push_back(aggregate::Reducer{source_field, result_field, std::move(func)});
|
||||
}
|
||||
|
||||
params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// SORTBY nargs
|
||||
if (parser.Check("SORTBY").ExpectTail(1)) {
|
||||
parser.ExpectTag("1");
|
||||
string_view field = parser.Next();
|
||||
bool desc = bool(parser.Check("DESC").IgnoreCase());
|
||||
|
||||
params.steps.push_back(aggregate::MakeSortStep(field, desc));
|
||||
continue;
|
||||
}
|
||||
|
||||
// LIMIT
|
||||
if (parser.Check("LIMIT").ExpectTail(2)) {
|
||||
auto [offset, num] = parser.Next<size_t, size_t>();
|
||||
params.steps.push_back(aggregate::MakeLimitStep(offset, num));
|
||||
continue;
|
||||
}
|
||||
|
||||
// PARAMS
|
||||
if (parser.Check("PARAMS").ExpectTail(1)) {
|
||||
params.params = ParseQueryParams(&parser);
|
||||
continue;
|
||||
}
|
||||
|
||||
cntx->SendError(absl::StrCat("Unknown clause: ", parser.Peek()));
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
if (auto err = parser.Error(); err) {
|
||||
cntx->SendError(err->MakeReply());
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
void SendSerializedDoc(const SerializedSearchDoc& doc, ConnectionContext* cntx) {
|
||||
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
|
||||
rb->SendBulkString(doc.key);
|
||||
|
@ -594,6 +690,55 @@ void SearchFamily::FtProfile(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
}
|
||||
|
||||
void SearchFamily::FtAggregate(CmdArgList args, ConnectionContext* cntx) {
|
||||
const auto params = ParseAggregatorParamsOrReply(args, cntx);
|
||||
if (!params)
|
||||
return;
|
||||
|
||||
search::SearchAlgorithm search_algo;
|
||||
if (!search_algo.Init(params->query, ¶ms->params, nullptr))
|
||||
return cntx->SendError("Query syntax error");
|
||||
|
||||
using ResultContainer =
|
||||
decltype(declval<ShardDocIndex>().SearchForAggregator(declval<OpArgs>(), {}, &search_algo));
|
||||
|
||||
vector<ResultContainer> query_results(shard_set->size());
|
||||
cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
|
||||
if (auto* index = es->search_indices()->GetIndex(params->index); index) {
|
||||
query_results[es->shard_id()] =
|
||||
index->SearchForAggregator(t->GetOpArgs(es), params->load_fields, &search_algo);
|
||||
}
|
||||
return OpStatus::OK;
|
||||
});
|
||||
|
||||
vector<aggregate::DocValues> values;
|
||||
for (auto& sub_results : query_results) {
|
||||
values.insert(values.end(), make_move_iterator(sub_results.begin()),
|
||||
make_move_iterator(sub_results.end()));
|
||||
}
|
||||
|
||||
auto agg_results = aggregate::Process(std::move(values), params->steps);
|
||||
if (!agg_results.has_value())
|
||||
return cntx->SendError(agg_results.error());
|
||||
|
||||
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
|
||||
Overloaded replier{
|
||||
[rb](monostate) { rb->SendNull(); },
|
||||
[rb](double d) { rb->SendDouble(d); },
|
||||
[rb](const string& s) { rb->SendBulkString(s); },
|
||||
};
|
||||
|
||||
rb->StartArray(agg_results->size());
|
||||
for (const auto& result : agg_results.value()) {
|
||||
rb->StartArray(result.size());
|
||||
for (const auto& [k, v] : result) {
|
||||
rb->StartArray(2);
|
||||
rb->SendBulkString(k);
|
||||
visit(replier, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define HFUNC(x) SetHandler(&SearchFamily::x)
|
||||
|
||||
// Redis search is a module. Therefore we introduce dragonfly extension search
|
||||
|
@ -616,6 +761,7 @@ void SearchFamily::Register(CommandRegistry* registry) {
|
|||
// Underscore same as in RediSearch because it's "temporary" (long time already)
|
||||
<< CI{"FT._LIST", kReadOnlyMask, 1, 0, 0, acl::FT_SEARCH}.HFUNC(FtList)
|
||||
<< CI{"FT.SEARCH", kReadOnlyMask, -3, 0, 0, acl::FT_SEARCH}.HFUNC(FtSearch)
|
||||
<< CI{"FT.AGGREGATE", kReadOnlyMask, -3, 0, 0, acl::FT_SEARCH}.HFUNC(FtAggregate)
|
||||
<< CI{"FT.PROFILE", kReadOnlyMask, -4, 0, 0, acl::FT_SEARCH}.HFUNC(FtProfile);
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ class SearchFamily {
|
|||
static void FtList(CmdArgList args, ConnectionContext* cntx);
|
||||
static void FtSearch(CmdArgList args, ConnectionContext* cntx);
|
||||
static void FtProfile(CmdArgList args, ConnectionContext* cntx);
|
||||
static void FtAggregate(CmdArgList args, ConnectionContext* cntx);
|
||||
|
||||
public:
|
||||
static void Register(CommandRegistry* registry);
|
||||
|
|
|
@ -62,6 +62,14 @@ template <typename... Args> auto AreDocIds(Args... args) {
|
|||
return DocIds(sizeof...(args), vector<string>{args...});
|
||||
}
|
||||
|
||||
template <typename... Args> auto IsArray(Args... args) {
|
||||
return RespArray(ElementsAre(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
template <typename... Args> auto IsUnordArray(Args... args) {
|
||||
return RespArray(UnorderedElementsAre(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, CreateDropListIndex) {
|
||||
EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK");
|
||||
EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK");
|
||||
|
@ -90,12 +98,10 @@ TEST_F(SearchFamilyTest, InfoIndex) {
|
|||
}
|
||||
|
||||
auto info = Run({"ft.info", "idx-1"});
|
||||
EXPECT_THAT(
|
||||
info, RespArray(ElementsAre(
|
||||
_, _, _, RespArray(ElementsAre("key_type", "HASH", "prefix", "doc-")), "attributes",
|
||||
RespArray(ElementsAre(RespArray(
|
||||
ElementsAre("identifier", "name", "attribute", "name", "type", "TEXT")))),
|
||||
"num_docs", IntArg(15))));
|
||||
EXPECT_THAT(info,
|
||||
IsArray(_, _, _, IsArray("key_type", "HASH", "prefix", "doc-"), "attributes",
|
||||
IsArray(IsArray("identifier", "name", "attribute", "name", "type", "TEXT")),
|
||||
"num_docs", IntArg(15)));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, Stats) {
|
||||
|
@ -385,7 +391,7 @@ TEST_F(SearchFamilyTest, TestReturn) {
|
|||
"NUMERIC", "vector", "VECTOR", "FLAT", "2", "DIM", "1"});
|
||||
|
||||
auto MatchEntry = [](string key, auto... fields) {
|
||||
return RespArray(ElementsAre(IntArg(1), key, RespArray(UnorderedElementsAre(fields...))));
|
||||
return RespArray(ElementsAre(IntArg(1), key, IsUnordArray(fields...)));
|
||||
};
|
||||
|
||||
// Check all fields are returned
|
||||
|
@ -596,4 +602,48 @@ TEST_F(SearchFamilyTest, SimpleExpiry) {
|
|||
Run({"flushall"});
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, AggregateGroupByReduceSort) {
|
||||
for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd
|
||||
Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value",
|
||||
absl::StrCat(i)});
|
||||
}
|
||||
Run({"ft.create", "i1", "schema", "even", "tag", "sortable", "value", "numeric", "sortable"});
|
||||
|
||||
// clang-format off
|
||||
auto resp = Run({"ft.aggregate", "i1", "*",
|
||||
"GROUPBY", "1", "even",
|
||||
"REDUCE", "count", "0", "as", "count",
|
||||
"REDUCE", "count_distinct", "1", "even", "as", "distinct_tags",
|
||||
"REDUCE", "count_distinct", "1", "value", "as", "distinct_vals",
|
||||
"REDUCE", "max", "1", "value", "as", "max_val",
|
||||
"REDUCE", "min", "1", "value", "as", "min_val",
|
||||
"SORTBY", "1", "count"});
|
||||
// clang-format on
|
||||
|
||||
EXPECT_THAT(resp,
|
||||
IsArray(IsUnordArray(IsArray("even", "false"), IsArray("count", "50"),
|
||||
IsArray("distinct_tags", "1"), IsArray("distinct_vals", "50"),
|
||||
IsArray("max_val", "99"), IsArray("min_val", "1")),
|
||||
IsUnordArray(IsArray("even", "true"), IsArray("count", "51"),
|
||||
IsArray("distinct_tags", "1"), IsArray("distinct_vals", "51"),
|
||||
IsArray("max_val", "100"), IsArray("min_val", "0"))));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, AggregateLoadGroupBy) {
|
||||
for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd
|
||||
Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value",
|
||||
absl::StrCat(i)});
|
||||
}
|
||||
Run({"ft.create", "i1", "schema", "value", "numeric", "sortable"});
|
||||
|
||||
// clang-format off
|
||||
auto resp = Run({"ft.aggregate", "i1", "*",
|
||||
"LOAD", "1", "even",
|
||||
"GROUPBY", "1", "even"});
|
||||
// clang-format on
|
||||
|
||||
EXPECT_THAT(resp, IsUnordArray(IsUnordArray(IsArray("even", "false")),
|
||||
IsUnordArray(IsArray("even", "true"))));
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue