diff --git a/src/facade/command_id.h b/src/facade/command_id.h index af95a021f..14a91a7ec 100644 --- a/src/facade/command_id.h +++ b/src/facade/command_id.h @@ -93,7 +93,7 @@ class CommandId { static uint32_t OptCount(uint32_t mask); protected: - std::string_view name_; + std::string name_; uint32_t opt_mask_; int8_t arity_; diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 81d9d0793..97e4b4b23 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -598,7 +598,6 @@ void AclFamily::Register(dfly::CommandRegistry* registry) { .HFUNC(DryRun); *registry << CI{"ACL GENPASS", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0, acl::kGenPass}.HFUNC( GenPass); - cmd_registry_ = registry; } diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 7f5ac2443..0e5f1932c 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -4,6 +4,8 @@ #include "server/acl/acl_family.h" +#include "absl/container/flat_hash_map.h" +#include "absl/flags/internal/flag.h" #include "absl/strings/str_cat.h" #include "base/gtest.h" #include "base/logging.h" @@ -14,12 +16,21 @@ using namespace testing; +ABSL_DECLARE_FLAG(std::vector, rename_command); + namespace dfly { class AclFamilyTest : public BaseFamilyTest { protected: }; +class AclFamilyTestRename : public BaseFamilyTest { + void SetUp() override { + absl::SetFlag(&FLAGS_rename_command, {"ACL=ROCKS"}); + ResetService(); + } +}; + TEST_F(AclFamilyTest, AclSetUser) { TestInitAclFam(); auto resp = Run({"ACL", "SETUSER"}); @@ -37,6 +48,14 @@ TEST_F(AclFamilyTest, AclSetUser) { auto vec = resp.GetVec(); EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL", "user vlad off nopass +@NONE")); + + resp = Run({"ACL", "SETUSER", "vlad", "+ACL"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "LIST"}); + vec = resp.GetVec(); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL", + "user vlad off nopass +@NONE +ACL")); } TEST_F(AclFamilyTest, AclDelUser) { @@ -321,4 +340,15 @@ TEST_F(AclFamilyTest, AclGenPass) { EXPECT_THAT(resp.GetString().length(), 3); } +TEST_F(AclFamilyTestRename, AclRename) { + auto resp = Run({"ACL", "SETUSER", "billy"}); + EXPECT_THAT(resp, ErrArg("ERR unknown command `ACL`")); + + resp = Run({"ROCKS", "SETUSER", "billy", "ON", ">mypass"}); + EXPECT_THAT(resp.GetString(), "OK"); + + resp = Run({"ROCKS", "DELUSER", "billy"}); + EXPECT_THAT(resp.GetString(), "OK"); +} + } // namespace dfly diff --git a/src/server/command_registry.cc b/src/server/command_registry.cc index 5e5d0071d..80fcc6f43 100644 --- a/src/server/command_registry.cc +++ b/src/server/command_registry.cc @@ -7,6 +7,8 @@ #include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "base/bits.h" #include "base/flags.h" @@ -113,22 +115,31 @@ void CommandRegistry::Init(unsigned int thread_count) { } CommandRegistry& CommandRegistry::operator<<(CommandId cmd) { - string_view k = cmd.name(); - auto it = cmd_rename_map_.find(k); + auto k = cmd.name(); + + absl::InlinedVector maybe_subcommand = StrSplit(cmd.name(), " "); + const bool is_sub_command = maybe_subcommand.size() == 2; + auto it = cmd_rename_map_.find(maybe_subcommand.front()); if (it != cmd_rename_map_.end()) { if (it->second.empty()) { return *this; // Incase of empty string we want to remove the command from registry. } - k = it->second; + k = is_sub_command ? absl::StrCat(it->second, " ", maybe_subcommand[1]) : it->second; } if (restricted_cmds_.find(k) != restricted_cmds_.end()) { cmd.SetRestricted(true); } - family_of_commands_.back().push_back(std::string(k)); cmd.SetFamily(family_of_commands_.size() - 1); - cmd.SetBitIndex(1ULL << bit_index_++); + if (!is_sub_command) { + cmd.SetBitIndex(1ULL << bit_index_); + family_of_commands_.back().push_back(std::string(k)); + ++bit_index_; + } else { + DCHECK(absl::StartsWith(k, family_of_commands_.back().back())); + cmd.SetBitIndex(1ULL << (bit_index_ - 1)); + } CHECK(cmd_map_.emplace(k, std::move(cmd)).second) << k; return *this; @@ -139,6 +150,13 @@ void CommandRegistry::StartFamily() { bit_index_ = 0; } +std::string_view CommandRegistry::RenamedOrOriginal(std::string_view orig) const { + if (cmd_rename_map_.contains(orig)) { + return cmd_rename_map_.find(orig)->second; + } + return orig; +} + CommandRegistry::FamiliesVec CommandRegistry::GetFamilies() { return std::move(family_of_commands_); } diff --git a/src/server/command_registry.h b/src/server/command_registry.h index a689f81ac..a56ec1b06 100644 --- a/src/server/command_registry.h +++ b/src/server/command_registry.h @@ -163,12 +163,15 @@ class CommandRegistry { } } - using FamiliesVec = std::vector>; void StartFamily(); + + std::string_view RenamedOrOriginal(std::string_view orig) const; + + using FamiliesVec = std::vector>; FamiliesVec GetFamilies(); private: - absl::flat_hash_map cmd_map_; + absl::flat_hash_map cmd_map_; absl::flat_hash_map cmd_rename_map_; absl::flat_hash_set restricted_cmds_; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 72f582b51..7a85b46a4 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1576,19 +1576,20 @@ optional StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptPa return false; } -static std::string FullAclCommandFromArgs(CmdArgList args) { +static std::string FullAclCommandFromArgs(CmdArgList args, std::string_view name) { ToUpper(&args[1]); - // Guranteed SSO no dynamic allocations here - return std::string("ACL ") + std::string(args[1].begin(), args[1].end()); + auto res = absl::StrCat(name, " ", ArgS(args, 1)); + return res; } std::pair Service::FindCmd(CmdArgList args) const { const std::string_view command = facade::ToSV(args[0]); - if (command == "ACL") { + std::string_view acl = "ACL"; + if (command == registry_.RenamedOrOriginal(acl)) { if (args.size() == 1) { return {registry_.Find(ArgS(args, 0)), args}; } - return {registry_.Find(FullAclCommandFromArgs(args)), args.subspan(2)}; + return {registry_.Find(FullAclCommandFromArgs(args, command)), args.subspan(2)}; } const CommandId* res = registry_.Find(ArgS(args, 0));