Basic FT.AGGREGATE (#2413)

Introduces basic FT.AGGREGATE command, supporting GROUPBY, SORTBY, LIMIT

Signed-off-by: Vladislav <vladislav.oleshko@gmail.com>

---------
This commit is contained in:
Vladislav 2024-03-08 08:51:51 +03:00 committed by GitHub
parent 98616755c0
commit 22e413a00b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 310 additions and 34 deletions

View file

@ -21,6 +21,10 @@ WrappedStrPtr::WrappedStrPtr(const PMR_NS::string& s)
std::strcpy(ptr.get(), s.c_str());
}
WrappedStrPtr::WrappedStrPtr(const std::string& s) : ptr{std::make_unique<char[]>(s.size() + 1)} {
std::strcpy(ptr.get(), s.c_str());
}
bool WrappedStrPtr::operator<(const WrappedStrPtr& other) const {
return std::strcmp(ptr.get(), other.ptr.get()) < 0;
}
@ -29,4 +33,8 @@ bool WrappedStrPtr::operator>=(const WrappedStrPtr& other) const {
return !operator<(other);
}
WrappedStrPtr::operator std::string_view() const {
return std::string_view{ptr.get(), std::strlen(ptr.get())};
}
} // namespace dfly::search

View file

@ -43,18 +43,28 @@ struct SortOption {
bool descending = false;
};
// Comparable string stored as char[]. Used to reduce size of std::variant with strings.
struct WrappedStrPtr {
// Intentionally implicit and const std::string& for use in templates
WrappedStrPtr(const PMR_NS::string& s);
WrappedStrPtr(const std::string& s);
bool operator<(const WrappedStrPtr& other) const;
bool operator>=(const WrappedStrPtr& other) const;
operator std::string_view() const;
private:
std::unique_ptr<char[]> ptr;
};
// Score produced either by KNN (float) or SORT (double / wrapped str)
using ResultScore = std::variant<std::monostate, float, double, WrappedStrPtr>;
// Values are either sortable as doubles or strings, or not sortable at all.
// Contrary to ResultScore it doesn't include KNN results and is not optimized for smaller struct
// size.
using SortableValue = std::variant<std::monostate, double, std::string>;
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector;
@ -78,6 +88,7 @@ struct BaseIndex {
// Base class for type-specific sorting indices.
struct BaseSortIndex : BaseIndex {
virtual SortableValue Lookup(DocId doc) const = 0;
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
};

View file

@ -443,6 +443,12 @@ struct BasicSearch {
} // namespace
string_view Schema::LookupAlias(string_view alias) const {
if (auto it = field_names.find(alias); it != field_names.end())
return it->second;
return alias;
}
FieldIndices::FieldIndices(Schema schema, PMR_NS::memory_resource* mr)
: schema_{std::move(schema)}, all_ids_{}, indices_{} {
CreateIndices(mr);
@ -521,20 +527,12 @@ void FieldIndices::Remove(DocId doc, DocumentAccessor* access) {
}
BaseIndex* FieldIndices::GetIndex(string_view field) const {
// Replace short field name with full identifier
if (auto it = schema_.field_names.find(field); it != schema_.field_names.end())
field = it->second;
auto it = indices_.find(field);
auto it = indices_.find(schema_.LookupAlias(field));
return it != indices_.end() ? it->second.get() : nullptr;
}
BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const {
// Replace short field name with full identifier
if (auto it = schema_.field_names.find(field); it != schema_.field_names.end())
field = it->second;
auto it = sort_indices_.find(field);
auto it = sort_indices_.find(schema_.LookupAlias(field));
return it != sort_indices_.end() ? it->second.get() : nullptr;
}
@ -558,6 +556,14 @@ const Schema& FieldIndices::GetSchema() const {
return schema_;
}
vector<pair<string, SortableValue>> FieldIndices::ExtractStoredValues(DocId doc) const {
vector<pair<string, SortableValue>> out;
for (const auto& [ident, index] : sort_indices_) {
out.emplace_back(ident, index->Lookup(doc));
}
return out;
}
SearchAlgorithm::SearchAlgorithm() = default;
SearchAlgorithm::~SearchAlgorithm() = default;

View file

@ -51,6 +51,9 @@ struct Schema {
// Mapping for short field names (aliases).
absl::flat_hash_map<std::string /* short name*/, std::string /*identifier*/> field_names;
// Return identifier for alias if found, otherwise return passed value
std::string_view LookupAlias(std::string_view alias) const;
};
// Collection of indices for all fields in schema
@ -64,13 +67,14 @@ class FieldIndices {
BaseIndex* GetIndex(std::string_view field) const;
BaseSortIndex* GetSortIndex(std::string_view field) const;
std::vector<TextIndex*> GetAllTextIndices() const;
const std::vector<DocId>& GetAllDocs() const;
const Schema& GetSchema() const;
// Extract values stored in sort indices
std::vector<std::pair<std::string, SortableValue>> ExtractStoredValues(DocId doc) const;
private:
void CreateIndices(PMR_NS::memory_resource* mr);
void CreateSortIndices(PMR_NS::memory_resource* mr);

View file

@ -10,6 +10,7 @@
#include <absl/strings/str_split.h>
#include <algorithm>
#include <type_traits>
namespace dfly::search {
@ -21,6 +22,15 @@ template <typename T>
SimpleValueSortIndex<T>::SimpleValueSortIndex(PMR_NS::memory_resource* mr) : values_{mr} {
}
template <typename T> SortableValue SimpleValueSortIndex<T>::Lookup(DocId doc) const {
DCHECK_LT(doc, values_.size());
if constexpr (is_same_v<T, PMR_NS::string>) {
return std::string(values_[doc]);
} else {
return values_[doc];
}
}
template <typename T>
std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids, size_t limit,
bool desc) const {
@ -37,7 +47,7 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
template <typename T>
void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_view field) {
DCHECK(id <= values_.size()); // Doc ids grow at most by one
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
if (id >= values_.size())
values_.resize(id + 1);
values_[id] = Get(id, doc, field);
@ -45,6 +55,7 @@ void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_v
template <typename T>
void SimpleValueSortIndex<T>::Remove(DocId id, DocumentAccessor* doc, std::string_view field) {
DCHECK_LT(id, values_.size());
values_[id] = T{};
}

View file

@ -21,6 +21,7 @@ namespace dfly::search {
template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
SimpleValueSortIndex(PMR_NS::memory_resource* mr);
SortableValue Lookup(DocId doc) const override;
std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const override;
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;

View file

@ -96,8 +96,8 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
}
PipelineStep MakeSortStep(std::string field, bool descending) {
return [field, descending](std::vector<DocValues> values) -> PipelineResult {
PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
@ -117,7 +117,7 @@ PipelineStep MakeLimitStep(size_t offset, size_t num) {
};
}
PipelineResult Process(std::vector<DocValues> values, absl::Span<PipelineStep> steps) {
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
for (auto& step : steps) {
auto result = step(std::move(values));
if (!result.has_value())

View file

@ -10,14 +10,17 @@
#include <string>
#include <variant>
#include "core/search/base.h"
#include "facade/reply_builder.h"
#include "io/io.h"
namespace dfly::aggregate {
using Value = std::variant<std::monostate, double, std::string>;
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
// TODO: Replace DocValues with compact linear search map instead of hash map
using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
@ -71,12 +74,12 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);
// Make `SORYBY field [DESC]` step
PipelineStep MakeSortStep(std::string field, bool descending = false);
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
// Make `LIMIT offset num` step
PipelineStep MakeLimitStep(size_t offset, size_t num);
// Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
} // namespace dfly::aggregate

View file

@ -220,8 +220,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
: accessor->Serialize(base_->schema);
auto score =
search_results.scores.empty() ? std::monostate{} : std::move(search_results.scores[i]);
auto score = search_results.scores.empty() ? monostate{} : std::move(search_results.scores[i]);
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)});
}
@ -229,6 +228,38 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
std::move(search_results.profile)};
}
vector<absl::flat_hash_map<string, search::SortableValue>> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.shard->db_slice();
auto search_results = search_algo->Search(&indices_);
if (!search_results.error.empty())
return {};
// Convert load_fields into return_list required by accessor interface
SearchParams::FieldReturnList return_fields;
for (string_view load_field : load_fields)
return_fields.emplace_back(indices_.GetSchema().LookupAlias(load_field), load_field);
vector<absl::flat_hash_map<string, search::SortableValue>> out;
for (DocId doc : search_results.ids) {
auto key = key_index_.Get(doc);
auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) // Item must have expired
continue;
auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_.ExtractStoredValues(doc);
auto loaded = accessor->Serialize(base_->schema, return_fields);
out.emplace_back(make_move_iterator(extracted.begin()), make_move_iterator(extracted.end()));
out.back().insert(make_move_iterator(loaded.begin()), make_move_iterator(loaded.end()));
}
return out;
}
DocIndexInfo ShardDocIndex::GetInfo() const {
return {*base_, key_index_.Size()};
}

View file

@ -126,6 +126,10 @@ class ShardDocIndex {
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const;
// Perform search and load requested values - note params might be interpreted differently.
std::vector<absl::flat_hash_map<std::string, search::SortableValue>> SearchForAggregator(
const OpArgs& op_args, ArgSlice load_fields, search::SearchAlgorithm* search_algo) const;
// Return whether base index matches
bool Matches(std::string_view key, unsigned obj_code) const;

View file

@ -24,8 +24,10 @@
#include "server/conn_context.h"
#include "server/container_utils.h"
#include "server/engine_shard_set.h"
#include "server/search/aggregator.h"
#include "server/search/doc_index.h"
#include "server/transaction.h"
#include "src/core/overloaded.h"
namespace dfly {
@ -152,6 +154,16 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
return schema;
}
search::QueryParams ParseQueryParams(CmdArgParser* parser) {
search::QueryParams params;
size_t num_args = parser->Next<size_t>();
while (parser->HasNext() && params.Size() * 2 < num_args) {
auto [k, v] = parser->Next<string_view, string_view>();
params[k] = v;
}
return params;
}
optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionContext* cntx) {
SearchParams params;
@ -183,12 +195,7 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
// [PARAMS num(ignored) name(ignored) knn_vector]
if (parser.Check("PARAMS").ExpectTail(1)) {
size_t num_args = parser.Next<size_t>();
while (parser.HasNext() && params.query_params.Size() * 2 < num_args) {
string_view k = parser.Next();
string_view v = parser.Next();
params.query_params[k] = v;
}
params.query_params = ParseQueryParams(&parser);
continue;
}
@ -210,6 +217,95 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
return params;
}
struct AggregateParams {
string_view index, query;
search::QueryParams params;
vector<string_view> load_fields;
vector<aggregate::PipelineStep> steps;
};
optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
ConnectionContext* cntx) {
AggregateParams params;
tie(params.index, params.query) = parser.Next<string_view, string_view>();
while (parser.ToUpper().HasNext()) {
// LOAD count field [field ...]
if (parser.Check("LOAD").ExpectTail(1)) {
params.load_fields.resize(parser.Next<size_t>());
for (string_view& field : params.load_fields)
field = parser.Next();
continue;
}
// GROUPBY nargs property [property ...]
if (parser.Check("GROUPBY").ExpectTail(1)) {
vector<string_view> fields(parser.Next<size_t>());
for (string_view& field : fields)
field = parser.Next();
vector<aggregate::Reducer> reducers;
while (parser.ToUpper().Check("REDUCE").ExpectTail(2)) {
parser.ToUpper(); // uppercase for func_name
auto [func_name, nargs] = parser.Next<string_view, size_t>();
auto func = aggregate::FindReducerFunc(func_name);
if (!parser.HasError() && !func) {
cntx->SendError(absl::StrCat("reducer function ", func_name, " not found"));
return nullopt;
}
string source_field = "";
if (nargs > 0) {
source_field = parser.Next<string>();
}
parser.ExpectTag("AS");
string result_field = parser.Next<string>();
reducers.push_back(aggregate::Reducer{source_field, result_field, std::move(func)});
}
params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
continue;
}
// SORTBY nargs
if (parser.Check("SORTBY").ExpectTail(1)) {
parser.ExpectTag("1");
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC").IgnoreCase());
params.steps.push_back(aggregate::MakeSortStep(field, desc));
continue;
}
// LIMIT
if (parser.Check("LIMIT").ExpectTail(2)) {
auto [offset, num] = parser.Next<size_t, size_t>();
params.steps.push_back(aggregate::MakeLimitStep(offset, num));
continue;
}
// PARAMS
if (parser.Check("PARAMS").ExpectTail(1)) {
params.params = ParseQueryParams(&parser);
continue;
}
cntx->SendError(absl::StrCat("Unknown clause: ", parser.Peek()));
return nullopt;
}
if (auto err = parser.Error(); err) {
cntx->SendError(err->MakeReply());
return nullopt;
}
return params;
}
void SendSerializedDoc(const SerializedSearchDoc& doc, ConnectionContext* cntx) {
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
rb->SendBulkString(doc.key);
@ -594,6 +690,55 @@ void SearchFamily::FtProfile(CmdArgList args, ConnectionContext* cntx) {
}
}
void SearchFamily::FtAggregate(CmdArgList args, ConnectionContext* cntx) {
const auto params = ParseAggregatorParamsOrReply(args, cntx);
if (!params)
return;
search::SearchAlgorithm search_algo;
if (!search_algo.Init(params->query, &params->params, nullptr))
return cntx->SendError("Query syntax error");
using ResultContainer =
decltype(declval<ShardDocIndex>().SearchForAggregator(declval<OpArgs>(), {}, &search_algo));
vector<ResultContainer> query_results(shard_set->size());
cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
if (auto* index = es->search_indices()->GetIndex(params->index); index) {
query_results[es->shard_id()] =
index->SearchForAggregator(t->GetOpArgs(es), params->load_fields, &search_algo);
}
return OpStatus::OK;
});
vector<aggregate::DocValues> values;
for (auto& sub_results : query_results) {
values.insert(values.end(), make_move_iterator(sub_results.begin()),
make_move_iterator(sub_results.end()));
}
auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return cntx->SendError(agg_results.error());
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
Overloaded replier{
[rb](monostate) { rb->SendNull(); },
[rb](double d) { rb->SendDouble(d); },
[rb](const string& s) { rb->SendBulkString(s); },
};
rb->StartArray(agg_results->size());
for (const auto& result : agg_results.value()) {
rb->StartArray(result.size());
for (const auto& [k, v] : result) {
rb->StartArray(2);
rb->SendBulkString(k);
visit(replier, v);
}
}
}
#define HFUNC(x) SetHandler(&SearchFamily::x)
// Redis search is a module. Therefore we introduce dragonfly extension search
@ -616,6 +761,7 @@ void SearchFamily::Register(CommandRegistry* registry) {
// Underscore same as in RediSearch because it's "temporary" (long time already)
<< CI{"FT._LIST", kReadOnlyMask, 1, 0, 0, acl::FT_SEARCH}.HFUNC(FtList)
<< CI{"FT.SEARCH", kReadOnlyMask, -3, 0, 0, acl::FT_SEARCH}.HFUNC(FtSearch)
<< CI{"FT.AGGREGATE", kReadOnlyMask, -3, 0, 0, acl::FT_SEARCH}.HFUNC(FtAggregate)
<< CI{"FT.PROFILE", kReadOnlyMask, -4, 0, 0, acl::FT_SEARCH}.HFUNC(FtProfile);
}

View file

@ -20,6 +20,7 @@ class SearchFamily {
static void FtList(CmdArgList args, ConnectionContext* cntx);
static void FtSearch(CmdArgList args, ConnectionContext* cntx);
static void FtProfile(CmdArgList args, ConnectionContext* cntx);
static void FtAggregate(CmdArgList args, ConnectionContext* cntx);
public:
static void Register(CommandRegistry* registry);

View file

@ -62,6 +62,14 @@ template <typename... Args> auto AreDocIds(Args... args) {
return DocIds(sizeof...(args), vector<string>{args...});
}
template <typename... Args> auto IsArray(Args... args) {
return RespArray(ElementsAre(std::forward<Args>(args)...));
}
template <typename... Args> auto IsUnordArray(Args... args) {
return RespArray(UnorderedElementsAre(std::forward<Args>(args)...));
}
TEST_F(SearchFamilyTest, CreateDropListIndex) {
EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK");
EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK");
@ -90,12 +98,10 @@ TEST_F(SearchFamilyTest, InfoIndex) {
}
auto info = Run({"ft.info", "idx-1"});
EXPECT_THAT(
info, RespArray(ElementsAre(
_, _, _, RespArray(ElementsAre("key_type", "HASH", "prefix", "doc-")), "attributes",
RespArray(ElementsAre(RespArray(
ElementsAre("identifier", "name", "attribute", "name", "type", "TEXT")))),
"num_docs", IntArg(15))));
EXPECT_THAT(info,
IsArray(_, _, _, IsArray("key_type", "HASH", "prefix", "doc-"), "attributes",
IsArray(IsArray("identifier", "name", "attribute", "name", "type", "TEXT")),
"num_docs", IntArg(15)));
}
TEST_F(SearchFamilyTest, Stats) {
@ -385,7 +391,7 @@ TEST_F(SearchFamilyTest, TestReturn) {
"NUMERIC", "vector", "VECTOR", "FLAT", "2", "DIM", "1"});
auto MatchEntry = [](string key, auto... fields) {
return RespArray(ElementsAre(IntArg(1), key, RespArray(UnorderedElementsAre(fields...))));
return RespArray(ElementsAre(IntArg(1), key, IsUnordArray(fields...)));
};
// Check all fields are returned
@ -596,4 +602,48 @@ TEST_F(SearchFamilyTest, SimpleExpiry) {
Run({"flushall"});
}
TEST_F(SearchFamilyTest, AggregateGroupByReduceSort) {
for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd
Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value",
absl::StrCat(i)});
}
Run({"ft.create", "i1", "schema", "even", "tag", "sortable", "value", "numeric", "sortable"});
// clang-format off
auto resp = Run({"ft.aggregate", "i1", "*",
"GROUPBY", "1", "even",
"REDUCE", "count", "0", "as", "count",
"REDUCE", "count_distinct", "1", "even", "as", "distinct_tags",
"REDUCE", "count_distinct", "1", "value", "as", "distinct_vals",
"REDUCE", "max", "1", "value", "as", "max_val",
"REDUCE", "min", "1", "value", "as", "min_val",
"SORTBY", "1", "count"});
// clang-format on
EXPECT_THAT(resp,
IsArray(IsUnordArray(IsArray("even", "false"), IsArray("count", "50"),
IsArray("distinct_tags", "1"), IsArray("distinct_vals", "50"),
IsArray("max_val", "99"), IsArray("min_val", "1")),
IsUnordArray(IsArray("even", "true"), IsArray("count", "51"),
IsArray("distinct_tags", "1"), IsArray("distinct_vals", "51"),
IsArray("max_val", "100"), IsArray("min_val", "0"))));
}
TEST_F(SearchFamilyTest, AggregateLoadGroupBy) {
for (size_t i = 0; i < 101; i++) { // 51 even, 50 odd
Run({"hset", absl::StrCat("k", i), "even", (i % 2 == 0) ? "true" : "false", "value",
absl::StrCat(i)});
}
Run({"ft.create", "i1", "schema", "value", "numeric", "sortable"});
// clang-format off
auto resp = Run({"ft.aggregate", "i1", "*",
"LOAD", "1", "even",
"GROUPBY", "1", "even"});
// clang-format on
EXPECT_THAT(resp, IsUnordArray(IsUnordArray(IsArray("even", "false")),
IsUnordArray(IsArray("even", "true"))));
}
} // namespace dfly