From bb9819464f26c3be426295ee2c200424f1ffe6b1 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Wed, 29 Jan 2025 10:58:17 +0200 Subject: [PATCH] chore: introduce GlobMatcher (#4521) Right now it's just a wrapper around stringmatchlen, so no functional changes are expected. Signed-off-by: Roman Gershman --- src/core/CMakeLists.txt | 2 +- src/core/dfly_core_test.cc | 8 +++----- src/core/glob_matcher.cc | 22 ++++++++++++++++++++++ src/core/glob_matcher.h | 21 +++++++++++++++++++++ src/server/acl/validator.cc | 22 ++++++++-------------- src/server/channel_store.cc | 17 +++++------------ src/server/common.cc | 7 ++----- src/server/common.h | 3 ++- src/server/config_registry.cc | 9 ++++----- src/server/dragonfly_test.cc | 2 +- src/server/generic_family.cc | 3 +-- 11 files changed, 70 insertions(+), 46 deletions(-) create mode 100644 src/core/glob_matcher.cc create mode 100644 src/core/glob_matcher.h diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 7eb8fb73c..ed2d2ed59 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -5,7 +5,7 @@ set(SEARCH_LIB query_parser) add_library(dfly_core allocation_tracker.cc bloom.cc compact_object.cc dense_set.cc dragonfly_core.cc extent_tree.cc - interpreter.cc mi_memory_resource.cc qlist.cc sds_utils.cc + interpreter.cc glob_matcher.cc mi_memory_resource.cc qlist.cc sds_utils.cc segment_allocator.cc score_map.cc small_string.cc sorted_map.cc task_queue.cc tx_queue.cc string_set.cc string_map.cc detail/bitpacking.cc) diff --git a/src/core/dfly_core_test.cc b/src/core/dfly_core_test.cc index f9eebec28..66ae59496 100644 --- a/src/core/dfly_core_test.cc +++ b/src/core/dfly_core_test.cc @@ -3,13 +3,10 @@ // #include "base/gtest.h" +#include "core/glob_matcher.h" #include "core/intent_lock.h" #include "core/tx_queue.h" -extern "C" { -#include "redis/util.h" -} - namespace dfly { using namespace std; @@ -75,7 +72,8 @@ class StringMatchTest : public ::testing::Test { protected: // wrapper around stringmatchlen with stringview arguments int MatchLen(string_view pattern, string_view str, bool nocase) { - return stringmatchlen(pattern.data(), pattern.size(), str.data(), str.size(), nocase); + GlobMatcher matcher(pattern, !nocase); + return matcher.Matches(str); } }; diff --git a/src/core/glob_matcher.cc b/src/core/glob_matcher.cc new file mode 100644 index 000000000..e556b3b85 --- /dev/null +++ b/src/core/glob_matcher.cc @@ -0,0 +1,22 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/glob_matcher.h" + +extern "C" { +#include "redis/util.h" +} + +namespace dfly { + +GlobMatcher::GlobMatcher(std::string_view pattern, bool case_sensitive) + : pattern_(pattern), case_sensitive_(case_sensitive) { +} + +bool GlobMatcher::Matches(std::string_view str) const { + return stringmatchlen(pattern_.data(), pattern_.size(), str.data(), str.size(), + int(!case_sensitive_)) != 0; +} + +} // namespace dfly diff --git a/src/core/glob_matcher.h b/src/core/glob_matcher.h new file mode 100644 index 000000000..30a8da889 --- /dev/null +++ b/src/core/glob_matcher.h @@ -0,0 +1,21 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// +#pragma once + +#include + +namespace dfly { + +class GlobMatcher { + public: + explicit GlobMatcher(std::string_view pattern, bool case_sensitive); + + bool Matches(std::string_view str) const; + + private: + std::string_view pattern_; + bool case_sensitive_; +}; + +} // namespace dfly diff --git a/src/server/acl/validator.cc b/src/server/acl/validator.cc index b7e7627b3..b1f5e39f1 100644 --- a/src/server/acl/validator.cc +++ b/src/server/acl/validator.cc @@ -5,18 +5,20 @@ #include "server/acl/validator.h" #include "base/logging.h" +#include "core/glob_matcher.h" #include "facade/dragonfly_connection.h" #include "server/acl/acl_commands_def.h" #include "server/command_registry.h" #include "server/server_state.h" #include "server/transaction.h" -// we need this because of stringmatchlen -extern "C" { -#include "redis/util.h" -} namespace dfly::acl { +inline bool Matches(std::string_view pattern, std::string_view target) { + GlobMatcher matcher(pattern, true); + return matcher.Matches(target); +}; + [[nodiscard]] bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id, ArgSlice tail_args) { if (cntx.skip_acl_validation) { @@ -58,16 +60,12 @@ static bool ValidateCommand(const std::vector& acl_commands, const Com return {false, AclLog::Reason::COMMAND}; } - auto match = [](const auto& pattern, const auto& target) { - return stringmatchlen(pattern.data(), pattern.size(), target.data(), target.size(), 0); - }; - const bool is_read_command = id.IsReadOnly(); const bool is_write_command = id.IsWriteOnly(); auto iterate_globs = [&](auto target) { for (auto& [elem, op] : keys.key_globs) { - if (match(elem, target)) { + if (Matches(elem, target)) { if (is_read_command && (op == KeyOp::READ || op == KeyOp::READ_WRITE)) { return true; } @@ -98,16 +96,12 @@ static bool ValidateCommand(const std::vector& acl_commands, const Com return {false, AclLog::Reason::COMMAND}; } - auto match = [](std::string_view pattern, std::string_view target) { - return stringmatchlen(pattern.data(), pattern.size(), target.data(), target.size(), 0); - }; - auto iterate_globs = [&](std::string_view target) { for (auto& [glob, has_asterisk] : pub_sub.globs) { if (literal_match && (glob == target)) { return true; } - if (!literal_match && match(glob, target)) { + if (!literal_match && Matches(glob, target)) { return true; } } diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 06d4d3b10..2a1770898 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -4,15 +4,10 @@ // See LICENSE for licensing terms. // -#include - -extern "C" { -#include "redis/util.h" -} - #include #include "base/logging.h" +#include "core/glob_matcher.h" #include "server/engine_shard_set.h" #include "server/server_state.h" @@ -21,10 +16,6 @@ using namespace std; namespace { -bool Matches(string_view pattern, string_view channel) { - return stringmatchlen(pattern.data(), pattern.size(), channel.data(), channel.size(), 0) == 1; -} - // Build functor for sending messages to connection auto BuildSender(string_view channel, facade::ArgRange messages) { absl::FixedArray views(messages.Size()); @@ -171,7 +162,8 @@ vector ChannelStore::FetchSubscribers(string_view chan Fill(*it->second, string{}, &res); for (const auto& [pat, subs] : *patterns_) { - if (Matches(pat, channel)) + GlobMatcher matcher{pat, true}; + if (matcher.Matches(channel)) Fill(*subs, pat, &res); } @@ -192,8 +184,9 @@ void ChannelStore::Fill(const SubscribeMap& src, const string& pattern, vector ChannelStore::ListChannels(const string_view pattern) const { vector res; + GlobMatcher matcher{pattern, true}; for (const auto& [channel, _] : *channels_) { - if (pattern.empty() || Matches(pattern, channel)) + if (pattern.empty() || matcher.Matches(channel)) res.push_back(channel); } return res; diff --git a/src/server/common.cc b/src/server/common.cc index e16d0d79f..cd57c599a 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -12,7 +12,6 @@ extern "C" { #include "redis/rdb.h" -#include "redis/util.h" } #include "base/flags.h" @@ -298,7 +297,7 @@ OpResult ScanOpts::TryFrom(CmdArgList args) { } else if (opt == "MATCH") { string_view pattern = ArgS(args, i + 1); if (pattern != "*") - scan_opts.pattern = pattern; + scan_opts.matcher.emplace(pattern, true); } else if (opt == "TYPE") { auto obj_type = ObjTypeFromString(ArgS(args, i + 1)); if (!obj_type) { @@ -317,9 +316,7 @@ OpResult ScanOpts::TryFrom(CmdArgList args) { } bool ScanOpts::Matches(std::string_view val_name) const { - if (!pattern) - return true; - return stringmatchlen(pattern->data(), pattern->size(), val_name.data(), val_name.size(), 0) == 1; + return !matcher || matcher->Matches(val_name); } GenericError::operator std::error_code() const { diff --git a/src/server/common.h b/src/server/common.h index 67165e929..87d90ca45 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -15,6 +15,7 @@ #include #include "core/compact_object.h" +#include "core/glob_matcher.h" #include "facade/facade_types.h" #include "facade/op_status.h" #include "helio/io/proc_reader.h" @@ -303,7 +304,7 @@ class Context : protected Cancellation { }; struct ScanOpts { - std::optional pattern; + std::optional matcher; size_t limit = 10; std::optional type_filter; unsigned bucket_id = UINT_MAX; diff --git a/src/server/config_registry.cc b/src/server/config_registry.cc index 5b3d42847..6888ec95d 100644 --- a/src/server/config_registry.cc +++ b/src/server/config_registry.cc @@ -7,12 +7,9 @@ #include #include "base/logging.h" +#include "core/glob_matcher.h" #include "server/common.h" -extern "C" { -#include "redis/util.h" -} - namespace dfly { namespace { using namespace std; @@ -67,11 +64,13 @@ void ConfigRegistry::Reset() { vector ConfigRegistry::List(string_view glob) const { string normalized_glob = NormalizeConfigName(glob); + GlobMatcher matcher(normalized_glob, false /* case insensitive*/); vector res; util::fb2::LockGuard lk(mu_); + for (const auto& [name, _] : registry_) { - if (stringmatchlen(normalized_glob.data(), normalized_glob.size(), name.data(), name.size(), 1)) + if (matcher.Matches(name)) res.push_back(name); } return res; diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index 1db8353ea..9a3b402aa 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -875,7 +875,7 @@ static void BM_MatchPattern(benchmark::State& state) { absl::InsecureBitGen eng; string random_val = GetRandomHex(eng, state.range(0)); ScanOpts scan_opts; - scan_opts.pattern = "*foobar*"; + scan_opts.matcher.emplace("*foobar*", true); while (state.KeepRunning()) { DoNotOptimize(scan_opts.Matches(random_val)); } diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index 0f89fd8db..29659d503 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -12,7 +12,6 @@ extern "C" { #include "redis/crc64.h" -#include "redis/util.h" } #include "base/flags.h" @@ -1180,7 +1179,7 @@ void GenericFamily::Keys(CmdArgList args, const CommandContext& cmd_cntx) { ScanOpts scan_opts; if (pattern != "*") { - scan_opts.pattern = pattern; + scan_opts.matcher.emplace(pattern, true); } scan_opts.limit = 512;