mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
fix(search_family): Support multiple fields in SORTBY option in the FT.AGGREGATE command. SECOND PR (#4232)
fix(search_family): Support multiple fields in SORTBY option in the FT.AGGREGATE command fixes dragonfly#3631 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
parent
3c7e31240f
commit
aeeb625393
5 changed files with 245 additions and 24 deletions
|
@ -65,9 +65,9 @@ void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const
|
|||
}
|
||||
}
|
||||
|
||||
void Aggregator::DoSort(std::string_view field, bool descending) {
|
||||
void Aggregator::DoSort(const SortParams& sort_params) {
|
||||
/*
|
||||
Comparator for sorting DocValues by field.
|
||||
Comparator for sorting DocValues by fields.
|
||||
If some of the fields is not present in the DocValues, comparator returns:
|
||||
1. l_it == l.end() && r_it != r.end()
|
||||
asc -> false
|
||||
|
@ -80,22 +80,41 @@ void Aggregator::DoSort(std::string_view field, bool descending) {
|
|||
desc -> false
|
||||
*/
|
||||
auto comparator = [&](const DocValues& l, const DocValues& r) {
|
||||
auto l_it = l.find(field);
|
||||
auto r_it = r.find(field);
|
||||
for (const auto& [field, order] : sort_params.fields) {
|
||||
auto l_it = l.find(field);
|
||||
auto r_it = r.find(field);
|
||||
|
||||
// If some of the values is not present
|
||||
if (l_it == l.end() || r_it == r.end()) {
|
||||
return l_it != l.end();
|
||||
// If some of the values is not present
|
||||
if (l_it == l.end() || r_it == r.end()) {
|
||||
if (l_it == l.end() && r_it == r.end()) {
|
||||
continue;
|
||||
}
|
||||
return l_it != l.end();
|
||||
}
|
||||
|
||||
const auto& lv = l_it->second;
|
||||
const auto& rv = r_it->second;
|
||||
if (lv == rv) {
|
||||
continue;
|
||||
}
|
||||
return order == SortParams::SortOrder::ASC ? lv < rv : lv > rv;
|
||||
}
|
||||
|
||||
auto& lv = l_it->second;
|
||||
auto& rv = r_it->second;
|
||||
return !descending ? lv < rv : lv > rv;
|
||||
return false;
|
||||
};
|
||||
|
||||
std::sort(result.values.begin(), result.values.end(), std::move(comparator));
|
||||
auto& values = result.values;
|
||||
if (sort_params.SortAll()) {
|
||||
std::sort(values.begin(), values.end(), comparator);
|
||||
} else {
|
||||
DCHECK_GE(sort_params.max, 0);
|
||||
const size_t limit = std::min(values.size(), size_t(sort_params.max));
|
||||
std::partial_sort(values.begin(), values.begin() + limit, values.end(), comparator);
|
||||
values.resize(limit);
|
||||
}
|
||||
|
||||
result.fields_to_print.insert(field);
|
||||
for (auto& field : sort_params.fields) {
|
||||
result.fields_to_print.insert(field.first);
|
||||
}
|
||||
}
|
||||
|
||||
void Aggregator::DoLimit(size_t offset, size_t num) {
|
||||
|
@ -152,10 +171,8 @@ AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reduc
|
|||
};
|
||||
}
|
||||
|
||||
AggregationStep MakeSortStep(std::string field, bool descending) {
|
||||
return [field = std::move(field), descending](Aggregator* aggregator) {
|
||||
aggregator->DoSort(field, descending);
|
||||
};
|
||||
AggregationStep MakeSortStep(SortParams sort_params) {
|
||||
return [params = std::move(sort_params)](Aggregator* aggregator) { aggregator->DoSort(params); };
|
||||
}
|
||||
|
||||
AggregationStep MakeLimitStep(size_t offset, size_t num) {
|
||||
|
|
|
@ -33,9 +33,27 @@ struct AggregationResult {
|
|||
absl::flat_hash_set<std::string_view> fields_to_print;
|
||||
};
|
||||
|
||||
struct SortParams {
|
||||
enum class SortOrder { ASC, DESC };
|
||||
|
||||
constexpr static int64_t kSortAll = -1;
|
||||
|
||||
bool SortAll() const {
|
||||
return max == kSortAll;
|
||||
}
|
||||
|
||||
/* Fields to sort by. If multiple fields are provided, sorting works hierarchically:
|
||||
- First, the i-th field is compared.
|
||||
- If the i-th field values are equal, the (i + 1)-th field is compared, and so on. */
|
||||
absl::InlinedVector<std::pair<std::string, SortOrder>, 2> fields;
|
||||
/* Max number of elements to include in the sorted result.
|
||||
If set, only the first [max] elements are fully sorted using partial_sort. */
|
||||
int64_t max = kSortAll;
|
||||
};
|
||||
|
||||
struct Aggregator {
|
||||
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
|
||||
void DoSort(std::string_view field, bool descending = false);
|
||||
void DoSort(const SortParams& sort_params);
|
||||
void DoLimit(size_t offset, size_t num);
|
||||
|
||||
AggregationResult result;
|
||||
|
@ -94,7 +112,7 @@ Reducer::Func FindReducerFunc(ReducerFunc name);
|
|||
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);
|
||||
|
||||
// Make `SORTBY field [DESC]` step
|
||||
AggregationStep MakeSortStep(std::string field, bool descending = false);
|
||||
AggregationStep MakeSortStep(SortParams sort_params);
|
||||
|
||||
// Make `LIMIT offset num` step
|
||||
AggregationStep MakeLimitStep(size_t offset, size_t num);
|
||||
|
|
|
@ -18,7 +18,10 @@ TEST(AggregatorTest, Sort) {
|
|||
DocValues{{"a", 0.5}},
|
||||
DocValues{{"a", 1.5}},
|
||||
};
|
||||
StepsList steps = {MakeSortStep("a", false)};
|
||||
|
||||
SortParams params;
|
||||
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
|
||||
StepsList steps = {MakeSortStep(std::move(params))};
|
||||
|
||||
auto result = Process(values, {"a"}, steps);
|
||||
|
||||
|
|
|
@ -306,6 +306,42 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
|
|||
return params;
|
||||
}
|
||||
|
||||
std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
|
||||
using SordOrder = aggregate::SortParams::SortOrder;
|
||||
|
||||
size_t strings_num = parser->Next<size_t>();
|
||||
|
||||
aggregate::SortParams sort_params;
|
||||
sort_params.fields.reserve(strings_num / 2);
|
||||
|
||||
while (parser->HasNext() && strings_num > 0) {
|
||||
// TODO: Throw an error if the field has no '@' sign at the beginning
|
||||
std::string_view parsed_field = ParseFieldWithAtSign(parser);
|
||||
strings_num--;
|
||||
|
||||
SordOrder sord_order = SordOrder::ASC;
|
||||
if (strings_num > 0) {
|
||||
auto order = parser->TryMapNext("ASC", SordOrder::ASC, "DESC", SordOrder::DESC);
|
||||
if (order) {
|
||||
sord_order = order.value();
|
||||
strings_num--;
|
||||
}
|
||||
}
|
||||
|
||||
sort_params.fields.emplace_back(parsed_field, sord_order);
|
||||
}
|
||||
|
||||
if (strings_num) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (parser->Check("MAX")) {
|
||||
sort_params.max = parser->Next<size_t>();
|
||||
}
|
||||
|
||||
return sort_params;
|
||||
}
|
||||
|
||||
optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
|
||||
SinkReplyBuilder* builder) {
|
||||
AggregateParams params;
|
||||
|
@ -372,11 +408,13 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
|
|||
|
||||
// SORTBY nargs
|
||||
if (parser.Check("SORTBY")) {
|
||||
parser.ExpectTag("1");
|
||||
string_view field = parser.Next();
|
||||
bool desc = bool(parser.Check("DESC"));
|
||||
auto sort_params = ParseAggregatorSortParams(&parser);
|
||||
if (!sort_params) {
|
||||
builder->SendError("bad arguments for SORTBY: specified invalid number of strings");
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc));
|
||||
params.steps.push_back(aggregate::MakeSortStep(std::move(sort_params).value()));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -1680,4 +1680,149 @@ TEST_F(SearchFamilyTest, AggregateResultFields) {
|
|||
IsMap(), IsMap()));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, AggregateSortByJson) {
|
||||
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
|
||||
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
|
||||
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
|
||||
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
|
||||
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
|
||||
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
|
||||
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
|
||||
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
|
||||
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});
|
||||
|
||||
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
|
||||
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});
|
||||
|
||||
// Test sorting by name (DESC) and number (ASC)
|
||||
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "DESC", "@number", "ASC"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(
|
||||
IsMap("name", "\"third\"", "number", "300"),
|
||||
IsMap("name", "\"sixth\"", "number", "300"),
|
||||
IsMap("name", "\"seventh\"", "number", "400"),
|
||||
IsMap("name", "\"second\"", "number", "800"), IsMap("name", "\"ninth\""),
|
||||
IsMap("name", "\"fourth\"", "number", "400"),
|
||||
IsMap("name", "\"first\"", "number", "1200"),
|
||||
IsMap("name", "\"fifth\"", "number", "900"), IsMap("name", "\"eighth\"")));
|
||||
|
||||
// Test sorting by name (ASC) and number (DESC)
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "ASC", "@number", "DESC"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(
|
||||
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
|
||||
IsMap("name", "\"first\"", "number", "1200"),
|
||||
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
|
||||
IsMap("name", "\"second\"", "number", "800"),
|
||||
IsMap("name", "\"seventh\"", "number", "400"),
|
||||
IsMap("name", "\"sixth\"", "number", "300"),
|
||||
IsMap("name", "\"third\"", "number", "300")));
|
||||
|
||||
// Test sorting by group (ASC), number (DESC), and name
|
||||
resp = Run(
|
||||
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@group", "ASC", "@number", "DESC", "@name"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(
|
||||
IsMap("group", "\"first\"", "number", "1200", "name", "\"first\""),
|
||||
IsMap("group", "\"first\"", "number", "800", "name", "\"second\""),
|
||||
IsMap("group", "\"first\"", "number", "300", "name", "\"sixth\""),
|
||||
IsMap("group", "\"first\"", "number", "300", "name", "\"third\""),
|
||||
IsMap("group", "\"first\"", "name", "\"eighth\""),
|
||||
IsMap("group", "\"second\"", "number", "900", "name", "\"fifth\""),
|
||||
IsMap("group", "\"second\"", "number", "400", "name", "\"fourth\""),
|
||||
IsMap("group", "\"second\"", "number", "400", "name", "\"seventh\""),
|
||||
IsMap("group", "\"second\"", "name", "\"ninth\"")));
|
||||
|
||||
// Test sorting by number (ASC), group (DESC), and name
|
||||
resp = Run(
|
||||
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@number", "ASC", "@group", "DESC", "@name"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(
|
||||
IsMap("number", "300", "group", "\"first\"", "name", "\"sixth\""),
|
||||
IsMap("number", "300", "group", "\"first\"", "name", "\"third\""),
|
||||
IsMap("number", "400", "group", "\"second\"", "name", "\"fourth\""),
|
||||
IsMap("number", "400", "group", "\"second\"", "name", "\"seventh\""),
|
||||
IsMap("number", "800", "group", "\"first\"", "name", "\"second\""),
|
||||
IsMap("number", "900", "group", "\"second\"", "name", "\"fifth\""),
|
||||
IsMap("number", "1200", "group", "\"first\"", "name", "\"first\""),
|
||||
IsMap("group", "\"second\"", "name", "\"ninth\""),
|
||||
IsMap("group", "\"first\"", "name", "\"eighth\"")));
|
||||
|
||||
// Test sorting with MAX 3
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "3"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
|
||||
IsMap("number", "400")));
|
||||
|
||||
// Test sorting with MAX 3
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "2", "@number", "DESC", "MAX", "3"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "1200"), IsMap("number", "900"),
|
||||
IsMap("number", "800")));
|
||||
|
||||
// Test sorting by number (ASC) with MAX 999
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "999"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
|
||||
IsMap("number", "400"), IsMap("number", "400"),
|
||||
IsMap("number", "800"), IsMap("number", "900"),
|
||||
IsMap("number", "1200"), IsMap(), IsMap()));
|
||||
|
||||
// Test sorting by name and number (DESC)
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "3", "@name", "@number", "DESC"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(
|
||||
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
|
||||
IsMap("name", "\"first\"", "number", "1200"),
|
||||
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
|
||||
IsMap("name", "\"second\"", "number", "800"),
|
||||
IsMap("name", "\"seventh\"", "number", "400"),
|
||||
IsMap("name", "\"sixth\"", "number", "300"),
|
||||
IsMap("name", "\"third\"", "number", "300")));
|
||||
|
||||
// Test SORTBY with MAX, GROUPBY, and REDUCE COUNT
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "1",
|
||||
"@number", "REDUCE", "COUNT", "0", "AS", "count"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "900", "count", "1"),
|
||||
IsMap("number", ArgType(RespExpr::NIL), "count", "1"),
|
||||
IsMap("number", "1200", "count", "1")));
|
||||
|
||||
// Test SORTBY with MAX, GROUPBY (0 fields), and REDUCE COUNT
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "0",
|
||||
"REDUCE", "COUNT", "0", "AS", "count"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("count", "3")));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) {
|
||||
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
|
||||
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
|
||||
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
|
||||
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
|
||||
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
|
||||
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
|
||||
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
|
||||
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
|
||||
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});
|
||||
|
||||
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
|
||||
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});
|
||||
|
||||
// Test SORTBY with invalid argument count
|
||||
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "999", "@name", "@number", "DESC"});
|
||||
EXPECT_THAT(resp, ErrArg("bad arguments for SORTBY: specified invalid number of strings"));
|
||||
|
||||
// Test SORTBY with negative argument count
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "-3", "@name", "@number", "DESC"});
|
||||
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
|
||||
|
||||
// Test MAX with invalid value
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "-10"});
|
||||
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
|
||||
|
||||
// Test MAX without a value
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX"});
|
||||
EXPECT_THAT(resp, ErrArg("syntax error"));
|
||||
|
||||
// Test SORTBY with a non-existing field
|
||||
/* Temporary unsupported
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@nonexistingfield"});
|
||||
EXPECT_THAT(resp, ErrArg("Property `nonexistingfield` not loaded nor in schema")); */
|
||||
|
||||
// Test SORTBY with an invalid value
|
||||
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "notvalue", "@name"});
|
||||
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue