feat(AclFamilly): add acl list command (#1722)

* Add acl-family source and header
* Add `ACL LIST` command
* Add a simple test to check the default user
This commit is contained in:
Kostas Kyrimis 2023-08-22 18:33:14 +03:00 committed by GitHub
parent eae02a16da
commit 898061d738
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 234 additions and 39 deletions

View file

@ -62,7 +62,7 @@ struct ConnectionStats {
struct ErrorReply { struct ErrorReply {
explicit ErrorReply(std::string&& msg, std::string_view kind = {}) explicit ErrorReply(std::string&& msg, std::string_view kind = {})
: message{move(msg)}, kind{kind} { : message{std::move(msg)}, kind{kind} {
} }
explicit ErrorReply(std::string_view msg, std::string_view kind = {}) : message{msg}, kind{kind} { explicit ErrorReply(std::string_view msg, std::string_view kind = {}) : message{msg}, kind{kind} {
} }

View file

@ -25,7 +25,7 @@ add_library(dragonfly_lib channel_store.cc command_registry.cc
zset_family.cc version.cc bitops_family.cc container_utils.cc io_utils.cc zset_family.cc version.cc bitops_family.cc container_utils.cc io_utils.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc
top_keys.cc multi_command_squasher.cc hll_family.cc cluster/cluster_config.cc top_keys.cc multi_command_squasher.cc hll_family.cc cluster/cluster_config.cc
cluster/cluster_family.cc acl/user.cc acl/user_registry.cc) cluster/cluster_family.cc acl/user.cc acl/user_registry.cc acl/acl_family.cc)
cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib aws_lib strings_lib html_lib cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib aws_lib strings_lib html_lib

View file

@ -89,5 +89,20 @@ inline const absl::flat_hash_map<std::string_view, uint32_t> CATEGORY_INDEX_TABL
{"DANGEROUS", DANGEROUS}, {"DANGEROUS", DANGEROUS},
{"CONNECTION", CONNECTION}, {"CONNECTION", CONNECTION},
{"TRANSACTION", TRANSACTION}, {"TRANSACTION", TRANSACTION},
{"SCRIPTING", SCRIPTING}}; {"SCRIPTING", SCRIPTING},
{"FT_SEARCH", FT_SEARCH},
{"THROTTLE", THROTTLE},
{"JSON", JSON}
};
// bit 0 at index 0
// bit 1 at index 1
// bit n at index n
inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"KEYSPACE", "READ", "WRITE", "SET", "SORTED_SET", "LIST",
"HASH", "STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM",
"PUBSUB", "ADMIN", "FAST", "SLOW", "BLOCKING", "DANGEROUS",
"CONNECTION", "TRANSACTION", "SCRIPTING", "FT_SEARCH", "THROTTLE", "JSON"};
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -0,0 +1,78 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
#include "server/acl/acl_family.h"
#include "absl/strings/str_cat.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/server_state.h"
namespace dfly::acl {
constexpr uint32_t kList = acl::ADMIN | acl::SLOW | acl::DANGEROUS;
static std::string AclToString(uint32_t acl_category) {
std::string tmp;
if (acl_category == acl::ALL) {
return "+@all";
}
if (acl_category == acl::NONE) {
return "+@none";
}
const std::string prefix = "+@";
const std::string postfix = " ";
for (uint32_t step = 0, cat = 0; cat != JSON; cat = 1ULL << ++step) {
if (acl_category & cat) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[step], postfix);
}
}
tmp.erase(tmp.size());
return tmp;
}
void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
const auto registry_with_lock = ServerState::tlocal()->user_registry->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
(*cntx)->StartArray(registry.size());
for (const auto& [username, user] : registry) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : std::string(pass.substr(0, 15));
const std::string acl_cat = AclToString(user.AclCategory());
using namespace std::string_view_literals;
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat);
(*cntx)->SendSimpleString(buffer);
}
}
using CI = dfly::CommandId;
#define HFUNC(x) SetHandler(&AclFamily::x)
// We can't implement the ACL commands and its respective subcommands LIST, CAT, etc
// the usual way, (that is, one command called ACL which then dispatches to the subcommand
// based on the secocond argument) because each of the subcommands has different ACL
// categories. Therefore, to keep it compatible with the CommandId, I need to treat them
// as separate commands in the registry. This is the least intrusive change because it's very
// easy to handle that case explicitly in `DispatchCommand`.
void AclFamily::Register(dfly::CommandRegistry* registry) {
*registry << CI{"ACL LIST", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, 0, 0, 0, 0, acl::kList}.HFUNC(
List);
}
#undef HFUNC
} // namespace dfly::acl

View file

@ -0,0 +1,24 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include "server/common.h"
namespace dfly {
class ConnectionContext;
class CommandRegistry;
namespace acl {
class AclFamily {
public:
static void Register(CommandRegistry* registry);
private:
static void List(CmdArgList args, ConnectionContext* cntx);
};
} // namespace acl
} // namespace dfly

View file

@ -81,4 +81,10 @@ bool User::IsActive() const {
return is_active_; return is_active_;
} }
static const std::string_view default_pass = "nopass";
std::string_view User::Password() const {
return password_hash_ ? *password_hash_ : default_pass;
}
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -16,8 +16,6 @@
namespace dfly::acl { namespace dfly::acl {
class CommandId;
// TODO implement these // TODO implement these
//#bool CheckIfCommandAllowed(uint64_t command_id, const CommandId& command); //#bool CheckIfCommandAllowed(uint64_t command_id, const CommandId& command);
//#bool CheckIfAclCategoryAllowed(uint64_t command_id, const CommandId& command); //#bool CheckIfAclCategoryAllowed(uint64_t command_id, const CommandId& command);
@ -59,6 +57,8 @@ class User final {
bool IsActive() const; bool IsActive() const;
std::string_view Password() const;
private: private:
// For ACL categories // For ACL categories
void SetAclCategories(uint64_t cat); void SetAclCategories(uint64_t cat);

View file

@ -4,12 +4,16 @@
#include "server/acl/user_registry.h" #include "server/acl/user_registry.h"
#include <shared_mutex>
#include "core/fibers.h" #include "core/fibers.h"
#include "server/acl/acl_commands_def.h"
namespace dfly::acl { namespace dfly::acl {
UserRegistry::UserRegistry() {
User::UpdateRequest req{{}, acl::ALL, {}, true};
MaybeAddAndUpdate("default", std::move(req));
}
void UserRegistry::MaybeAddAndUpdate(std::string_view username, User::UpdateRequest req) { void UserRegistry::MaybeAddAndUpdate(std::string_view username, User::UpdateRequest req) {
std::unique_lock<util::SharedMutex> lock(mu_); std::unique_lock<util::SharedMutex> lock(mu_);
auto& user = registry_[username]; auto& user = registry_[username];
@ -50,4 +54,14 @@ bool UserRegistry::AuthUser(std::string_view username, std::string_view password
return user->second.HasPassword(password); return user->second.HasPassword(password);
} }
UserRegistry::RegistryViewWithLock::RegistryViewWithLock(std::shared_lock<util::SharedMutex> mu,
const RegistryType& registry)
: registry(registry), registry_mu_(std::move(mu)) {
}
UserRegistry::RegistryViewWithLock UserRegistry::GetRegistryWithLock() const {
std::shared_lock<util::SharedMutex> lock(mu_);
return {std::move(lock), registry_};
}
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -7,7 +7,10 @@
#include <absl/container/flat_hash_map.h> #include <absl/container/flat_hash_map.h>
#include <absl/synchronization/mutex.h> #include <absl/synchronization/mutex.h>
#include <shared_mutex>
#include <string> #include <string>
#include <utility>
#include <vector>
#include "core/fibers.h" #include "core/fibers.h"
#include "server/acl/user.h" #include "server/acl/user.h"
@ -16,11 +19,13 @@ namespace dfly::acl {
class UserRegistry { class UserRegistry {
public: public:
UserRegistry() = default; UserRegistry();
UserRegistry(const UserRegistry&) = delete; UserRegistry(const UserRegistry&) = delete;
UserRegistry(UserRegistry&&) = delete; UserRegistry(UserRegistry&&) = delete;
using RegistryType = absl::flat_hash_map<std::string, User>;
// Acquires a write lock of mu_ // Acquires a write lock of mu_
// If the user with name `username` does not exist, it's added in the store with // If the user with name `username` does not exist, it's added in the store with
// the exact fields found in req // the exact fields found in req
@ -48,8 +53,21 @@ class UserRegistry {
// Used by Auth // Used by Auth
bool AuthUser(std::string_view username, std::string_view password) const; bool AuthUser(std::string_view username, std::string_view password) const;
// Helper class for accessing the registry with a ReadLock outside the scope of UserRegistry
class RegistryViewWithLock {
public:
RegistryViewWithLock(std::shared_lock<util::SharedMutex> mu, const RegistryType& registry);
const RegistryType& registry;
private:
std::shared_lock<util::SharedMutex> registry_mu_;
};
// Helper function used for printing users via ACL LIST
RegistryViewWithLock GetRegistryWithLock() const;
private: private:
absl::flat_hash_map<std::string, User> registry_; RegistryType registry_;
// TODO add abseil mutex attributes // TODO add abseil mutex attributes
mutable util::SharedMutex mu_; mutable util::SharedMutex mu_;
}; };

View file

@ -40,8 +40,8 @@ constexpr size_t kNumThreads = 3;
void BlockingControllerTest::SetUp() { void BlockingControllerTest::SetUp() {
pp_.reset(fb2::Pool::IOUring(16, kNumThreads)); pp_.reset(fb2::Pool::IOUring(16, kNumThreads));
pp_->Run(); pp_->Run();
pp_->Await([](unsigned index, ProactorBase* p) { ServerState::Init(index); }); pp_->Await([](unsigned index, ProactorBase* p) { ServerState::Init(index, nullptr); });
ServerState::Init(kNumThreads); ServerState::Init(kNumThreads, nullptr);
shard_set = new EngineShardSet(pp_.get()); shard_set = new EngineShardSet(pp_.get());
shard_set->Init(kNumThreads, false); shard_set->Init(kNumThreads, false);

View file

@ -9,6 +9,8 @@
#include <absl/types/span.h> #include <absl/types/span.h>
#include <atomic> #include <atomic>
#include <cstddef>
#include <cstdint>
#include <string_view> #include <string_view>
#include <vector> #include <vector>

View file

@ -4,6 +4,9 @@
#include "server/main_service.h" #include "server/main_service.h"
#include "facade/resp_expr.h"
#include "server/acl/user_registry.h"
extern "C" { extern "C" {
#include "redis/redis_aux.h" #include "redis/redis_aux.h"
} }
@ -23,6 +26,7 @@ extern "C" {
#include "facade/error.h" #include "facade/error.h"
#include "facade/reply_capture.h" #include "facade/reply_capture.h"
#include "server/acl/acl_commands_def.h" #include "server/acl/acl_commands_def.h"
#include "server/acl/acl_family.h"
#include "server/bitops_family.h" #include "server/bitops_family.h"
#include "server/cluster/cluster_family.h" #include "server/cluster/cluster_family.h"
#include "server/conn_context.h" #include "server/conn_context.h"
@ -654,8 +658,8 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
config_registry.Register("requirepass"); config_registry.Register("requirepass");
config_registry.Register("masterauth"); config_registry.Register("masterauth");
config_registry.Register("tcp_keepalive"); config_registry.Register("tcp_keepalive");
acl::UserRegistry* reg = &user_registry_;
pp_.Await([](uint32_t index, ProactorBase* pb) { ServerState::Init(index); }); pp_.Await([reg](uint32_t index, ProactorBase* pb) { ServerState::Init(index, reg); });
uint32_t shard_num = GetFlag(FLAGS_num_shards); uint32_t shard_num = GetFlag(FLAGS_num_shards);
if (shard_num == 0) { if (shard_num == 0) {
@ -878,7 +882,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
ServerState& etl = *ServerState::tlocal(); ServerState& etl = *ServerState::tlocal();
ToUpper(&args[0]); ToUpper(&args[0]);
const CommandId* cid = FindCmd(args); const auto [cid, args_no_cmd] = FindCmd(args);
if (cid == nullptr) { if (cid == nullptr) {
return (*cntx)->SendError(ReportUnknownCmd(ArgS(args, 0))); return (*cntx)->SendError(ReportUnknownCmd(ArgS(args, 0)));
@ -897,13 +901,11 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
etl.RecordCmd(); etl.RecordCmd();
auto args_no_cmd = args.subspan(1);
if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) { if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) {
if (auto& exec_info = dfly_cntx->conn_state.exec_info; exec_info.IsCollecting()) if (auto& exec_info = dfly_cntx->conn_state.exec_info; exec_info.IsCollecting())
exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR; exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR;
(*dfly_cntx)->SendError(move(*err)); (*dfly_cntx)->SendError(std::move(*err));
return; return;
} }
@ -985,7 +987,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo
DCHECK(!cid->Validate(tail_args)); DCHECK(!cid->Validate(tail_args));
if (auto err = VerifyCommandExecution(cid); err) { if (auto err = VerifyCommandExecution(cid); err) {
(*cntx)->SendError(move(*err)); (*cntx)->SendError(std::move(*err));
return true; // return false only for internal error aborts return true; // return false only for internal error aborts
} }
@ -1042,7 +1044,7 @@ void Service::DispatchManyCommands(absl::Span<CmdArgList> args_list,
for (auto args : args_list) { for (auto args : args_list) {
ToUpper(&args[0]); ToUpper(&args[0]);
const CommandId* cid = FindCmd(args); const auto [cid, tail_args] = FindCmd(args);
// MULTI...EXEC commands need to be collected into a single context, so squashing is not // MULTI...EXEC commands need to be collected into a single context, so squashing is not
// possible // possible
@ -1051,7 +1053,7 @@ void Service::DispatchManyCommands(absl::Span<CmdArgList> args_list,
if (!is_multi && cid != nullptr) { if (!is_multi && cid != nullptr) {
stored_cmds.reserve(args_list.size()); stored_cmds.reserve(args_list.size());
stored_cmds.emplace_back(cid, args.subspan(1)); stored_cmds.emplace_back(cid, tail_args);
continue; continue;
} }
@ -1204,6 +1206,10 @@ facade::ConnectionStats* Service::GetThreadLocalConnectionStats() {
return ServerState::tl_connection_stats(); return ServerState::tl_connection_stats();
} }
const CommandId* Service::FindCmd(std::string_view cmd) const {
return registry_.Find(cmd);
}
bool Service::IsLocked(DbIndex db_index, std::string_view key) const { bool Service::IsLocked(DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_count()); ShardId sid = Shard(key, shard_count());
bool is_open = pp_.at(sid)->AwaitBrief([db_index, key] { bool is_open = pp_.at(sid)->AwaitBrief([db_index, key] {
@ -1312,7 +1318,7 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
info->async_cmds.clear(); info->async_cmds.clear();
auto reply = crb.Take(); auto reply = crb.Take();
return CapturingReplyBuilder::GetError(reply) ? make_optional(move(reply)) : nullopt; return CapturingReplyBuilder::GetError(reply) ? make_optional(std::move(reply)) : nullopt;
} }
void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) { void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) {
@ -1331,7 +1337,7 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
// Full command verification happens during squashed execution // Full command verification happens during squashed execution
if (auto* cid = registry_.Find(ArgS(ca.args, 0)); cid != nullptr) { if (auto* cid = registry_.Find(ArgS(ca.args, 0)); cid != nullptr) {
auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE; auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies); info->async_cmds.emplace_back(std::move(*ca.buffer), cid, ca.args.subspan(1), replies);
info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory(); info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory();
} else if (ca.error_abort) { // If we don't abort on errors, we can ignore it completely } else if (ca.error_abort) { // If we don't abort on errors, we can ignore it completely
findcmd_err = ReportUnknownCmd(ArgS(ca.args, 0)); findcmd_err = ReportUnknownCmd(ArgS(ca.args, 0));
@ -1339,13 +1345,13 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
} }
if (auto err = FlushEvalAsyncCmds(cntx, !ca.async || findcmd_err.has_value()); err) { if (auto err = FlushEvalAsyncCmds(cntx, !ca.async || findcmd_err.has_value()); err) {
CapturingReplyBuilder::Apply(move(*err), &replier); // forward error to lua CapturingReplyBuilder::Apply(std::move(*err), &replier); // forward error to lua
*ca.requested_abort = true; *ca.requested_abort = true;
return; return;
} }
if (findcmd_err.has_value()) { if (findcmd_err.has_value()) {
replier.RedisReplyBuilder::SendError(move(*findcmd_err)); replier.RedisReplyBuilder::SendError(std::move(*findcmd_err));
*ca.requested_abort |= ca.error_abort; *ca.requested_abort |= ca.error_abort;
} }
@ -1370,7 +1376,7 @@ void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
if (!res) if (!res)
return (*cntx)->SendError(res.error().Format(), facade::kScriptErrType); return (*cntx)->SendError(res.error().Format(), facade::kScriptErrType);
string sha{move(res.value())}; string sha{std::move(res.value())};
CallSHA(args, sha, interpreter, cntx); CallSHA(args, sha, interpreter, cntx);
} }
@ -1466,10 +1472,21 @@ bool StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams param
return false; return false;
} }
const CommandId* Service::FindCmd(CmdArgList args) const { static std::string FullAclCommandFromArgs(CmdArgList args) {
ToUpper(&args[1]);
// Guranteed SSO no dynamic allocations here
return std::string("ACL ") + std::string(args[1].begin(), args[1].end());
}
std::pair<const CommandId*, CmdArgList> Service::FindCmd(CmdArgList args) const {
const std::string_view command = facade::ToSV(args[0]);
if (command == "ACL") {
return {registry_.Find(FullAclCommandFromArgs(args)), args.subspan(2)};
}
const CommandId* res = registry_.Find(ArgS(args, 0)); const CommandId* res = registry_.Find(ArgS(args, 0));
if (!res) if (!res)
return nullptr; return {nullptr, args};
// A workaround for XGROUP HELP that does not fit our static taxonomy of commands. // A workaround for XGROUP HELP that does not fit our static taxonomy of commands.
if (args.size() == 2 && res->name() == "XGROUP") { if (args.size() == 2 && res->name() == "XGROUP") {
@ -1477,7 +1494,7 @@ const CommandId* Service::FindCmd(CmdArgList args) const {
res = registry_.Find("_XGROUP_HELP"); res = registry_.Find("_XGROUP_HELP");
} }
} }
return res; return {res, args.subspan(1)};
} }
void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
@ -1766,7 +1783,7 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
for (auto& sub : subscribers) for (auto& sub : subscribers)
sub.conn_cntx->owner()->EnsureAsyncMemoryBudget(); sub.conn_cntx->owner()->EnsureAsyncMemoryBudget();
auto subscribers_ptr = make_shared<decltype(subscribers)>(move(subscribers)); auto subscribers_ptr = make_shared<decltype(subscribers)>(std::move(subscribers));
auto buf = shared_ptr<char[]>{new char[channel.size() + msg.size()]}; auto buf = shared_ptr<char[]>{new char[channel.size() + msg.size()]};
memcpy(buf.get(), channel.data(), channel.size()); memcpy(buf.get(), channel.data(), channel.size());
memcpy(buf.get() + channel.size(), msg.data(), msg.size()); memcpy(buf.get() + channel.size(), msg.data(), msg.size());
@ -1778,7 +1795,8 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
while (it != subscribers_ptr->end() && it->thread_id == idx) { while (it != subscribers_ptr->end() && it->thread_id == idx) {
facade::Connection* conn = it->conn_cntx->owner(); facade::Connection* conn = it->conn_cntx->owner();
DCHECK(conn); DCHECK(conn);
conn->SendPubMessageAsync({move(it->pattern), move(buf), channel.size(), msg.size()}); conn->SendPubMessageAsync(
{std::move(it->pattern), std::move(buf), channel.size(), msg.size()});
it->borrow_token.Dec(); it->borrow_token.Dec();
it++; it++;
} }
@ -2081,6 +2099,7 @@ void Service::RegisterCommands() {
BitOpsFamily::Register(&registry_); BitOpsFamily::Register(&registry_);
HllFamily::Register(&registry_); HllFamily::Register(&registry_);
SearchFamily::Register(&registry_); SearchFamily::Register(&registry_);
acl::AclFamily::Register(&registry_);
server_family_.Register(&registry_); server_family_.Register(&registry_);
cluster_family_.Register(&registry_); cluster_family_.Register(&registry_);

View file

@ -4,9 +4,12 @@
#pragma once #pragma once
#include <utility>
#include "base/varz_value.h" #include "base/varz_value.h"
#include "core/interpreter.h" #include "core/interpreter.h"
#include "facade/service_interface.h" #include "facade/service_interface.h"
#include "server/acl/user_registry.h"
#include "server/cluster/cluster_family.h" #include "server/cluster/cluster_family.h"
#include "server/command_registry.h" #include "server/command_registry.h"
#include "server/config_registry.h" #include "server/config_registry.h"
@ -68,9 +71,8 @@ class Service : public facade::ServiceInterface {
facade::ConnectionStats* GetThreadLocalConnectionStats() final; facade::ConnectionStats* GetThreadLocalConnectionStats() final;
const CommandId* FindCmd(std::string_view cmd) const { std::pair<const CommandId*, CmdArgList> FindCmd(CmdArgList args) const;
return registry_.Find(cmd); const CommandId* FindCmd(std::string_view) const;
}
CommandRegistry* mutable_registry() { CommandRegistry* mutable_registry() {
return &registry_; return &registry_;
@ -145,8 +147,6 @@ class Service : public facade::ServiceInterface {
std::optional<facade::ErrorReply> CheckKeysOwnership(const CommandId* cid, CmdArgList args, std::optional<facade::ErrorReply> CheckKeysOwnership(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx); const ConnectionContext& dfly_cntx);
const CommandId* FindCmd(CmdArgList args) const;
void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx); void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx);
void CallSHA(CmdArgList args, std::string_view sha, Interpreter* interpreter, void CallSHA(CmdArgList args, std::string_view sha, Interpreter* interpreter,
ConnectionContext* cntx); ConnectionContext* cntx);
@ -161,9 +161,9 @@ class Service : public facade::ServiceInterface {
base::VarzValue::Map GetVarzStats(); base::VarzValue::Map GetVarzStats();
private:
util::ProactorPool& pp_; util::ProactorPool& pp_;
acl::UserRegistry user_registry_;
ServerFamily server_family_; ServerFamily server_family_;
ClusterFamily cluster_family_; ClusterFamily cluster_family_;
CommandRegistry registry_; CommandRegistry registry_;

View file

@ -6,6 +6,8 @@
#include <mimalloc.h> #include <mimalloc.h>
#include "server/acl/user_registry.h"
extern "C" { extern "C" {
#include "redis/zmalloc.h" #include "redis/zmalloc.h"
} }
@ -60,10 +62,11 @@ ServerState::ServerState() : interpreter_mgr_{absl::GetFlag(FLAGS_interpreter_pe
ServerState::~ServerState() { ServerState::~ServerState() {
} }
void ServerState::Init(uint32_t thread_index) { void ServerState::Init(uint32_t thread_index, acl::UserRegistry* registry) {
state_ = new ServerState(); state_ = new ServerState();
state_->gstate_ = GlobalState::ACTIVE; state_->gstate_ = GlobalState::ACTIVE;
state_->thread_index_ = thread_index; state_->thread_index_ = thread_index;
state_->user_registry = registry;
} }
void ServerState::Destroy() { void ServerState::Destroy() {

View file

@ -9,6 +9,7 @@
#include "base/histogram.h" #include "base/histogram.h"
#include "core/interpreter.h" #include "core/interpreter.h"
#include "server/acl/user_registry.h"
#include "server/common.h" #include "server/common.h"
#include "server/script_mgr.h" #include "server/script_mgr.h"
#include "util/sliding_counter.h" #include "util/sliding_counter.h"
@ -103,7 +104,7 @@ class ServerState { // public struct - to allow initialization.
ServerState(); ServerState();
~ServerState(); ~ServerState();
static void Init(uint32_t thread_index); static void Init(uint32_t thread_index, acl::UserRegistry* registry);
static void Destroy(); static void Destroy();
void EnterLameDuck() { void EnterLameDuck() {
@ -197,7 +198,6 @@ class ServerState { // public struct - to allow initialization.
channel_store_ = replacement; channel_store_ = replacement;
} }
public:
Stats stats; Stats stats;
bool is_master = true; bool is_master = true;
@ -205,6 +205,8 @@ class ServerState { // public struct - to allow initialization.
facade::ConnectionStats connection_stats; facade::ConnectionStats connection_stats;
acl::UserRegistry* user_registry;
private: private:
int64_t live_transactions_ = 0; int64_t live_transactions_ = 0;
mi_heap_t* data_heap_; mi_heap_t* data_heap_;

View file

@ -0,0 +1,14 @@
import pytest
from redis import asyncio as aioredis
from . import DflyInstanceFactory
from .utility import disconnect_clients
@pytest.mark.asyncio
async def test_acl_list_default_user(async_client):
"""
make sure that the default created user is printed correctly
"""
result = await async_client.execute_command("ACL LIST")
assert 1 == len(result)
assert "user default on nopass +@all" == result[0]