CmdArgParser improvement (#3633)

* feat: add processing of tail args into CmdArgParser::Check
* refactor: rename CmdArgParser::Switch to Map
* feat: add CheckMap method into CmdArgParser
This commit is contained in:
Borys 2024-09-05 15:30:54 +03:00 committed by GitHub
parent 3461419088
commit a1e9ee1b6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 128 additions and 86 deletions

View file

@ -66,7 +66,7 @@ ValueIterator& ValueIterator::operator++() {
return *this;
}
Reducer::Func FindReducerFunc(std::string_view name) {
Reducer::Func FindReducerFunc(ReducerFunc name) {
const static auto kCountReducer = [](ValueIterator it) -> double {
return std::distance(it, it.end());
};
@ -78,17 +78,24 @@ Reducer::Func FindReducerFunc(std::string_view name) {
return sum;
};
static const std::unordered_map<std::string_view, std::function<Value(ValueIterator)>> kReducers =
{{"COUNT", [](auto it) { return kCountReducer(it); }},
{"COUNT_DISTINCT",
[](auto it) { return double(std::unordered_set<Value>(it, it.end()).size()); }},
{"SUM", [](auto it) { return kSumReducer(it); }},
{"AVG", [](auto it) { return kSumReducer(it) / kCountReducer(it); }},
{"MAX", [](auto it) { return *std::max_element(it, it.end()); }},
{"MIN", [](auto it) { return *std::min_element(it, it.end()); }}};
switch (name) {
case ReducerFunc::COUNT:
return [](ValueIterator it) -> Value { return kCountReducer(it); };
case ReducerFunc::COUNT_DISTINCT:
return [](ValueIterator it) -> Value {
return double(std::unordered_set<Value>(it, it.end()).size());
};
case ReducerFunc::SUM:
return [](ValueIterator it) -> Value { return kSumReducer(it); };
case ReducerFunc::AVG:
return [](ValueIterator it) -> Value { return kSumReducer(it) / kCountReducer(it); };
case ReducerFunc::MAX:
return [](ValueIterator it) -> Value { return *std::max_element(it, it.end()); };
case ReducerFunc::MIN:
return [](ValueIterator it) -> Value { return *std::min_element(it, it.end()); };
}
auto it = kReducers.find(name);
return it != kReducers.end() ? it->second : Reducer::Func{};
return nullptr;
}
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,

View file

@ -61,13 +61,15 @@ struct ValueIterator {
};
struct Reducer {
using Func = std::function<Value(ValueIterator)>;
using Func = Value (*)(ValueIterator);
std::string source_field, result_field;
Func func;
};
enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
// Find reducer function by uppercase name (COUNT, MAX, etc...), empty functor if not found
Reducer::Func FindReducerFunc(std::string_view name);
Reducer::Func FindReducerFunc(ReducerFunc name);
// Make `GROUPBY [fields...]` with REDUCE step
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,

View file

@ -77,9 +77,10 @@ TEST(AggregatorTest, GroupWithReduce) {
std::string_view fields[] = {"tag"};
std::vector<Reducer> reducers = {
Reducer{"", "count", FindReducerFunc("COUNT")}, Reducer{"i", "sum-i", FindReducerFunc("SUM")},
Reducer{"half-i", "distinct-hi", FindReducerFunc("COUNT_DISTINCT")},
Reducer{"null-field", "distinct-null", FindReducerFunc("COUNT_DISTINCT")}};
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
auto result = Process(values, steps);

View file

@ -47,21 +47,17 @@ bool IsValidJsonPath(string_view path) {
search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
search::SchemaField::VectorParams params{};
params.use_hnsw = parser->Switch("HNSW", true, "FLAT", false);
params.use_hnsw = parser->MapNext("HNSW", true, "FLAT", false);
const size_t num_args = parser->Next<size_t>();
for (size_t i = 0; i * 2 < num_args; i++) {
if (parser->Check("DIM")) {
params.dim = parser->Next<size_t>();
if (parser->Check("DIM", &params.dim)) {
} else if (parser->Check("DISTANCE_METRIC")) {
params.sim = parser->Switch("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
} else if (parser->Check("INITIAL_CAP")) {
params.capacity = parser->Next<size_t>();
} else if (parser->Check("M")) {
params.hnsw_m = parser->Next<size_t>();
} else if (parser->Check("EF_CONSTRUCTION")) {
params.hnsw_ef_construction = parser->Next<size_t>();
params.sim = parser->MapNext("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
} else if (parser->Check("INITIAL_CAP", &params.capacity)) {
} else if (parser->Check("M", &params.hnsw_m)) {
} else if (parser->Check("EF_CONSTRUCTION", &params.hnsw_ef_construction)) {
} else if (parser->Check("EF_RUNTIME")) {
parser->Next<size_t>();
LOG(WARNING) << "EF_RUNTIME not supported";
@ -116,13 +112,12 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
}
// AS [alias]
if (parser.Check("AS"))
field_alias = parser.Next();
parser.Check("AS", &field_alias);
// Determine type
using search::SchemaField;
auto type = parser.Switch("TAG", SchemaField::TAG, "TEXT", SchemaField::TEXT, "NUMERIC",
SchemaField::NUMERIC, "VECTOR", SchemaField::VECTOR);
auto type = parser.MapNext("TAG", SchemaField::TAG, "TEXT", SchemaField::TEXT, "NUMERIC",
SchemaField::NUMERIC, "VECTOR", SchemaField::VECTOR);
if (auto err = parser.Error(); err) {
cntx->SendError(err->MakeReply());
return nullopt;
@ -265,24 +260,26 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
vector<aggregate::Reducer> reducers;
while (parser.Check("REDUCE")) {
parser.ToUpper(); // uppercase for func_name
auto [func_name, nargs] = parser.Next<string_view, size_t>();
auto func = aggregate::FindReducerFunc(func_name);
using RF = aggregate::ReducerFunc;
auto func_name =
parser.TryMapNext("COUNT", RF::COUNT, "COUNT_DISTINCT", RF::COUNT_DISTINCT, "SUM",
RF::SUM, "AVG", RF::AVG, "MAX", RF::MAX, "MIN", RF::MIN);
if (!parser.HasError() && !func) {
cntx->SendError(absl::StrCat("reducer function ", func_name, " not found"));
if (!func_name) {
cntx->SendError(absl::StrCat("reducer function ", parser.Next(), " not found"));
return nullopt;
}
string source_field = "";
if (nargs > 0) {
source_field = parser.Next<string>();
}
auto func = aggregate::FindReducerFunc(*func_name);
auto nargs = parser.Next<size_t>();
string source_field = nargs > 0 ? parser.Next<string>() : "";
parser.ExpectTag("AS");
string result_field = parser.Next<string>();
reducers.push_back(aggregate::Reducer{source_field, result_field, std::move(func)});
reducers.push_back(
aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)});
}
params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
@ -435,7 +432,7 @@ void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
while (parser.HasNext()) {
// ON HASH | JSON
if (parser.Check("ON")) {
index.type = parser.Switch("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
index.type = parser.MapNext("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
continue;
}