feat(server): Add support for command aliasing (#4932)

Add support for command aliasing using command_alias flag

Signed-off-by: Abhijat Malviya <abhijat@dragonflydb.io>
This commit is contained in:
Abhijat Malviya 2025-04-21 13:29:04 +05:30 committed by GitHub
parent 7ffe812967
commit 0fafa21722
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 124 additions and 84 deletions

View file

@ -591,7 +591,7 @@ void AclFamily::DryRun(CmdArgList args, const CommandContext& cmd_cntx) {
string command = absl::AsciiStrToUpper(ArgS(args, 1));
auto* cid = cmd_registry_->Find(command);
if (!cid) {
if (!cid || cid->IsAlias()) {
auto error = absl::StrCat("Command '", command, "' not found");
rb->SendError(error);
return;
@ -1062,7 +1062,7 @@ std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
std::string_view command) const {
if (absl::StartsWith(command, "+")) {
auto res = cmd_registry_->Find(command.substr(1));
if (!res) {
if (!res || res->IsAlias()) {
return {};
}
std::pair<size_t, uint64_t> cmd{res->GetFamily(), res->GetBitIndex()};
@ -1071,7 +1071,7 @@ std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
if (absl::StartsWith(command, "-")) {
auto res = cmd_registry_->Find(command.substr(1));
if (!res) {
if (!res || res->IsAlias()) {
return {};
}
std::pair<size_t, uint64_t> cmd{res->GetFamily(), res->GetBitIndex()};

View file

@ -19,6 +19,7 @@
using namespace testing;
ABSL_DECLARE_FLAG(std::vector<std::string>, rename_command);
ABSL_DECLARE_FLAG(std::vector<std::string>, command_alias);
namespace dfly {
@ -29,6 +30,7 @@ class AclFamilyTest : public BaseFamilyTest {
class AclFamilyTestRename : public BaseFamilyTest {
void SetUp() override {
absl::SetFlag(&FLAGS_rename_command, {"ACL=ROCKS"});
absl::SetFlag(&FLAGS_command_alias, {"___SET=SET"});
ResetService();
}
};
@ -538,4 +540,22 @@ TEST_F(AclFamilyTest, TestPubSub) {
EXPECT_THAT(vec[9], "resetchannels &foo");
}
TEST_F(AclFamilyTest, TestAlias) {
auto resp = Run({"ACL", "SETUSER", "luke", "+___SET"});
EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter +___SET"));
resp = Run({"ACL", "SETUSER", "leia", "-___SET"});
EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter -___SET"));
resp = Run({"ACL", "SETUSER", "anakin", "+SET"});
EXPECT_EQ(resp, "OK");
resp = Run({"ACL", "SETUSER", "jarjar", "allcommands"});
EXPECT_EQ(resp, "OK");
resp = Run({"ACL", "DRYRUN", "jarjar", "___SET"});
EXPECT_THAT(resp, ErrArg("ERR Command '___SET' not found"));
EXPECT_EQ(Run({"ACL", "DRYRUN", "jarjar", "SET"}), "OK");
}
} // namespace dfly

View file

@ -66,6 +66,10 @@ bool ValidateCommand(const std::vector<uint64_t>& acl_commands, const CommandId&
return true;
}
if (id.IsAlias()) {
return false;
}
std::pair<bool, AclLog::Reason> auth_res;
if (id.IsPubSub() || id.IsShardedPSub()) {

View file

@ -22,14 +22,16 @@ using namespace std;
ABSL_FLAG(vector<string>, rename_command, {},
"Change the name of commands, format is: <cmd1_name>=<cmd1_new_name>, "
"<cmd2_name>=<cmd2_new_name>");
ABSL_FLAG(vector<string>, command_alias, {},
"Add an alias for given commands, format is: <alias>=<original>, "
"<alias>=<original>");
ABSL_FLAG(vector<string>, restricted_commands, {},
"Commands restricted to connections on the admin port");
ABSL_FLAG(vector<string>, oom_deny_commands, {},
"Additinal commands that will be marked as denyoom");
ABSL_FLAG(vector<string>, command_alias, {},
"Add an alias for given command(s), format is: <alias>=<original>, <alias>=<original>. "
"Aliases must be set identically on replicas, if applicable");
namespace dfly {
using namespace facade;
@ -75,16 +77,17 @@ uint32_t ImplicitAclCategories(uint32_t mask) {
return out;
}
absl::flat_hash_map<std::string, std::string> ParseCmdlineArgMap(
const absl::Flag<std::vector<std::string>>& flag, const bool allow_duplicates = false) {
using CmdLineMapping = absl::flat_hash_map<std::string, std::string>;
CmdLineMapping ParseCmdlineArgMap(const absl::Flag<std::vector<std::string>>& flag) {
const auto& mappings = absl::GetFlag(flag);
absl::flat_hash_map<std::string, std::string> parsed_mappings;
CmdLineMapping parsed_mappings;
parsed_mappings.reserve(mappings.size());
for (const std::string& mapping : mappings) {
std::vector<std::string_view> kv = absl::StrSplit(mapping, '=');
absl::InlinedVector<std::string_view, 2> kv = absl::StrSplit(mapping, '=');
if (kv.size() != 2) {
LOG(ERROR) << "Malformed command " << mapping << " for " << flag.Name()
LOG(ERROR) << "Malformed command '" << mapping << "' for " << flag.Name()
<< ", expected key=value";
exit(1);
}
@ -97,8 +100,7 @@ absl::flat_hash_map<std::string, std::string> ParseCmdlineArgMap(
exit(1);
}
const bool inserted = parsed_mappings.emplace(std::move(key), std::move(value)).second;
if (!allow_duplicates && !inserted) {
if (!parsed_mappings.emplace(std::move(key), std::move(value)).second) {
LOG(ERROR) << "Duplicate insert to " << flag.Name() << " not allowed";
exit(1);
}
@ -106,6 +108,19 @@ absl::flat_hash_map<std::string, std::string> ParseCmdlineArgMap(
return parsed_mappings;
}
CmdLineMapping OriginalToAliasMap() {
CmdLineMapping original_to_alias;
CmdLineMapping alias_to_original = ParseCmdlineArgMap(FLAGS_command_alias);
original_to_alias.reserve(alias_to_original.size());
std::for_each(std::make_move_iterator(alias_to_original.begin()),
std::make_move_iterator(alias_to_original.end()),
[&original_to_alias](auto&& pair) {
original_to_alias.emplace(std::move(pair.second), std::move(pair.first));
});
return original_to_alias;
}
} // namespace
CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key,
@ -115,6 +130,17 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first
implicit_acl_ = !acl_categories.has_value();
}
CommandId CommandId::Clone(const std::string_view name) const {
CommandId cloned =
CommandId{name.data(), opt_mask_, arity_, first_key_, last_key_, acl_categories_};
cloned.handler_ = handler_;
cloned.opt_mask_ = opt_mask_ | CO::HIDDEN;
cloned.acl_categories_ = acl_categories_;
cloned.implicit_acl_ = implicit_acl_;
cloned.is_alias_ = true;
return cloned;
}
bool CommandId::IsTransactional() const {
if (first_key_ > 0 || (opt_mask_ & CO::GLOBAL_TRANS) || (opt_mask_ & CO::NO_KEY_TRANSACTIONAL))
return true;
@ -130,8 +156,7 @@ bool CommandId::IsMultiTransactional() const {
return CO::IsTransKind(name()) || CO::IsEvalKind(name());
}
uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx,
std::string_view orig_cmd_name) const {
uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx) const {
int64_t before = absl::GetCurrentTimeNanos();
handler_(args, cmd_cntx);
int64_t after = absl::GetCurrentTimeNanos();
@ -139,7 +164,7 @@ uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx,
ServerState* ss = ServerState::tlocal(); // Might have migrated thread, read after invocation
int64_t execution_time_usec = (after - before) / 1000;
auto& ent = command_stats_[ss->thread_index()][orig_cmd_name];
auto& ent = command_stats_[ss->thread_index()];
++ent.first;
ent.second += execution_time_usec;
@ -169,7 +194,6 @@ optional<facade::ErrorReply> CommandId::Validate(CmdArgList tail_args) const {
CommandRegistry::CommandRegistry() {
cmd_rename_map_ = ParseCmdlineArgMap(FLAGS_rename_command);
cmd_aliases_ = ParseCmdlineArgMap(FLAGS_command_alias, true);
for (string name : GetFlag(FLAGS_restricted_commands)) {
restricted_cmds_.emplace(AsciiStrToUpper(name));
@ -181,9 +205,20 @@ CommandRegistry::CommandRegistry() {
}
void CommandRegistry::Init(unsigned int thread_count) {
const CmdLineMapping original_to_alias = OriginalToAliasMap();
absl::flat_hash_map<std::string, CommandId> alias_to_command_id;
alias_to_command_id.reserve(original_to_alias.size());
for (auto& [_, cmd] : cmd_map_) {
cmd.Init(thread_count);
if (auto it = original_to_alias.find(cmd.name()); it != original_to_alias.end()) {
auto alias_cmd = cmd.Clone(it->second);
alias_cmd.Init(thread_count);
alias_to_command_id.insert({it->second, std::move(alias_cmd)});
}
}
std::copy(std::make_move_iterator(alias_to_command_id.begin()),
std::make_move_iterator(alias_to_command_id.end()),
std::inserter(cmd_map_, cmd_map_.end()));
}
CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
@ -212,7 +247,7 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
if (!is_sub_command || absl::StartsWith(cmd.name(), "ACL")) {
cmd.SetBitIndex(1ULL << bit_index_);
family_of_commands_.back().push_back(std::string(k));
family_of_commands_.back().emplace_back(k);
++bit_index_;
} else {
DCHECK(absl::StartsWith(k, family_of_commands_.back().back()));
@ -266,10 +301,6 @@ std::pair<const CommandId*, ArgSlice> CommandRegistry::FindExtended(string_view
return {res, tail_args};
}
bool CommandRegistry::IsAlias(std::string_view cmd) const {
return cmd_aliases_.contains(cmd);
}
namespace CO {
const char* OptName(CO::CommandOpt fl) {

View file

@ -71,9 +71,8 @@ static_assert(!IsEvalKind(""));
}; // namespace CO
// Per thread vector of command stats. Each entry is:
// command invocation string -> {cmd_calls, cmd_latency_agg in usec}.
using CmdCallStats = absl::flat_hash_map<std::string, std::pair<uint64_t, uint64_t>>;
// Per thread vector of command stats. Each entry is {cmd_calls, cmd_latency_agg in usec}.
using CmdCallStats = std::pair<uint64_t, uint64_t>;
struct CommandContext {
CommandContext(Transaction* _tx, facade::SinkReplyBuilder* _rb, ConnectionContext* cntx)
@ -94,6 +93,8 @@ class CommandId : public facade::CommandId {
CommandId(CommandId&&) = default;
[[nodiscard]] CommandId Clone(std::string_view name) const;
void Init(unsigned thread_count) {
command_stats_ = std::make_unique<CmdCallStats[]>(thread_count);
}
@ -103,10 +104,8 @@ class CommandId : public facade::CommandId {
using ArgValidator = fu2::function_base<true, true, fu2::capacity_default, false, false,
std::optional<facade::ErrorReply>(CmdArgList) const>;
// Invokes the command handler. Returns the invoke time in usec. The invoked_by parameter is set
// to the string passed in by user, if available. If not set, defaults to command name.
uint64_t Invoke(CmdArgList args, const CommandContext& cmd_cntx,
std::string_view orig_cmd_name) const;
// Returns the invoke time in usec.
uint64_t Invoke(CmdArgList args, const CommandContext& cmd_cntx) const;
// Returns error if validation failed, otherwise nullopt
std::optional<facade::ErrorReply> Validate(CmdArgList tail_args) const;
@ -144,7 +143,7 @@ class CommandId : public facade::CommandId {
}
void ResetStats(unsigned thread_index) {
command_stats_[thread_index].clear();
command_stats_[thread_index] = {0, 0};
}
CmdCallStats GetStats(unsigned thread_index) const {
@ -156,11 +155,16 @@ class CommandId : public facade::CommandId {
acl_categories_ |= mask;
}
bool IsAlias() const {
return is_alias_;
}
private:
bool implicit_acl_;
std::unique_ptr<CmdCallStats[]> command_stats_;
Handler3 handler_;
ArgValidator validator_;
bool is_alias_{false};
};
class CommandRegistry {
@ -172,16 +176,8 @@ class CommandRegistry {
CommandRegistry& operator<<(CommandId cmd);
const CommandId* Find(std::string_view cmd) const {
if (const auto it = cmd_map_.find(cmd); it != cmd_map_.end()) {
return &it->second;
}
if (const auto it = cmd_aliases_.find(cmd); it != cmd_aliases_.end()) {
if (const auto alias_lookup = cmd_map_.find(it->second); alias_lookup != cmd_map_.end()) {
return &alias_lookup->second;
}
}
return nullptr;
auto it = cmd_map_.find(cmd);
return it == cmd_map_.end() ? nullptr : &it->second;
}
CommandId* Find(std::string_view cmd) {
@ -203,17 +199,13 @@ class CommandRegistry {
}
}
void MergeCallStats(
unsigned thread_index,
std::function<void(std::string_view, const CmdCallStats::mapped_type&)> cb) const {
for (const auto& [_, cmd_id] : cmd_map_) {
for (const auto& [cmd_name, call_stats] : cmd_id.GetStats(thread_index)) {
if (call_stats.first == 0) {
continue;
}
cb(cmd_name, call_stats);
}
void MergeCallStats(unsigned thread_index,
std::function<void(std::string_view, const CmdCallStats&)> cb) const {
for (const auto& k_v : cmd_map_) {
auto src = k_v.second.GetStats(thread_index);
if (src.first == 0)
continue;
cb(k_v.second.name(), src);
}
}
@ -227,16 +219,9 @@ class CommandRegistry {
std::pair<const CommandId*, facade::ArgSlice> FindExtended(std::string_view cmd,
facade::ArgSlice tail_args) const;
bool IsAlias(std::string_view cmd) const;
private:
absl::flat_hash_map<std::string, CommandId> cmd_map_;
absl::flat_hash_map<std::string, std::string> cmd_rename_map_;
// Stores a mapping from alias to original command. During the find operation, the first lookup is
// done in the cmd_map_, then in the alias map. This results in two lookups but only for commands
// which are not in original map, ie either typos or aliases. While it would be faster, we cannot
// store iterators into cmd_map_ here as they may be invalidated on rehashing.
absl::flat_hash_map<std::string, std::string> cmd_aliases_;
absl::flat_hash_set<std::string> restricted_cmds_;
absl::flat_hash_set<std::string> oomdeny_cmds_;

View file

@ -9,7 +9,7 @@ extern "C" {
#include <absl/strings/ascii.h>
#include <absl/strings/str_join.h>
#include <absl/strings/str_split.h>
#include <absl/strings/strip.h>
#include <gmock/gmock.h>
#include <reflex/matcher.h>
@ -17,6 +17,7 @@ extern "C" {
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
#include "server/conn_context.h"
#include "server/main_service.h"
#include "server/test_utils.h"
@ -24,9 +25,9 @@ ABSL_DECLARE_FLAG(float, mem_defrag_threshold);
ABSL_DECLARE_FLAG(float, mem_defrag_waste_threshold);
ABSL_DECLARE_FLAG(uint32_t, mem_defrag_check_sec_interval);
ABSL_DECLARE_FLAG(std::vector<std::string>, rename_command);
ABSL_DECLARE_FLAG(std::vector<std::string>, command_alias);
ABSL_DECLARE_FLAG(bool, lua_resp2_legacy_float);
ABSL_DECLARE_FLAG(double, eviction_memory_budget_threshold);
ABSL_DECLARE_FLAG(std::vector<std::string>, command_alias);
namespace dfly {
@ -118,7 +119,8 @@ class DflyRenameCommandTest : public DflyEngineTest {
&FLAGS_rename_command,
std::vector<std::string>({"flushall=myflushall", "flushdb=", "ping=abcdefghijklmnop"}));
}
absl::FlagSaver saver_;
absl::FlagSaver _saver;
};
TEST_F(DflyRenameCommandTest, RenameCommand) {
@ -848,9 +850,7 @@ TEST_F(DflyEngineTest, CommandMetricLabels) {
class DflyCommandAliasTest : public DflyEngineTest {
protected:
DflyCommandAliasTest() {
// Test an interaction of rename and alias, where we rename and then add an alias on the rename
absl::SetFlag(&FLAGS_rename_command, {"ping=gnip"});
absl::SetFlag(&FLAGS_command_alias, {"___set=set", "___ping=gnip"});
absl::SetFlag(&FLAGS_command_alias, {"___set=set", "___ping=ping"});
}
absl::FlagSaver saver_;
@ -861,19 +861,26 @@ TEST_F(DflyCommandAliasTest, Aliasing) {
EXPECT_EQ(Run({"___SET", "a", "b"}), "OK");
EXPECT_EQ(Run({"GET", "foo"}), "bar");
EXPECT_EQ(Run({"GET", "a"}), "b");
// test the alias
EXPECT_EQ(Run({"___ping"}), "PONG");
// test the rename
EXPECT_EQ(Run({"gnip"}), "PONG");
// the original command is not accessible
EXPECT_THAT(Run({"PING"}), ErrArg("unknown command `PING`"));
const Metrics metrics = GetMetrics();
Metrics metrics = GetMetrics();
const auto& stats = metrics.cmd_stats_map;
EXPECT_THAT(stats, Contains(Pair("___set", Key(1))));
EXPECT_THAT(stats, Contains(Pair("set", Key(1))));
EXPECT_THAT(stats, Contains(Pair("___ping", Key(1))));
EXPECT_THAT(stats, Contains(Pair("get", Key(2))));
// test stats within multi-exec
EXPECT_EQ(Run({"multi"}), "OK");
EXPECT_EQ(Run({"___set", "a", "x"}), "QUEUED");
EXPECT_EQ(Run({"exec"}), "OK");
metrics = GetMetrics();
EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("___set", Key(2))));
EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("set", Key(1))));
EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("multi", Key(1))));
EXPECT_THAT(metrics.cmd_stats_map, Contains(Pair("exec", Key(1))));
}
} // namespace dfly

View file

@ -4,6 +4,7 @@
#include "server/main_service.h"
#include "absl/strings/str_split.h"
#include "facade/resp_expr.h"
#include "util/fibers/synchronization.h"
@ -1246,13 +1247,7 @@ void Service::DispatchCommand(ArgSlice args, SinkReplyBuilder* builder,
dfly_cntx->cid = cid;
// If cmd is an alias, pass it to Invoke so the stats are updated against the alias. By defaults
// stats will be updated for cid.name
std::optional<std::string_view> orig_cmd_name = std::nullopt;
if (registry_.IsAlias(cmd)) {
orig_cmd_name = cmd;
}
if (!InvokeCmd(cid, args_no_cmd, builder, dfly_cntx, orig_cmd_name)) {
if (!InvokeCmd(cid, args_no_cmd, builder, dfly_cntx)) {
builder->SendError("Internal Error");
builder->CloseConnection();
}
@ -1313,7 +1308,7 @@ OpResult<void> OpTrackKeys(const OpArgs slice_args, const facade::Connection::We
}
bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, SinkReplyBuilder* builder,
ConnectionContext* cntx, std::optional<std::string_view> orig_cmd_name) {
ConnectionContext* cntx) {
DCHECK(cid);
DCHECK(!cid->Validate(tail_args));
@ -1362,8 +1357,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, SinkReplyBui
auto last_error = builder->ConsumeLastError();
DCHECK(last_error.empty());
try {
invoke_time_usec = cid->Invoke(tail_args, CommandContext{tx, builder, cntx},
orig_cmd_name.value_or(cid->name()));
invoke_time_usec = cid->Invoke(tail_args, CommandContext{tx, builder, cntx});
} catch (std::exception& e) {
LOG(ERROR) << "Internal error, system probably unstable " << e.what();
return false;

View file

@ -45,8 +45,7 @@ class Service : public facade::ServiceInterface {
// Check VerifyCommandExecution and invoke command with args
bool InvokeCmd(const CommandId* cid, CmdArgList tail_args, facade::SinkReplyBuilder* builder,
ConnectionContext* reply_cntx,
std::optional<std::string_view> orig_cmd_name = std::nullopt);
ConnectionContext* reply_cntx);
// Verify command can be executed now (check out of memory), always called immediately before
// execution

View file

@ -2219,8 +2219,7 @@ Metrics ServerFamily::GetMetrics(Namespace* ns) const {
uint64_t start = absl::GetCurrentTimeNanos();
auto cmd_stat_cb = [&dest = result.cmd_stats_map](string_view name,
const CmdCallStats::mapped_type& stat) {
auto cmd_stat_cb = [&dest = result.cmd_stats_map](string_view name, const CmdCallStats& stat) {
auto& [calls, sum] = dest[absl::AsciiStrToLower(name)];
calls += stat.first;
sum += stat.second;

View file

@ -2959,6 +2959,7 @@ async def test_preempt_in_atomic_section_of_heartbeat(df_factory: DflyInstanceFa
await fill_task
@pytest.mark.skip("temporarily skipped")
async def test_bug_in_json_memory_tracking(df_factory: DflyInstanceFactory):
"""
This test reproduces a bug in the JSON memory tracking.