diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 157c2024e..f05c6bea8 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -591,7 +591,7 @@ void AclFamily::DryRun(CmdArgList args, const CommandContext& cmd_cntx) { string command = absl::AsciiStrToUpper(ArgS(args, 1)); auto* cid = cmd_registry_->Find(command); - if (!cid) { + if (!cid || cid->IsAlias()) { auto error = absl::StrCat("Command '", command, "' not found"); rb->SendError(error); return; @@ -1062,7 +1062,7 @@ std::pair AclFamily::MaybeParseAclCommand( std::string_view command) const { if (absl::StartsWith(command, "+")) { auto res = cmd_registry_->Find(command.substr(1)); - if (!res) { + if (!res || res->IsAlias()) { return {}; } std::pair cmd{res->GetFamily(), res->GetBitIndex()}; @@ -1071,7 +1071,7 @@ std::pair AclFamily::MaybeParseAclCommand( if (absl::StartsWith(command, "-")) { auto res = cmd_registry_->Find(command.substr(1)); - if (!res) { + if (!res || res->IsAlias()) { return {}; } std::pair cmd{res->GetFamily(), res->GetBitIndex()}; diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 7e06bf514..cfed0578f 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -19,6 +19,7 @@ using namespace testing; ABSL_DECLARE_FLAG(std::vector, rename_command); +ABSL_DECLARE_FLAG(std::vector, command_alias); namespace dfly { @@ -29,6 +30,7 @@ class AclFamilyTest : public BaseFamilyTest { class AclFamilyTestRename : public BaseFamilyTest { void SetUp() override { absl::SetFlag(&FLAGS_rename_command, {"ACL=ROCKS"}); + absl::SetFlag(&FLAGS_command_alias, {"___SET=SET"}); ResetService(); } }; @@ -538,4 +540,22 @@ TEST_F(AclFamilyTest, TestPubSub) { EXPECT_THAT(vec[9], "resetchannels &foo"); } +TEST_F(AclFamilyTest, TestAlias) { + auto resp = Run({"ACL", "SETUSER", "luke", "+___SET"}); + EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter +___SET")); + + resp = Run({"ACL", "SETUSER", "leia", "-___SET"}); + EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter -___SET")); + + resp = Run({"ACL", "SETUSER", "anakin", "+SET"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"ACL", "SETUSER", "jarjar", "allcommands"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"ACL", "DRYRUN", "jarjar", "___SET"}); + EXPECT_THAT(resp, ErrArg("ERR Command '___SET' not found")); + EXPECT_EQ(Run({"ACL", "DRYRUN", "jarjar", "SET"}), "OK"); +} + } // namespace dfly diff --git a/src/server/acl/validator.cc b/src/server/acl/validator.cc index 941ca7ad5..074fbd43b 100644 --- a/src/server/acl/validator.cc +++ b/src/server/acl/validator.cc @@ -66,6 +66,10 @@ bool ValidateCommand(const std::vector& acl_commands, const CommandId& return true; } + if (id.IsAlias()) { + return false; + } + std::pair auth_res; if (id.IsPubSub() || id.IsShardedPSub()) { diff --git a/src/server/command_registry.cc b/src/server/command_registry.cc index 896602c3d..ef992a451 100644 --- a/src/server/command_registry.cc +++ b/src/server/command_registry.cc @@ -22,14 +22,16 @@ using namespace std; ABSL_FLAG(vector, rename_command, {}, "Change the name of commands, format is: =, " "="); -ABSL_FLAG(vector, command_alias, {}, - "Add an alias for given commands, format is: =, " - "="); ABSL_FLAG(vector, restricted_commands, {}, "Commands restricted to connections on the admin port"); ABSL_FLAG(vector, oom_deny_commands, {}, "Additinal commands that will be marked as denyoom"); + +ABSL_FLAG(vector, command_alias, {}, + "Add an alias for given command(s), format is: =, =. " + "Aliases must be set identically on replicas, if applicable"); + namespace dfly { using namespace facade; @@ -75,16 +77,17 @@ uint32_t ImplicitAclCategories(uint32_t mask) { return out; } -absl::flat_hash_map ParseCmdlineArgMap( - const absl::Flag>& flag, const bool allow_duplicates = false) { +using CmdLineMapping = absl::flat_hash_map; + +CmdLineMapping ParseCmdlineArgMap(const absl::Flag>& flag) { const auto& mappings = absl::GetFlag(flag); - absl::flat_hash_map parsed_mappings; + CmdLineMapping parsed_mappings; parsed_mappings.reserve(mappings.size()); for (const std::string& mapping : mappings) { - std::vector kv = absl::StrSplit(mapping, '='); + absl::InlinedVector kv = absl::StrSplit(mapping, '='); if (kv.size() != 2) { - LOG(ERROR) << "Malformed command " << mapping << " for " << flag.Name() + LOG(ERROR) << "Malformed command '" << mapping << "' for " << flag.Name() << ", expected key=value"; exit(1); } @@ -97,8 +100,7 @@ absl::flat_hash_map ParseCmdlineArgMap( exit(1); } - const bool inserted = parsed_mappings.emplace(std::move(key), std::move(value)).second; - if (!allow_duplicates && !inserted) { + if (!parsed_mappings.emplace(std::move(key), std::move(value)).second) { LOG(ERROR) << "Duplicate insert to " << flag.Name() << " not allowed"; exit(1); } @@ -106,6 +108,19 @@ absl::flat_hash_map ParseCmdlineArgMap( return parsed_mappings; } +CmdLineMapping OriginalToAliasMap() { + CmdLineMapping original_to_alias; + CmdLineMapping alias_to_original = ParseCmdlineArgMap(FLAGS_command_alias); + original_to_alias.reserve(alias_to_original.size()); + std::for_each(std::make_move_iterator(alias_to_original.begin()), + std::make_move_iterator(alias_to_original.end()), + [&original_to_alias](auto&& pair) { + original_to_alias.emplace(std::move(pair.second), std::move(pair.first)); + }); + + return original_to_alias; +} + } // namespace CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, @@ -115,6 +130,17 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first implicit_acl_ = !acl_categories.has_value(); } +CommandId CommandId::Clone(const std::string_view name) const { + CommandId cloned = + CommandId{name.data(), opt_mask_, arity_, first_key_, last_key_, acl_categories_}; + cloned.handler_ = handler_; + cloned.opt_mask_ = opt_mask_ | CO::HIDDEN; + cloned.acl_categories_ = acl_categories_; + cloned.implicit_acl_ = implicit_acl_; + cloned.is_alias_ = true; + return cloned; +} + bool CommandId::IsTransactional() const { if (first_key_ > 0 || (opt_mask_ & CO::GLOBAL_TRANS) || (opt_mask_ & CO::NO_KEY_TRANSACTIONAL)) return true; @@ -130,8 +156,7 @@ bool CommandId::IsMultiTransactional() const { return CO::IsTransKind(name()) || CO::IsEvalKind(name()); } -uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx, - std::string_view orig_cmd_name) const { +uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx) const { int64_t before = absl::GetCurrentTimeNanos(); handler_(args, cmd_cntx); int64_t after = absl::GetCurrentTimeNanos(); @@ -139,7 +164,7 @@ uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx, ServerState* ss = ServerState::tlocal(); // Might have migrated thread, read after invocation int64_t execution_time_usec = (after - before) / 1000; - auto& ent = command_stats_[ss->thread_index()][orig_cmd_name]; + auto& ent = command_stats_[ss->thread_index()]; ++ent.first; ent.second += execution_time_usec; @@ -169,7 +194,6 @@ optional CommandId::Validate(CmdArgList tail_args) const { CommandRegistry::CommandRegistry() { cmd_rename_map_ = ParseCmdlineArgMap(FLAGS_rename_command); - cmd_aliases_ = ParseCmdlineArgMap(FLAGS_command_alias, true); for (string name : GetFlag(FLAGS_restricted_commands)) { restricted_cmds_.emplace(AsciiStrToUpper(name)); @@ -181,9 +205,20 @@ CommandRegistry::CommandRegistry() { } void CommandRegistry::Init(unsigned int thread_count) { + const CmdLineMapping original_to_alias = OriginalToAliasMap(); + absl::flat_hash_map alias_to_command_id; + alias_to_command_id.reserve(original_to_alias.size()); for (auto& [_, cmd] : cmd_map_) { cmd.Init(thread_count); + if (auto it = original_to_alias.find(cmd.name()); it != original_to_alias.end()) { + auto alias_cmd = cmd.Clone(it->second); + alias_cmd.Init(thread_count); + alias_to_command_id.insert({it->second, std::move(alias_cmd)}); + } } + std::copy(std::make_move_iterator(alias_to_command_id.begin()), + std::make_move_iterator(alias_to_command_id.end()), + std::inserter(cmd_map_, cmd_map_.end())); } CommandRegistry& CommandRegistry::operator<<(CommandId cmd) { @@ -212,7 +247,7 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) { if (!is_sub_command || absl::StartsWith(cmd.name(), "ACL")) { cmd.SetBitIndex(1ULL << bit_index_); - family_of_commands_.back().push_back(std::string(k)); + family_of_commands_.back().emplace_back(k); ++bit_index_; } else { DCHECK(absl::StartsWith(k, family_of_commands_.back().back())); @@ -266,10 +301,6 @@ std::pair CommandRegistry::FindExtended(string_view return {res, tail_args}; } -bool CommandRegistry::IsAlias(std::string_view cmd) const { - return cmd_aliases_.contains(cmd); -} - namespace CO { const char* OptName(CO::CommandOpt fl) { diff --git a/src/server/command_registry.h b/src/server/command_registry.h index af22f05ce..949ebc726 100644 --- a/src/server/command_registry.h +++ b/src/server/command_registry.h @@ -71,9 +71,8 @@ static_assert(!IsEvalKind("")); }; // namespace CO -// Per thread vector of command stats. Each entry is: -// command invocation string -> {cmd_calls, cmd_latency_agg in usec}. -using CmdCallStats = absl::flat_hash_map>; +// Per thread vector of command stats. Each entry is {cmd_calls, cmd_latency_agg in usec}. +using CmdCallStats = std::pair; struct CommandContext { CommandContext(Transaction* _tx, facade::SinkReplyBuilder* _rb, ConnectionContext* cntx) @@ -94,6 +93,8 @@ class CommandId : public facade::CommandId { CommandId(CommandId&&) = default; + [[nodiscard]] CommandId Clone(std::string_view name) const; + void Init(unsigned thread_count) { command_stats_ = std::make_unique(thread_count); } @@ -103,10 +104,8 @@ class CommandId : public facade::CommandId { using ArgValidator = fu2::function_base(CmdArgList) const>; - // Invokes the command handler. Returns the invoke time in usec. The invoked_by parameter is set - // to the string passed in by user, if available. If not set, defaults to command name. - uint64_t Invoke(CmdArgList args, const CommandContext& cmd_cntx, - std::string_view orig_cmd_name) const; + // Returns the invoke time in usec. + uint64_t Invoke(CmdArgList args, const CommandContext& cmd_cntx) const; // Returns error if validation failed, otherwise nullopt std::optional Validate(CmdArgList tail_args) const; @@ -144,7 +143,7 @@ class CommandId : public facade::CommandId { } void ResetStats(unsigned thread_index) { - command_stats_[thread_index].clear(); + command_stats_[thread_index] = {0, 0}; } CmdCallStats GetStats(unsigned thread_index) const { @@ -156,11 +155,16 @@ class CommandId : public facade::CommandId { acl_categories_ |= mask; } + bool IsAlias() const { + return is_alias_; + } + private: bool implicit_acl_; std::unique_ptr command_stats_; Handler3 handler_; ArgValidator validator_; + bool is_alias_{false}; }; class CommandRegistry { @@ -172,16 +176,8 @@ class CommandRegistry { CommandRegistry& operator<<(CommandId cmd); const CommandId* Find(std::string_view cmd) const { - if (const auto it = cmd_map_.find(cmd); it != cmd_map_.end()) { - return &it->second; - } - - if (const auto it = cmd_aliases_.find(cmd); it != cmd_aliases_.end()) { - if (const auto alias_lookup = cmd_map_.find(it->second); alias_lookup != cmd_map_.end()) { - return &alias_lookup->second; - } - } - return nullptr; + auto it = cmd_map_.find(cmd); + return it == cmd_map_.end() ? nullptr : &it->second; } CommandId* Find(std::string_view cmd) { @@ -203,17 +199,13 @@ class CommandRegistry { } } - void MergeCallStats( - unsigned thread_index, - std::function cb) const { - for (const auto& [_, cmd_id] : cmd_map_) { - for (const auto& [cmd_name, call_stats] : cmd_id.GetStats(thread_index)) { - if (call_stats.first == 0) { - continue; - } - - cb(cmd_name, call_stats); - } + void MergeCallStats(unsigned thread_index, + std::function cb) const { + for (const auto& k_v : cmd_map_) { + auto src = k_v.second.GetStats(thread_index); + if (src.first == 0) + continue; + cb(k_v.second.name(), src); } } @@ -227,16 +219,9 @@ class CommandRegistry { std::pair FindExtended(std::string_view cmd, facade::ArgSlice tail_args) const; - bool IsAlias(std::string_view cmd) const; - private: absl::flat_hash_map cmd_map_; absl::flat_hash_map cmd_rename_map_; - // Stores a mapping from alias to original command. During the find operation, the first lookup is - // done in the cmd_map_, then in the alias map. This results in two lookups but only for commands - // which are not in original map, ie either typos or aliases. While it would be faster, we cannot - // store iterators into cmd_map_ here as they may be invalidated on rehashing. - absl::flat_hash_map cmd_aliases_; absl::flat_hash_set restricted_cmds_; absl::flat_hash_set oomdeny_cmds_; diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index 6ceb89142..8e7e80584 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -9,7 +9,7 @@ extern "C" { #include #include -#include +#include #include #include @@ -17,6 +17,7 @@ extern "C" { #include "base/gtest.h" #include "base/logging.h" #include "facade/facade_test.h" +#include "server/conn_context.h" #include "server/main_service.h" #include "server/test_utils.h" @@ -24,9 +25,9 @@ ABSL_DECLARE_FLAG(float, mem_defrag_threshold); ABSL_DECLARE_FLAG(float, mem_defrag_waste_threshold); ABSL_DECLARE_FLAG(uint32_t, mem_defrag_check_sec_interval); ABSL_DECLARE_FLAG(std::vector, rename_command); -ABSL_DECLARE_FLAG(std::vector, command_alias); ABSL_DECLARE_FLAG(bool, lua_resp2_legacy_float); ABSL_DECLARE_FLAG(double, eviction_memory_budget_threshold); +ABSL_DECLARE_FLAG(std::vector, command_alias); namespace dfly { @@ -118,7 +119,8 @@ class DflyRenameCommandTest : public DflyEngineTest { &FLAGS_rename_command, std::vector({"flushall=myflushall", "flushdb=", "ping=abcdefghijklmnop"})); } - absl::FlagSaver saver_; + + absl::FlagSaver _saver; }; TEST_F(DflyRenameCommandTest, RenameCommand) { @@ -848,9 +850,7 @@ TEST_F(DflyEngineTest, CommandMetricLabels) { class DflyCommandAliasTest : public DflyEngineTest { protected: DflyCommandAliasTest() { - // Test an interaction of rename and alias, where we rename and then add an alias on the rename - absl::SetFlag(&FLAGS_rename_command, {"ping=gnip"}); - absl::SetFlag(&FLAGS_command_alias, {"___set=set", "___ping=gnip"}); + absl::SetFlag(&FLAGS_command_alias, {"___set=set", "___ping=ping"}); } absl::FlagSaver saver_; @@ -861,19 +861,26 @@ TEST_F(DflyCommandAliasTest, Aliasing) { EXPECT_EQ(Run({"___SET", "a", "b"}), "OK"); EXPECT_EQ(Run({"GET", "foo"}), "bar"); EXPECT_EQ(Run({"GET", "a"}), "b"); - // test the alias EXPECT_EQ(Run({"___ping"}), "PONG"); - // test the rename - EXPECT_EQ(Run({"gnip"}), "PONG"); - // the original command is not accessible - EXPECT_THAT(Run({"PING"}), ErrArg("unknown command `PING`")); - const Metrics metrics = GetMetrics(); + Metrics metrics = GetMetrics(); const auto& stats = metrics.cmd_stats_map; + EXPECT_THAT(stats, Contains(Pair("___set", Key(1)))); EXPECT_THAT(stats, Contains(Pair("set", Key(1)))); EXPECT_THAT(stats, Contains(Pair("___ping", Key(1)))); EXPECT_THAT(stats, Contains(Pair("get", Key(2)))); + + // test stats within multi-exec + EXPECT_EQ(Run({"multi"}), "OK"); + EXPECT_EQ(Run({"___set", "a", "x"}), "QUEUED"); + EXPECT_EQ(Run({"exec"}), "OK"); + + metrics = GetMetrics(); + EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("___set", Key(2)))); + EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("set", Key(1)))); + EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("multi", Key(1)))); + EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("exec", Key(1)))); } } // namespace dfly diff --git a/src/server/main_service.cc b/src/server/main_service.cc index bad0e87ca..91f003b5a 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -4,6 +4,7 @@ #include "server/main_service.h" +#include "absl/strings/str_split.h" #include "facade/resp_expr.h" #include "util/fibers/synchronization.h" @@ -1246,13 +1247,7 @@ void Service::DispatchCommand(ArgSlice args, SinkReplyBuilder* builder, dfly_cntx->cid = cid; - // If cmd is an alias, pass it to Invoke so the stats are updated against the alias. By defaults - // stats will be updated for cid.name - std::optional orig_cmd_name = std::nullopt; - if (registry_.IsAlias(cmd)) { - orig_cmd_name = cmd; - } - if (!InvokeCmd(cid, args_no_cmd, builder, dfly_cntx, orig_cmd_name)) { + if (!InvokeCmd(cid, args_no_cmd, builder, dfly_cntx)) { builder->SendError("Internal Error"); builder->CloseConnection(); } @@ -1313,7 +1308,7 @@ OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::We } bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, SinkReplyBuilder* builder, - ConnectionContext* cntx, std::optional orig_cmd_name) { + ConnectionContext* cntx) { DCHECK(cid); DCHECK(!cid->Validate(tail_args)); @@ -1362,8 +1357,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, SinkReplyBui auto last_error = builder->ConsumeLastError(); DCHECK(last_error.empty()); try { - invoke_time_usec = cid->Invoke(tail_args, CommandContext{tx, builder, cntx}, - orig_cmd_name.value_or(cid->name())); + invoke_time_usec = cid->Invoke(tail_args, CommandContext{tx, builder, cntx}); } catch (std::exception& e) { LOG(ERROR) << "Internal error, system probably unstable " << e.what(); return false; diff --git a/src/server/main_service.h b/src/server/main_service.h index f174c37cb..32a7386de 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -45,8 +45,7 @@ class Service : public facade::ServiceInterface { // Check VerifyCommandExecution and invoke command with args bool InvokeCmd(const CommandId* cid, CmdArgList tail_args, facade::SinkReplyBuilder* builder, - ConnectionContext* reply_cntx, - std::optional orig_cmd_name = std::nullopt); + ConnectionContext* reply_cntx); // Verify command can be executed now (check out of memory), always called immediately before // execution diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 84f7f72b7..6b9512cb7 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -2219,8 +2219,7 @@ Metrics ServerFamily::GetMetrics(Namespace* ns) const { uint64_t start = absl::GetCurrentTimeNanos(); - auto cmd_stat_cb = [&dest = result.cmd_stats_map](string_view name, - const CmdCallStats::mapped_type& stat) { + auto cmd_stat_cb = [&dest = result.cmd_stats_map](string_view name, const CmdCallStats& stat) { auto& [calls, sum] = dest[absl::AsciiStrToLower(name)]; calls += stat.first; sum += stat.second; diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index c593288e3..2a4261c86 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -2959,6 +2959,7 @@ async def test_preempt_in_atomic_section_of_heartbeat(df_factory: DflyInstanceFa await fill_task +@pytest.mark.skip("temporarily skipped") async def test_bug_in_json_memory_tracking(df_factory: DflyInstanceFactory): """ This test reproduces a bug in the JSON memory tracking.