refactor(search_family): Add Aggregator class (#4290)

* refactor(search_family): Add Aggregator class

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* fix(aggregator_test): Fix tests failing

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: Restore the previous comment

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments 2

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments 3

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* fix(aggregator): Simplify comparator for the case when one of the values is not present

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-12-23 08:43:48 +04:00 committed by GitHub
parent 8d66c25bc6
commit 1fa9a47a86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 157 additions and 112 deletions

View file

@ -10,63 +10,100 @@ namespace dfly::aggregate {
namespace {
struct GroupStep {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}
using ValuesList = absl::FixedArray<Value>;
// Restore DocValues and apply reducers
std::vector<DocValues> out;
while (!groups.empty()) {
auto node = groups.extract(groups.begin());
DocValues doc = Unpack(std::move(node.key()));
for (auto& reducer : reducers_) {
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
}
out.push_back(std::move(doc));
}
absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());
for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}
return {std::move(out), std::move(fields_to_print)};
ValuesList ExtractFieldsValues(const DocValues& dv, absl::Span<const std::string> fields) {
ValuesList out(fields.size());
for (size_t i = 0; i < fields.size(); i++) {
auto it = dv.find(fields[i]);
out[i] = (it != dv.end()) ? it->second : Value{};
}
return out;
}
absl::FixedArray<Value> Extract(const DocValues& dv) {
absl::FixedArray<Value> out(fields_.size());
for (size_t i = 0; i < fields_.size(); i++) {
auto it = dv.find(fields_[i]);
out[i] = (it != dv.end()) ? it->second : Value{};
}
return out;
}
DocValues Unpack(absl::FixedArray<Value>&& values) {
DCHECK_EQ(values.size(), fields_.size());
DocValues out;
for (size_t i = 0; i < fields_.size(); i++)
out[fields_[i]] = std::move(values[i]);
return out;
}
std::vector<std::string> fields_;
std::vector<Reducer> reducers_;
};
DocValues PackFields(ValuesList values, absl::Span<const std::string> fields) {
DCHECK_EQ(values.size(), fields.size());
DocValues out;
for (size_t i = 0; i < fields.size(); i++)
out[fields[i]] = std::move(values[i]);
return out;
}
const Value kEmptyValue = Value{};
} // namespace
void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers) {
// Separate items into groups
absl::flat_hash_map<ValuesList, std::vector<DocValues>> groups;
for (auto& value : result.values) {
groups[ExtractFieldsValues(value, fields)].push_back(std::move(value));
}
// Restore DocValues and apply reducers
auto& values = result.values;
values.clear();
values.reserve(groups.size());
while (!groups.empty()) {
auto node = groups.extract(groups.begin());
DocValues doc = PackFields(std::move(node.key()), fields);
for (auto& reducer : reducers) {
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
}
values.push_back(std::move(doc));
}
auto& fields_to_print = result.fields_to_print;
fields_to_print.clear();
fields_to_print.reserve(fields.size() + reducers.size());
for (auto& field : fields) {
fields_to_print.insert(field);
}
for (auto& reducer : reducers) {
fields_to_print.insert(reducer.result_field);
}
}
void Aggregator::DoSort(std::string_view field, bool descending) {
/*
Comparator for sorting DocValues by field.
If some of the fields is not present in the DocValues, comparator returns:
1. l_it == l.end() && r_it != r.end()
asc -> false
desc -> false
2. l_it != l.end() && r_it == r.end()
asc -> true
desc -> true
3. l_it == l.end() && r_it == r.end()
asc -> false
desc -> false
*/
auto comparator = [&](const DocValues& l, const DocValues& r) {
auto l_it = l.find(field);
auto r_it = r.find(field);
// If some of the values is not present
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end();
}
auto& lv = l_it->second;
auto& rv = r_it->second;
return !descending ? lv < rv : lv > rv;
};
std::sort(result.values.begin(), result.values.end(), std::move(comparator));
result.fields_to_print.insert(field);
}
void Aggregator::DoLimit(size_t offset, size_t num) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
}
const Value& ValueIterator::operator*() const {
auto it = values_.front().find(field_);
return it == values_.front().end() ? kEmptyValue : it->second;
@ -109,48 +146,30 @@ Reducer::Func FindReducerFunc(ReducerFunc name) {
return nullptr;
}
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers) {
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
}
PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
if (descending) {
std::reverse(values.begin(), values.end());
}
result.fields_to_print.insert(field);
return result;
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers) {
return [fields = std::move(fields), reducers = std::move(reducers)](Aggregator* aggregator) {
aggregator->DoGroup(fields, reducers);
};
}
PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return result;
AggregationStep MakeSortStep(std::string field, bool descending) {
return [field = std::move(field), descending](Aggregator* aggregator) {
aggregator->DoSort(field, descending);
};
}
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
AggregationStep MakeLimitStep(size_t offset, size_t num) {
return [=](Aggregator* aggregator) { aggregator->DoLimit(offset, num); };
}
AggregationResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const AggregationStep> steps) {
Aggregator aggregator{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
step(&aggregator);
}
return result;
return aggregator.result;
}
} // namespace dfly::aggregate

View file

@ -17,19 +17,31 @@
namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
struct Reducer;
struct PipelineResult {
using Value = ::dfly::search::SortableValue;
// DocValues sent through the pipeline
// TODO: Replace DocValues with compact linear search map instead of hash map
using DocValues = absl::flat_hash_map<std::string_view, Value>;
struct AggregationResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;
// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
absl::flat_hash_set<std::string_view> fields_to_print;
};
using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
struct Aggregator {
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
void DoSort(std::string_view field, bool descending = false);
void DoLimit(size_t offset, size_t num);
AggregationResult result;
};
using AggregationStep = std::function<void(Aggregator*)>; // Group, Sort, etc.
// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
@ -79,18 +91,17 @@ enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
Reducer::Func FindReducerFunc(ReducerFunc name);
// Make `GROUPBY [fields...]` with REDUCE step
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);
// Make `SORTBY field [DESC]` step
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
AggregationStep MakeSortStep(std::string field, bool descending = false);
// Make `LIMIT offset num` step
PipelineStep MakeLimitStep(size_t offset, size_t num);
AggregationStep MakeLimitStep(size_t offset, size_t num);
// Process values with given steps
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);
AggregationResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const AggregationStep> steps);
} // namespace dfly::aggregate

View file

@ -10,13 +10,15 @@ namespace dfly::aggregate {
using namespace std::string_literals;
using StepsList = std::vector<AggregationStep>;
TEST(AggregatorTest, Sort) {
std::vector<DocValues> values = {
DocValues{{"a", 1.0}},
DocValues{{"a", 0.5}},
DocValues{{"a", 1.5}},
};
PipelineStep steps[] = {MakeSortStep("a", false)};
StepsList steps = {MakeSortStep("a", false)};
auto result = Process(values, {"a"}, steps);
@ -32,7 +34,8 @@ TEST(AggregatorTest, Limit) {
DocValues{{"i", 3.0}},
DocValues{{"i", 4.0}},
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};
StepsList steps = {MakeLimitStep(1, 2)};
auto result = Process(values, {"i"}, steps);
@ -49,8 +52,8 @@ TEST(AggregatorTest, SimpleGroup) {
DocValues{{"i", 4.0}, {"tag", "even"}},
};
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};
std::vector<std::string> fields = {"tag"};
StepsList steps = {MakeGroupStep(std::move(fields), {})};
auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
@ -72,13 +75,14 @@ TEST(AggregatorTest, GroupWithReduce) {
});
}
std::string_view fields[] = {"tag"};
std::vector<std::string> fields = {"tag"};
std::vector<Reducer> reducers = {
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
StepsList steps = {MakeGroupStep(std::move(fields), std::move(reducers))};
auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

View file

@ -168,7 +168,7 @@ struct AggregateParams {
search::QueryParams params;
std::optional<SearchFieldsList> load_fields;
std::vector<aggregate::PipelineStep> steps;
std::vector<aggregate::AggregationStep> steps;
};
// Stores basic info about a document index.

View file

@ -320,20 +320,23 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
while (parser.HasNext()) {
// GROUPBY nargs property [property ...]
if (parser.Check("GROUPBY")) {
vector<string_view> fields(parser.Next<size_t>());
for (string_view& field : fields) {
size_t num_fields = parser.Next<size_t>();
std::vector<std::string> fields;
fields.reserve(num_fields);
while (num_fields > 0 && parser.HasNext()) {
auto parsed_field = ParseFieldWithAtSign(&parser);
/*
TODO: Throw an error if the field has no '@' sign at the beginning
if (!parsed_field) {
builder->SendError(absl::StrCat("bad arguments for GROUPBY: Unknown property '", field,
"'. Did you mean '@", field, "`?"));
return nullopt;
} */
field = parsed_field;
fields.emplace_back(parsed_field);
num_fields--;
}
vector<aggregate::Reducer> reducers;
@ -363,7 +366,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)});
}
params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
params.steps.push_back(aggregate::MakeGroupStep(std::move(fields), std::move(reducers)));
continue;
}
@ -373,7 +376,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC"));
params.steps.push_back(aggregate::MakeSortStep(field, desc));
params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc));
continue;
}
@ -975,10 +978,18 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
return OpStatus::OK;
});
vector<aggregate::DocValues> values;
// ResultContainer is absl::flat_hash_map<std::string, search::SortableValue>
// DocValues is absl::flat_hash_map<std::string_view, SortableValue>
// Keys of values should point to the keys of the query_results
std::vector<aggregate::DocValues> values;
for (auto& sub_results : query_results) {
values.insert(values.end(), make_move_iterator(sub_results.begin()),
make_move_iterator(sub_results.end()));
for (auto& docs : sub_results) {
aggregate::DocValues doc_value;
for (auto& doc : docs) {
doc_value[doc.first] = std::move(doc.second);
}
values.push_back(std::move(doc_value));
}
}
std::vector<std::string_view> load_fields;