CmdArgParser improvement (#3633)

* feat: add processing of tail args into CmdArgParser::Check
* refactor: rename CmdArgParser::Switch to Map
* feat: add CheckMap method into CmdArgParser
This commit is contained in:
Borys 2024-09-05 15:30:54 +03:00 committed by GitHub
parent 3461419088
commit a1e9ee1b6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 128 additions and 86 deletions

View file

@ -36,6 +36,7 @@ ErrorReply CmdArgParser::ErrorInfo::MakeReply() const {
CmdArgParser::~CmdArgParser() {
DCHECK(!error_.has_value()) << "Parsing error occured but not checked";
// TODO DCHECK(!HasNext()) << "Not all args were processed";
}
void CmdArgParser::ToUpper(size_t i) {

View file

@ -65,15 +65,15 @@ struct CmdArgParser {
void ExpectTag(std::string_view tag);
// Consume next value
template <class... Cases> auto Switch(Cases&&... cases) {
template <class... Cases> auto MapNext(Cases&&... cases) {
if (cur_i_ >= args_.size()) {
Report(OUT_OF_BOUNDS, cur_i_);
return typename decltype(SwitchImpl(std::string_view(),
std::forward<Cases>(cases)...))::value_type{};
return typename decltype(MapImpl(std::string_view(),
std::forward<Cases>(cases)...))::value_type{};
}
auto idx = cur_i_++;
auto res = SwitchImpl(SafeSV(idx), std::forward<Cases>(cases)...);
auto res = MapImpl(SafeSV(idx), std::forward<Cases>(cases)...);
if (!res) {
Report(INVALID_CASES, idx);
return typename decltype(res)::value_type{};
@ -81,16 +81,31 @@ struct CmdArgParser {
return *res;
}
// Check if the next value if equal to a specific tag. If equal, its consumed.
bool Check(std::string_view tag) {
if (cur_i_ >= args_.size())
// Consume next value if can map it and return mapped result or return nullopt
template <class... Cases>
auto TryMapNext(Cases&&... cases)
-> std::optional<std::tuple_element_t<1, std::tuple<Cases...>>> {
if (cur_i_ >= args_.size()) {
return std::nullopt;
}
auto res = MapImpl(SafeSV(cur_i_), std::forward<Cases>(cases)...);
cur_i_ = res ? cur_i_ + 1 : cur_i_;
return res;
}
// Check if the next value is equal to a specific tag. If equal, its consumed.
template <class... Args> bool Check(std::string_view tag, Args*... args) {
if (cur_i_ + sizeof...(Args) >= args_.size())
return false;
std::string_view arg = SafeSV(cur_i_);
if (!absl::EqualsIgnoreCase(arg, tag))
return false;
cur_i_++;
((*args = Convert<Args>(++cur_i_)), ...);
++cur_i_;
return true;
}
@ -137,13 +152,13 @@ struct CmdArgParser {
private:
template <class T, class... Cases>
std::optional<std::decay_t<T>> SwitchImpl(std::string_view arg, std::string_view tag, T&& value,
Cases&&... cases) {
std::optional<std::decay_t<T>> MapImpl(std::string_view arg, std::string_view tag, T&& value,
Cases&&... cases) {
if (absl::EqualsIgnoreCase(arg, tag))
return std::forward<T>(value);
if constexpr (sizeof...(cases) > 0)
return SwitchImpl(arg, cases...);
return MapImpl(arg, cases...);
return std::nullopt;
}
@ -172,8 +187,10 @@ struct CmdArgParser {
}
void Report(ErrorType type, size_t idx) {
if (!error_)
if (!error_) {
error_ = {type, idx};
cur_i_ = args_.size();
}
}
template <typename T> T Num(size_t idx) {

View file

@ -79,6 +79,7 @@ TEST_F(CmdArgParserTest, Check) {
EXPECT_FALSE(parser.Check("NOT_TAG_2"));
EXPECT_TRUE(parser.Check("TAG_2"));
EXPECT_EQ(parser.Next<int>(), 22);
}
TEST_F(CmdArgParserTest, NextStatement) {
@ -95,31 +96,42 @@ TEST_F(CmdArgParserTest, NextStatement) {
}
TEST_F(CmdArgParserTest, CheckTailFail) {
auto parser = Make({"TAG", "11", "22", "TAG", "33"});
auto parser = Make({"TAG", "11", "22", "TAG", "text"});
EXPECT_TRUE(parser.Check("TAG"));
parser.Skip(2);
int first;
string_view second;
EXPECT_TRUE(parser.Check("TAG", &first, &second));
EXPECT_EQ(first, 11);
EXPECT_EQ(second, "22");
EXPECT_TRUE(parser.Check("TAG"));
parser.Next<int, int>();
auto err = parser.Error();
EXPECT_TRUE(err);
EXPECT_EQ(err->index, 4);
EXPECT_FALSE(parser.Check("TAG", &first, &second));
EXPECT_TRUE(parser.Check("TAG", &first));
EXPECT_TRUE(parser.Error());
}
TEST_F(CmdArgParserTest, Cases) {
TEST_F(CmdArgParserTest, Map) {
auto parser = Make({"TWO", "NONE"});
EXPECT_EQ(int(parser.Switch("ONE", 1, "TWO", 2)), 2);
EXPECT_EQ(parser.MapNext("ONE", 1, "TWO", 2), 2);
EXPECT_EQ(int(parser.Switch("ONE", 1, "TWO", 2)), 0);
EXPECT_EQ(parser.MapNext("ONE", 1, "TWO", 2), 0);
auto err = parser.Error();
EXPECT_TRUE(err);
EXPECT_EQ(err->type, CmdArgParser::INVALID_CASES);
EXPECT_EQ(err->index, 1);
}
TEST_F(CmdArgParserTest, TryMapNext) {
auto parser = Make({"TWO", "GREEN"});
EXPECT_EQ(parser.TryMapNext("ONE", 1, "TWO", 2), std::make_optional(2));
EXPECT_EQ(parser.TryMapNext("ONE", 1, "TWO", 2), std::nullopt);
EXPECT_FALSE(parser.HasError());
EXPECT_EQ(parser.TryMapNext("green", 1, "yellow", 2), std::make_optional(1));
EXPECT_FALSE(parser.HasError());
}
TEST_F(CmdArgParserTest, IgnoreCase) {
auto parser = Make({"hello", "marker", "taail", "world"});

View file

@ -1015,7 +1015,7 @@ nonstd::expected<CommandList, std::string> ParseToCommandList(CmdArgList args, b
make_unexpected("BITFIELD_RO only supports the GET subcommand");
}
using pol = Overflow::Policy;
auto res = parser.Switch("SAT", pol::SAT, "WRAP", pol::WRAP, "FAIL", pol::FAIL);
auto res = parser.MapNext("SAT", pol::SAT, "WRAP", pol::WRAP, "FAIL", pol::FAIL);
if (!parser.HasError()) {
result.push_back(Overflow{res});
continue;

View file

@ -1300,13 +1300,11 @@ OpStatus OpMerge(const OpArgs& op_args, string_view key, string_view path,
void JsonFamily::Set(CmdArgList args, ConnectionContext* cntx) {
CmdArgParser parser{args};
string_view key = parser.Next();
string_view path = parser.Next();
string_view json_str = parser.Next();
auto [key, path, json_str] = parser.Next<string_view, string_view, string_view>();
WrappedJsonPath json_path = GET_OR_SEND_UNEXPECTED(ParseJsonPath(path));
int res = parser.HasNext() ? parser.Switch("NX", 1, "XX", 2) : 0;
auto res = parser.TryMapNext("NX", 1, "XX", 2);
bool is_xx_condition = (res == 2), is_nx_condition = (res == 1);
if (parser.Error() || parser.HasNext()) // also clear the parser error dcheck

View file

@ -100,7 +100,7 @@ string ListPop(ListDir dir, quicklist* ql) {
}
ListDir ParseDir(facade::CmdArgParser* parser) {
return parser->Switch("LEFT", ListDir::LEFT, "RIGHT", ListDir::RIGHT);
return parser->MapNext("LEFT", ListDir::LEFT, "RIGHT", ListDir::RIGHT);
}
string_view DirToSv(ListDir dir) {
@ -989,13 +989,14 @@ void ListFamily::LIndex(CmdArgList args, ConnectionContext* cntx) {
void ListFamily::LInsert(CmdArgList args, ConnectionContext* cntx) {
facade::CmdArgParser parser{args};
string_view key = parser.Next();
InsertParam where = parser.Switch("AFTER", INSERT_AFTER, "BEFORE", INSERT_BEFORE);
InsertParam where = parser.MapNext("AFTER", INSERT_AFTER, "BEFORE", INSERT_BEFORE);
auto [pivot, elem] = parser.Next<string_view, string_view>();
DCHECK(pivot.data() && elem.data());
if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());
DCHECK(pivot.data() && elem.data());
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpInsert(t->GetOpArgs(shard), key, pivot, elem, where);
};

View file

@ -66,7 +66,7 @@ ValueIterator& ValueIterator::operator++() {
return *this;
}
Reducer::Func FindReducerFunc(std::string_view name) {
Reducer::Func FindReducerFunc(ReducerFunc name) {
const static auto kCountReducer = [](ValueIterator it) -> double {
return std::distance(it, it.end());
};
@ -78,17 +78,24 @@ Reducer::Func FindReducerFunc(std::string_view name) {
return sum;
};
static const std::unordered_map<std::string_view, std::function<Value(ValueIterator)>> kReducers =
{{"COUNT", [](auto it) { return kCountReducer(it); }},
{"COUNT_DISTINCT",
[](auto it) { return double(std::unordered_set<Value>(it, it.end()).size()); }},
{"SUM", [](auto it) { return kSumReducer(it); }},
{"AVG", [](auto it) { return kSumReducer(it) / kCountReducer(it); }},
{"MAX", [](auto it) { return *std::max_element(it, it.end()); }},
{"MIN", [](auto it) { return *std::min_element(it, it.end()); }}};
switch (name) {
case ReducerFunc::COUNT:
return [](ValueIterator it) -> Value { return kCountReducer(it); };
case ReducerFunc::COUNT_DISTINCT:
return [](ValueIterator it) -> Value {
return double(std::unordered_set<Value>(it, it.end()).size());
};
case ReducerFunc::SUM:
return [](ValueIterator it) -> Value { return kSumReducer(it); };
case ReducerFunc::AVG:
return [](ValueIterator it) -> Value { return kSumReducer(it) / kCountReducer(it); };
case ReducerFunc::MAX:
return [](ValueIterator it) -> Value { return *std::max_element(it, it.end()); };
case ReducerFunc::MIN:
return [](ValueIterator it) -> Value { return *std::min_element(it, it.end()); };
}
auto it = kReducers.find(name);
return it != kReducers.end() ? it->second : Reducer::Func{};
return nullptr;
}
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,

View file

@ -61,13 +61,15 @@ struct ValueIterator {
};
struct Reducer {
using Func = std::function<Value(ValueIterator)>;
using Func = Value (*)(ValueIterator);
std::string source_field, result_field;
Func func;
};
enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
// Find reducer function by uppercase name (COUNT, MAX, etc...), empty functor if not found
Reducer::Func FindReducerFunc(std::string_view name);
Reducer::Func FindReducerFunc(ReducerFunc name);
// Make `GROUPBY [fields...]` with REDUCE step
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,

View file

@ -77,9 +77,10 @@ TEST(AggregatorTest, GroupWithReduce) {
std::string_view fields[] = {"tag"};
std::vector<Reducer> reducers = {
Reducer{"", "count", FindReducerFunc("COUNT")}, Reducer{"i", "sum-i", FindReducerFunc("SUM")},
Reducer{"half-i", "distinct-hi", FindReducerFunc("COUNT_DISTINCT")},
Reducer{"null-field", "distinct-null", FindReducerFunc("COUNT_DISTINCT")}};
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};
auto result = Process(values, steps);

View file

@ -47,21 +47,17 @@ bool IsValidJsonPath(string_view path) {
search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
search::SchemaField::VectorParams params{};
params.use_hnsw = parser->Switch("HNSW", true, "FLAT", false);
params.use_hnsw = parser->MapNext("HNSW", true, "FLAT", false);
const size_t num_args = parser->Next<size_t>();
for (size_t i = 0; i * 2 < num_args; i++) {
if (parser->Check("DIM")) {
params.dim = parser->Next<size_t>();
if (parser->Check("DIM", &params.dim)) {
} else if (parser->Check("DISTANCE_METRIC")) {
params.sim = parser->Switch("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
} else if (parser->Check("INITIAL_CAP")) {
params.capacity = parser->Next<size_t>();
} else if (parser->Check("M")) {
params.hnsw_m = parser->Next<size_t>();
} else if (parser->Check("EF_CONSTRUCTION")) {
params.hnsw_ef_construction = parser->Next<size_t>();
params.sim = parser->MapNext("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
} else if (parser->Check("INITIAL_CAP", &params.capacity)) {
} else if (parser->Check("M", &params.hnsw_m)) {
} else if (parser->Check("EF_CONSTRUCTION", &params.hnsw_ef_construction)) {
} else if (parser->Check("EF_RUNTIME")) {
parser->Next<size_t>();
LOG(WARNING) << "EF_RUNTIME not supported";
@ -116,13 +112,12 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
}
// AS [alias]
if (parser.Check("AS"))
field_alias = parser.Next();
parser.Check("AS", &field_alias);
// Determine type
using search::SchemaField;
auto type = parser.Switch("TAG", SchemaField::TAG, "TEXT", SchemaField::TEXT, "NUMERIC",
SchemaField::NUMERIC, "VECTOR", SchemaField::VECTOR);
auto type = parser.MapNext("TAG", SchemaField::TAG, "TEXT", SchemaField::TEXT, "NUMERIC",
SchemaField::NUMERIC, "VECTOR", SchemaField::VECTOR);
if (auto err = parser.Error(); err) {
cntx->SendError(err->MakeReply());
return nullopt;
@ -265,24 +260,26 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
vector<aggregate::Reducer> reducers;
while (parser.Check("REDUCE")) {
parser.ToUpper(); // uppercase for func_name
auto [func_name, nargs] = parser.Next<string_view, size_t>();
auto func = aggregate::FindReducerFunc(func_name);
using RF = aggregate::ReducerFunc;
auto func_name =
parser.TryMapNext("COUNT", RF::COUNT, "COUNT_DISTINCT", RF::COUNT_DISTINCT, "SUM",
RF::SUM, "AVG", RF::AVG, "MAX", RF::MAX, "MIN", RF::MIN);
if (!parser.HasError() && !func) {
cntx->SendError(absl::StrCat("reducer function ", func_name, " not found"));
if (!func_name) {
cntx->SendError(absl::StrCat("reducer function ", parser.Next(), " not found"));
return nullopt;
}
string source_field = "";
if (nargs > 0) {
source_field = parser.Next<string>();
}
auto func = aggregate::FindReducerFunc(*func_name);
auto nargs = parser.Next<size_t>();
string source_field = nargs > 0 ? parser.Next<string>() : "";
parser.ExpectTag("AS");
string result_field = parser.Next<string>();
reducers.push_back(aggregate::Reducer{source_field, result_field, std::move(func)});
reducers.push_back(
aggregate::Reducer{std::move(source_field), std::move(result_field), std::move(func)});
}
params.steps.push_back(aggregate::MakeGroupStep(fields, std::move(reducers)));
@ -435,7 +432,7 @@ void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
while (parser.HasNext()) {
// ON HASH | JSON
if (parser.Check("ON")) {
index.type = parser.Switch("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
index.type = parser.MapNext("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
continue;
}

View file

@ -435,7 +435,7 @@ void ClientPauseCmd(CmdArgList args, vector<facade::Listener*> listeners, Connec
auto timeout = parser.Next<uint64_t>();
ClientPause pause_state = ClientPause::ALL;
if (parser.HasNext()) {
pause_state = parser.Switch("WRITE", ClientPause::WRITE, "ALL", ClientPause::ALL);
pause_state = parser.MapNext("WRITE", ClientPause::WRITE, "ALL", ClientPause::ALL);
}
if (auto err = parser.Error(); err) {
return cntx->SendError(err->MakeReply());

View file

@ -44,6 +44,8 @@ using namespace facade;
using CI = CommandId;
enum class ExpT { EX, PX, EXAT, PXAT };
constexpr uint32_t kMaxStrLen = 1 << 28;
void CopyValueToBuffer(const PrimeValue& pv, char* dest) {
@ -758,9 +760,10 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
facade::SinkReplyBuilder* builder = cntx->reply_builder();
while (parser.HasNext()) {
parser.ToUpper();
if (base::_in(parser.Peek(), {"EX", "PX", "EXAT", "PXAT"})) {
auto [opt, int_arg] = parser.Next<string_view, int64_t>();
if (auto exp_type = parser.TryMapNext("EX", ExpT::EX, "PX", ExpT::PX, "EXAT", ExpT::EXAT,
"PXAT", ExpT::PXAT);
exp_type) {
auto int_arg = parser.Next<int64_t>();
if (auto err = parser.Error(); err) {
return cntx->SendError(err->MakeReply());
@ -779,8 +782,8 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
DbSlice::ExpireParams expiry{
.value = int_arg,
.unit = (opt[0] == 'P') ? TimeUnit::MSEC : TimeUnit::SEC,
.absolute = absl::EndsWith(opt, "AT"),
.unit = *exp_type == ExpT::PX || *exp_type == ExpT::PXAT ? TimeUnit::MSEC : TimeUnit::SEC,
.absolute = *exp_type == ExpT::EXAT || *exp_type == ExpT::PXAT,
};
int64_t now_ms = GetCurrentTimeMs();
@ -802,7 +805,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
} else if (parser.Check("_MCFLAGS")) {
sparams.memcache_flags = parser.Next<uint32_t>();
} else {
uint16_t flag = parser.Switch( //
uint16_t flag = parser.MapNext( //
"GET", SetCmd::SET_GET, "STICK", SetCmd::SET_STICK, "KEEPTTL", SetCmd::SET_KEEP_EXPIRE,
"XX", SetCmd::SET_IF_EXISTS, "NX", SetCmd::SET_IF_NOTEXIST);
sparams.flags |= flag;
@ -970,9 +973,11 @@ void StringFamily::GetEx(CmdArgList args, ConnectionContext* cntx) {
DbSlice::ExpireParams exp_params;
bool defined = false;
while (parser.ToUpper().HasNext()) {
if (base::_in(parser.Peek(), {"EX", "PX", "EXAT", "PXAT"})) {
auto [ex, int_arg] = parser.Next<string_view, int64_t>();
while (parser.HasNext()) {
if (auto exp_type = parser.TryMapNext("EX", ExpT::EX, "PX", ExpT::PX, "EXAT", ExpT::EXAT,
"PXAT", ExpT::PXAT);
exp_type) {
auto int_arg = parser.Next<int64_t>();
if (auto err = parser.Error(); err) {
return cntx->SendError(err->MakeReply());
}
@ -985,9 +990,10 @@ void StringFamily::GetEx(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(InvalidExpireTime("getex"));
}
exp_params.absolute = base::_in(ex, {"EXAT", "PXAT"});
exp_params.absolute = *exp_type == ExpT::EXAT || *exp_type == ExpT::PXAT;
exp_params.value = int_arg;
exp_params.unit = ex[0] == 'P' ? TimeUnit::MSEC : TimeUnit::SEC;
exp_params.unit =
*exp_type == ExpT::PX || *exp_type == ExpT::PXAT ? TimeUnit::MSEC : TimeUnit::SEC;
defined = true;
} else if (parser.Check("PERSIST")) {
exp_params.persist = true;