mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 18:35:46 +02:00
refactor(search_family): Add Aggregator class (#4290)
* refactor(search_family): Add Aggregator class Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * fix(aggregator_test): Fix tests failing Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: Restore the previous comment Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments 2 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments 3 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * fix(aggregator): Simplify comparator for the case when one of the values is not present Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
parent
8d66c25bc6
commit
1fa9a47a86
5 changed files with 157 additions and 112 deletions
|
@ -10,63 +10,100 @@ namespace dfly::aggregate {
|
|||
|
||||
namespace {
|
||||
|
||||
struct GroupStep {
|
||||
PipelineResult operator()(PipelineResult result) {
|
||||
// Separate items into groups
|
||||
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
|
||||
for (auto& value : result.values) {
|
||||
groups[Extract(value)].push_back(std::move(value));
|
||||
}
|
||||
using ValuesList = absl::FixedArray<Value>;
|
||||
|
||||
// Restore DocValues and apply reducers
|
||||
std::vector<DocValues> out;
|
||||
while (!groups.empty()) {
|
||||
auto node = groups.extract(groups.begin());
|
||||
DocValues doc = Unpack(std::move(node.key()));
|
||||
for (auto& reducer : reducers_) {
|
||||
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
|
||||
}
|
||||
out.push_back(std::move(doc));
|
||||
}
|
||||
|
||||
absl::flat_hash_set<std::string> fields_to_print;
|
||||
fields_to_print.reserve(fields_.size() + reducers_.size());
|
||||
|
||||
for (auto& field : fields_) {
|
||||
fields_to_print.insert(std::move(field));
|
||||
}
|
||||
for (auto& reducer : reducers_) {
|
||||
fields_to_print.insert(std::move(reducer.result_field));
|
||||
}
|
||||
|
||||
return {std::move(out), std::move(fields_to_print)};
|
||||
ValuesList ExtractFieldsValues(const DocValues& dv, absl::Span<const std::string> fields) {
|
||||
ValuesList out(fields.size());
|
||||
for (size_t i = 0; i < fields.size(); i++) {
|
||||
auto it = dv.find(fields[i]);
|
||||
out[i] = (it != dv.end()) ? it->second : Value{};
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
absl::FixedArray<Value> Extract(const DocValues& dv) {
|
||||
absl::FixedArray<Value> out(fields_.size());
|
||||
for (size_t i = 0; i < fields_.size(); i++) {
|
||||
auto it = dv.find(fields_[i]);
|
||||
out[i] = (it != dv.end()) ? it->second : Value{};
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
DocValues Unpack(absl::FixedArray<Value>&& values) {
|
||||
DCHECK_EQ(values.size(), fields_.size());
|
||||
DocValues out;
|
||||
for (size_t i = 0; i < fields_.size(); i++)
|
||||
out[fields_[i]] = std::move(values[i]);
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<std::string> fields_;
|
||||
std::vector<Reducer> reducers_;
|
||||
};
|
||||
DocValues PackFields(ValuesList values, absl::Span<const std::string> fields) {
|
||||
DCHECK_EQ(values.size(), fields.size());
|
||||
DocValues out;
|
||||
for (size_t i = 0; i < fields.size(); i++)
|
||||
out[fields[i]] = std::move(values[i]);
|
||||
return out;
|
||||
}
|
||||
|
||||
const Value kEmptyValue = Value{};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers) {
|
||||
// Separate items into groups
|
||||
absl::flat_hash_map<ValuesList, std::vector<DocValues>> groups;
|
||||
for (auto& value : result.values) {
|
||||
groups[ExtractFieldsValues(value, fields)].push_back(std::move(value));
|
||||
}
|
||||
|
||||
// Restore DocValues and apply reducers
|
||||
auto& values = result.values;
|
||||
values.clear();
|
||||
values.reserve(groups.size());
|
||||
while (!groups.empty()) {
|
||||
auto node = groups.extract(groups.begin());
|
||||
DocValues doc = PackFields(std::move(node.key()), fields);
|
||||
for (auto& reducer : reducers) {
|
||||
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
|
||||
}
|
||||
values.push_back(std::move(doc));
|
||||
}
|
||||
|
||||
auto& fields_to_print = result.fields_to_print;
|
||||
fields_to_print.clear();
|
||||
fields_to_print.reserve(fields.size() + reducers.size());
|
||||
|
||||
for (auto& field : fields) {
|
||||
fields_to_print.insert(field);
|
||||
}
|
||||
for (auto& reducer : reducers) {
|
||||
fields_to_print.insert(reducer.result_field);
|
||||
}
|
||||
}
|
||||
|
||||
void Aggregator::DoSort(std::string_view field, bool descending) {
|
||||
/*
|
||||
Comparator for sorting DocValues by field.
|
||||
If some of the fields is not present in the DocValues, comparator returns:
|
||||
1. l_it == l.end() && r_it != r.end()
|
||||
asc -> false
|
||||
desc -> false
|
||||
2. l_it != l.end() && r_it == r.end()
|
||||
asc -> true
|
||||
desc -> true
|
||||
3. l_it == l.end() && r_it == r.end()
|
||||
asc -> false
|
||||
desc -> false
|
||||
*/
|
||||
auto comparator = [&](const DocValues& l, const DocValues& r) {
|
||||
auto l_it = l.find(field);
|
||||
auto r_it = r.find(field);
|
||||
|
||||
// If some of the values is not present
|
||||
if (l_it == l.end() || r_it == r.end()) {
|
||||
return l_it != l.end();
|
||||
}
|
||||
|
||||
auto& lv = l_it->second;
|
||||
auto& rv = r_it->second;
|
||||
return !descending ? lv < rv : lv > rv;
|
||||
};
|
||||
|
||||
std::sort(result.values.begin(), result.values.end(), std::move(comparator));
|
||||
|
||||
result.fields_to_print.insert(field);
|
||||
}
|
||||
|
||||
void Aggregator::DoLimit(size_t offset, size_t num) {
|
||||
auto& values = result.values;
|
||||
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
|
||||
values.resize(std::min(num, values.size()));
|
||||
}
|
||||
|
||||
const Value& ValueIterator::operator*() const {
|
||||
auto it = values_.front().find(field_);
|
||||
return it == values_.front().end() ? kEmptyValue : it->second;
|
||||
|
@ -109,48 +146,30 @@ Reducer::Func FindReducerFunc(ReducerFunc name) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
|
||||
std::vector<Reducer> reducers) {
|
||||
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
|
||||
}
|
||||
|
||||
PipelineStep MakeSortStep(std::string_view field, bool descending) {
|
||||
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
|
||||
auto& values = result.values;
|
||||
|
||||
std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
|
||||
auto it1 = l.find(field);
|
||||
auto it2 = r.find(field);
|
||||
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
|
||||
});
|
||||
|
||||
if (descending) {
|
||||
std::reverse(values.begin(), values.end());
|
||||
}
|
||||
|
||||
result.fields_to_print.insert(field);
|
||||
return result;
|
||||
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers) {
|
||||
return [fields = std::move(fields), reducers = std::move(reducers)](Aggregator* aggregator) {
|
||||
aggregator->DoGroup(fields, reducers);
|
||||
};
|
||||
}
|
||||
|
||||
PipelineStep MakeLimitStep(size_t offset, size_t num) {
|
||||
return [offset, num](PipelineResult result) {
|
||||
auto& values = result.values;
|
||||
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
|
||||
values.resize(std::min(num, values.size()));
|
||||
return result;
|
||||
AggregationStep MakeSortStep(std::string field, bool descending) {
|
||||
return [field = std::move(field), descending](Aggregator* aggregator) {
|
||||
aggregator->DoSort(field, descending);
|
||||
};
|
||||
}
|
||||
|
||||
PipelineResult Process(std::vector<DocValues> values,
|
||||
absl::Span<const std::string_view> fields_to_print,
|
||||
absl::Span<const PipelineStep> steps) {
|
||||
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
|
||||
AggregationStep MakeLimitStep(size_t offset, size_t num) {
|
||||
return [=](Aggregator* aggregator) { aggregator->DoLimit(offset, num); };
|
||||
}
|
||||
|
||||
AggregationResult Process(std::vector<DocValues> values,
|
||||
absl::Span<const std::string_view> fields_to_print,
|
||||
absl::Span<const AggregationStep> steps) {
|
||||
Aggregator aggregator{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
|
||||
for (auto& step : steps) {
|
||||
PipelineResult step_result = step(std::move(result));
|
||||
result = std::move(step_result);
|
||||
step(&aggregator);
|
||||
}
|
||||
return result;
|
||||
return aggregator.result;
|
||||
}
|
||||
|
||||
} // namespace dfly::aggregate
|
||||
|
|
|
@ -17,19 +17,31 @@
|
|||
|
||||
namespace dfly::aggregate {
|
||||
|
||||
using Value = ::dfly::search::SortableValue;
|
||||
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline
|
||||
struct Reducer;
|
||||
|
||||
struct PipelineResult {
|
||||
using Value = ::dfly::search::SortableValue;
|
||||
|
||||
// DocValues sent through the pipeline
|
||||
// TODO: Replace DocValues with compact linear search map instead of hash map
|
||||
using DocValues = absl::flat_hash_map<std::string_view, Value>;
|
||||
|
||||
struct AggregationResult {
|
||||
// Values to be passed to the next step
|
||||
// TODO: Replace DocValues with compact linear search map instead of hash map
|
||||
std::vector<DocValues> values;
|
||||
|
||||
// Fields from values to be printed
|
||||
absl::flat_hash_set<std::string> fields_to_print;
|
||||
absl::flat_hash_set<std::string_view> fields_to_print;
|
||||
};
|
||||
|
||||
using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
|
||||
struct Aggregator {
|
||||
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
|
||||
void DoSort(std::string_view field, bool descending = false);
|
||||
void DoLimit(size_t offset, size_t num);
|
||||
|
||||
AggregationResult result;
|
||||
};
|
||||
|
||||
using AggregationStep = std::function<void(Aggregator*)>; // Group, Sort, etc.
|
||||
|
||||
// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
|
||||
// Extra clumsy for STL compatibility!
|
||||
|
@ -79,18 +91,17 @@ enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
|
|||
Reducer::Func FindReducerFunc(ReducerFunc name);
|
||||
|
||||
// Make `GROUPBY [fields...]` with REDUCE step
|
||||
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
|
||||
std::vector<Reducer> reducers);
|
||||
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);
|
||||
|
||||
// Make `SORTBY field [DESC]` step
|
||||
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
|
||||
AggregationStep MakeSortStep(std::string field, bool descending = false);
|
||||
|
||||
// Make `LIMIT offset num` step
|
||||
PipelineStep MakeLimitStep(size_t offset, size_t num);
|
||||
AggregationStep MakeLimitStep(size_t offset, size_t num);
|
||||
|
||||
// Process values with given steps
|
||||
PipelineResult Process(std::vector<DocValues> values,
|
||||
absl::Span<const std::string_view> fields_to_print,
|
||||
absl::Span<const PipelineStep> steps);
|
||||
AggregationResult Process(std::vector<DocValues> values,
|
||||
absl::Span<const std::string_view> fields_to_print,
|
||||
absl::Span<const AggregationStep> steps);
|
||||
|
||||
} // namespace dfly::aggregate
|
||||
|
|
|
@ -10,13 +10,15 @@ namespace dfly::aggregate {
|
|||
|
||||
using namespace std::string_literals;
|
||||
|
||||
using StepsList = std::vector<AggregationStep>;
|
||||
|
||||
TEST(AggregatorTest, Sort) {
|
||||
std::vector<DocValues> values = {
|
||||
DocValues{{"a", 1.0}},
|
||||
DocValues{{"a", 0.5}},
|
||||
DocValues{{"a", 1.5}},
|
||||
};
|
||||
PipelineStep steps[] = {MakeSortStep("a", false)};
|
||||
StepsList steps = {MakeSortStep("a", false)};
|
||||
|
||||
auto result = Process(values, {"a"}, steps);
|
||||
|
||||
|
@ -32,7 +34,8 @@ TEST(AggregatorTest, Limit) {
|
|||
DocValues{{"i", 3.0}},
|
||||
DocValues{{"i", 4.0}},
|
||||
};
|
||||
PipelineStep steps[] = {MakeLimitStep(1, 2)};
|
||||
|
||||
StepsList steps = {MakeLimitStep(1, 2)};
|
||||
|
||||
auto result = Process(values, {"i"}, steps);
|
||||
|
||||
|
@ -49,8 +52,8 @@ TEST(AggregatorTest, SimpleGroup) {
|
|||
DocValues{{"i", 4.0}, {"tag", "even"}},
|
||||
};
|
||||
|
||||
std::string_view fields[] = {"tag"};
|
||||
PipelineStep steps[] = {MakeGroupStep(fields, {})};
|
||||
std::vector<std::string> fields = {"tag"};
|
||||
StepsList steps = {MakeGroupStep(std::move(fields), {})};
|
||||
|
||||
auto result = Process(values, {"i", "tag"}, steps);
|
||||
EXPECT_EQ(result.values.size(), 2);
|
||||
|
@ -72,13 +75,14 @@ TEST(AggregatorTest, GroupWithReduce) {
|
|||
});
|
||||
}
|
||||
|
||||
std::string_view fields[] = {"tag"};
|
||||
std::vector<std::string> fields = {"tag"};
|
||||
std::vector<Reducer> reducers = {
|
||||
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
|
||||
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
|
||||
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
|
||||
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
|
||||
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
|
||||
|
||||
StepsList steps = {MakeGroupStep(std::move(fields), std::move(reducers))};
|
||||
|
||||
auto result = Process(values, {"i", "half-i", "tag"}, steps);
|
||||
EXPECT_EQ(result.values.size(), 2);
|
||||
|
|
|
@ -168,7 +168,7 @@ struct AggregateParams {
|
|||
search::QueryParams params;
|
||||
|
||||
std::optional<SearchFieldsList> load_fields;
|
||||
std::vector<aggregate::PipelineStep> steps;
|
||||
std::vector<aggregate::AggregationStep> steps;
|
||||
};
|
||||
|
||||
// Stores basic info about a document index.
|
||||
|
|
|
@ -320,20 +320,23 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
|
|||
while (parser.HasNext()) {
|
||||
// GROUPBY nargs property [property ...]
|
||||
if (parser.Check("GROUPBY")) {
|
||||
vector<string_view> fields(parser.Next<size_t>());
|
||||
for (string_view& field : fields) {
|
||||
size_t num_fields = parser.Next<size_t>();
|
||||
|
||||
std::vector<std::string> fields;
|
||||
fields.reserve(num_fields);
|
||||
while (num_fields > 0 && parser.HasNext()) {
|
||||
auto parsed_field = ParseFieldWithAtSign(&parser);
|
||||
|
||||
/*
|
||||
TODO: Throw an error if the field has no '@' sign at the beginning
|
||||
|
||||
if (!parsed_field) {
|
||||
builder->SendError(absl::StrCat("bad arguments for GROUPBY: Unknown property '", field,
|
||||
"'. Did you mean '@", field, "`?"));
|
||||
return nullopt;
|
||||
} */
|
||||
|
||||
field = parsed_field;
|
||||
fields.emplace_back(parsed_field);
|
||||
num_fields--;
|
||||
}
|
||||
|
||||
vector<aggregate::Reducer> reducers;
|
||||
|
@ -363,7 +366,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
|
|||
aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)});
|
||||
}
|
||||
|
||||
params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
|
||||
params.steps.push_back(aggregate::MakeGroupStep(std::move(fields), std::move(reducers)));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -373,7 +376,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
|
|||
string_view field = parser.Next();
|
||||
bool desc = bool(parser.Check("DESC"));
|
||||
|
||||
params.steps.push_back(aggregate::MakeSortStep(field, desc));
|
||||
params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -975,10 +978,18 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
|
|||
return OpStatus::OK;
|
||||
});
|
||||
|
||||
vector<aggregate::DocValues> values;
|
||||
// ResultContainer is absl::flat_hash_map<std::string, search::SortableValue>
|
||||
// DocValues is absl::flat_hash_map<std::string_view, SortableValue>
|
||||
// Keys of values should point to the keys of the query_results
|
||||
std::vector<aggregate::DocValues> values;
|
||||
for (auto& sub_results : query_results) {
|
||||
values.insert(values.end(), make_move_iterator(sub_results.begin()),
|
||||
make_move_iterator(sub_results.end()));
|
||||
for (auto& docs : sub_results) {
|
||||
aggregate::DocValues doc_value;
|
||||
for (auto& doc : docs) {
|
||||
doc_value[doc.first] = std::move(doc.second);
|
||||
}
|
||||
values.push_back(std::move(doc_value));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string_view> load_fields;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue