refactor: remove toUpper() from cmd_arg_parser (#3599)

* refactor: remove usage of toUpper() from cmd_arg_parser

* refactor: remove CmdArgParser::NextUpper
This commit is contained in:
Borys 2024-08-29 15:19:52 +03:00 committed by GitHub
parent 72fc0391f0
commit 88229cf365
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 63 additions and 91 deletions

View file

@ -49,12 +49,6 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) {
} while (cursor);
}
const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTypes = {
{"TAG"sv, search::SchemaField::TAG},
{"TEXT"sv, search::SchemaField::TEXT},
{"NUMERIC"sv, search::SchemaField::NUMERIC},
{"VECTOR"sv, search::SchemaField::VECTOR}};
} // namespace
bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const {
@ -70,15 +64,17 @@ bool SearchParams::ShouldReturnField(std::string_view field) const {
return !return_fields || any_of(return_fields->begin(), return_fields->end(), cb);
}
optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
auto it = kSchemaTypes.find(name);
return it != kSchemaTypes.end() ? make_optional(it->second) : nullopt;
}
string_view SearchFieldTypeToString(search::SchemaField::FieldType type) {
for (auto [it_name, it_type] : kSchemaTypes)
if (it_type == type)
return it_name;
switch (type) {
case search::SchemaField::TAG:
return "TAG";
case search::SchemaField::TEXT:
return "TEXT";
case search::SchemaField::NUMERIC:
return "NUMERIC";
case search::SchemaField::VECTOR:
return "VECTOR";
}
ABSL_UNREACHABLE();
return "";
}

View file

@ -22,7 +22,6 @@ namespace dfly {
using SearchDocData = absl::flat_hash_map<std::string /*field*/, std::string /*value*/>;
std::optional<search::SchemaField::FieldType> ParseSearchFieldType(std::string_view name);
std::string_view SearchFieldTypeToString(search::SchemaField::FieldType);
struct SerializedSearchDoc {

View file

@ -47,20 +47,18 @@ bool IsValidJsonPath(string_view path) {
search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
search::SchemaField::VectorParams params{};
params.use_hnsw = parser->ToUpper().Switch("HNSW", true, "FLAT", false);
params.use_hnsw = parser->Switch("HNSW", true, "FLAT", false);
size_t num_args = parser->Next<size_t>();
for (size_t i = 0; i * 2 < num_args; i++) {
parser->ToUpper();
if (parser->Check("DIM").ExpectTail(1)) {
params.dim = parser->Next<size_t>();
continue;
}
if (parser->Check("DISTANCE_METRIC").ExpectTail(1)) {
params.sim = parser->ToUpper().Switch("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
params.sim = parser->Switch("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
continue;
}
@ -100,13 +98,13 @@ search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
search::SchemaField::TagParams ParseTagParams(CmdArgParser* parser) {
search::SchemaField::TagParams params{};
while (parser->HasNext()) {
if (parser->Check("SEPARATOR").IgnoreCase().ExpectTail(1)) {
if (parser->Check("SEPARATOR").ExpectTail(1)) {
string_view separator = parser->Next();
params.separator = separator.front();
continue;
}
if (parser->Check("CASESENSITIVE").IgnoreCase()) {
if (parser->Check("CASESENSITIVE")) {
params.case_sensitive = true;
continue;
}
@ -137,26 +135,25 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
return nullopt;
}
parser.ToUpper();
// AS [alias]
if (parser.Check("AS").ExpectTail(1).NextUpper())
if (parser.Check("AS").ExpectTail(1))
field_alias = parser.Next();
// Determine type
string_view type_str = parser.Next();
auto type = ParseSearchFieldType(type_str);
if (!type) {
cntx->SendError("Invalid field type: " + string{type_str});
using search::SchemaField;
auto type = parser.Switch("TAG", SchemaField::TAG, "TEXT", SchemaField::TEXT, "NUMERIC",
SchemaField::NUMERIC, "VECTOR", SchemaField::VECTOR);
if (auto err = parser.Error(); err) {
cntx->SendError(err->MakeReply());
return nullopt;
}
// Tag fields include: [separator char] [casesensitive]
// Vector fields include: {algorithm} num_args args...
search::SchemaField::ParamsVariant params(monostate{});
if (*type == search::SchemaField::TAG) {
if (type == search::SchemaField::TAG) {
params = ParseTagParams(&parser);
} else if (*type == search::SchemaField::VECTOR) {
} else if (type == search::SchemaField::VECTOR) {
auto vector_params = ParseVectorParams(&parser);
if (parser.HasError()) {
auto err = *parser.Error();
@ -175,12 +172,12 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
// Flags: check for SORTABLE and NOINDEX
uint8_t flags = 0;
while (parser.HasNext()) {
if (parser.Check("NOINDEX").IgnoreCase()) {
if (parser.Check("NOINDEX")) {
flags |= search::SchemaField::NOINDEX;
continue;
}
if (parser.Check("SORTABLE").IgnoreCase()) {
if (parser.Check("SORTABLE")) {
flags |= search::SchemaField::SORTABLE;
continue;
}
@ -192,7 +189,7 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
while (kIgnoredOptions.count(parser.Peek()) > 0)
parser.Skip(2);
schema.fields[field] = {*type, flags, string{field_alias}, std::move(params)};
schema.fields[field] = {type, flags, string{field_alias}, std::move(params)};
}
// Build field name mapping table
@ -224,7 +221,7 @@ search::QueryParams ParseQueryParams(CmdArgParser* parser) {
optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionContext* cntx) {
SearchParams params;
while (parser.ToUpper().HasNext()) {
while (parser.HasNext()) {
// [LIMIT offset total]
if (parser.Check("LIMIT").ExpectTail(2)) {
params.limit_offset = parser.Next<size_t>();
@ -238,7 +235,7 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
params.return_fields = SearchParams::FieldReturnList{};
while (params.return_fields->size() < num_fields) {
string_view ident = parser.Next();
string_view alias = parser.Check("AS").IgnoreCase().ExpectTail(1) ? parser.Next() : ident;
string_view alias = parser.Check("AS").ExpectTail(1) ? parser.Next() : ident;
params.return_fields->emplace_back(ident, alias);
}
continue;
@ -257,8 +254,7 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
}
if (parser.Check("SORTBY").ExpectTail(1)) {
params.sort_option =
search::SortOption{string{parser.Next()}, bool(parser.Check("DESC").IgnoreCase())};
params.sort_option = search::SortOption{string{parser.Next()}, bool(parser.Check("DESC"))};
continue;
}
@ -287,7 +283,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
AggregateParams params;
tie(params.index, params.query) = parser.Next<string_view, string_view>();
while (parser.ToUpper().HasNext()) {
while (parser.HasNext()) {
// LOAD count field [field ...]
if (parser.Check("LOAD").ExpectTail(1)) {
params.load_fields.resize(parser.Next<size_t>());
@ -303,7 +299,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
field = parser.Next();
vector<aggregate::Reducer> reducers;
while (parser.ToUpper().Check("REDUCE").ExpectTail(2)) {
while (parser.Check("REDUCE").ExpectTail(2)) {
parser.ToUpper(); // uppercase for func_name
auto [func_name, nargs] = parser.Next<string_view, size_t>();
auto func = aggregate::FindReducerFunc(func_name);
@ -332,7 +328,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
if (parser.Check("SORTBY").ExpectTail(1)) {
parser.ExpectTag("1");
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC").IgnoreCase());
bool desc = bool(parser.Check("DESC"));
params.steps.push_back(aggregate::MakeSortStep(field, desc));
continue;
@ -471,10 +467,10 @@ void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
CmdArgParser parser{args};
string_view idx_name = parser.Next();
while (parser.ToUpper().HasNext()) {
while (parser.HasNext()) {
// ON HASH | JSON
if (parser.Check("ON").ExpectTail(1)) {
index.type = parser.ToUpper().Switch("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
index.type = parser.Switch("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
continue;
}