feat(AclFamilly): add acl list command (#1722)

* Add acl-family source and header
* Add `ACL LIST` command
* Add a simple test to check the default user
This commit is contained in:
Kostas Kyrimis 2023-08-22 18:33:14 +03:00 committed by GitHub
parent eae02a16da
commit 898061d738
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 234 additions and 39 deletions

View file

@ -62,7 +62,7 @@ struct ConnectionStats {
struct ErrorReply {
explicit ErrorReply(std::string&& msg, std::string_view kind = {})
: message{move(msg)}, kind{kind} {
: message{std::move(msg)}, kind{kind} {
}
explicit ErrorReply(std::string_view msg, std::string_view kind = {}) : message{msg}, kind{kind} {
}

View file

@ -25,7 +25,7 @@ add_library(dragonfly_lib channel_store.cc command_registry.cc
zset_family.cc version.cc bitops_family.cc container_utils.cc io_utils.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc
top_keys.cc multi_command_squasher.cc hll_family.cc cluster/cluster_config.cc
cluster/cluster_family.cc acl/user.cc acl/user_registry.cc)
cluster/cluster_family.cc acl/user.cc acl/user_registry.cc acl/acl_family.cc)
cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib aws_lib strings_lib html_lib

View file

@ -89,5 +89,20 @@ inline const absl::flat_hash_map<std::string_view, uint32_t> CATEGORY_INDEX_TABL
{"DANGEROUS", DANGEROUS},
{"CONNECTION", CONNECTION},
{"TRANSACTION", TRANSACTION},
{"SCRIPTING", SCRIPTING}};
{"SCRIPTING", SCRIPTING},
{"FT_SEARCH", FT_SEARCH},
{"THROTTLE", THROTTLE},
{"JSON", JSON}
};
// bit 0 at index 0
// bit 1 at index 1
// bit n at index n
inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"KEYSPACE", "READ", "WRITE", "SET", "SORTED_SET", "LIST",
"HASH", "STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM",
"PUBSUB", "ADMIN", "FAST", "SLOW", "BLOCKING", "DANGEROUS",
"CONNECTION", "TRANSACTION", "SCRIPTING", "FT_SEARCH", "THROTTLE", "JSON"};
} // namespace dfly::acl

View file

@ -0,0 +1,78 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
#include "server/acl/acl_family.h"
#include "absl/strings/str_cat.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/server_state.h"
namespace dfly::acl {
constexpr uint32_t kList = acl::ADMIN | acl::SLOW | acl::DANGEROUS;
static std::string AclToString(uint32_t acl_category) {
std::string tmp;
if (acl_category == acl::ALL) {
return "+@all";
}
if (acl_category == acl::NONE) {
return "+@none";
}
const std::string prefix = "+@";
const std::string postfix = " ";
for (uint32_t step = 0, cat = 0; cat != JSON; cat = 1ULL << ++step) {
if (acl_category & cat) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[step], postfix);
}
}
tmp.erase(tmp.size());
return tmp;
}
void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
const auto registry_with_lock = ServerState::tlocal()->user_registry->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
(*cntx)->StartArray(registry.size());
for (const auto& [username, user] : registry) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : std::string(pass.substr(0, 15));
const std::string acl_cat = AclToString(user.AclCategory());
using namespace std::string_view_literals;
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat);
(*cntx)->SendSimpleString(buffer);
}
}
using CI = dfly::CommandId;
#define HFUNC(x) SetHandler(&AclFamily::x)
// We can't implement the ACL commands and its respective subcommands LIST, CAT, etc
// the usual way, (that is, one command called ACL which then dispatches to the subcommand
// based on the secocond argument) because each of the subcommands has different ACL
// categories. Therefore, to keep it compatible with the CommandId, I need to treat them
// as separate commands in the registry. This is the least intrusive change because it's very
// easy to handle that case explicitly in `DispatchCommand`.
void AclFamily::Register(dfly::CommandRegistry* registry) {
*registry << CI{"ACL LIST", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, 0, 0, 0, 0, acl::kList}.HFUNC(
List);
}
#undef HFUNC
} // namespace dfly::acl

View file

@ -0,0 +1,24 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include "server/common.h"
namespace dfly {
class ConnectionContext;
class CommandRegistry;
namespace acl {
class AclFamily {
public:
static void Register(CommandRegistry* registry);
private:
static void List(CmdArgList args, ConnectionContext* cntx);
};
} // namespace acl
} // namespace dfly

View file

@ -81,4 +81,10 @@ bool User::IsActive() const {
return is_active_;
}
static const std::string_view default_pass = "nopass";
std::string_view User::Password() const {
return password_hash_ ? *password_hash_ : default_pass;
}
} // namespace dfly::acl

View file

@ -16,8 +16,6 @@
namespace dfly::acl {
class CommandId;
// TODO implement these
//#bool CheckIfCommandAllowed(uint64_t command_id, const CommandId& command);
//#bool CheckIfAclCategoryAllowed(uint64_t command_id, const CommandId& command);
@ -59,6 +57,8 @@ class User final {
bool IsActive() const;
std::string_view Password() const;
private:
// For ACL categories
void SetAclCategories(uint64_t cat);

View file

@ -4,12 +4,16 @@
#include "server/acl/user_registry.h"
#include <shared_mutex>
#include "core/fibers.h"
#include "server/acl/acl_commands_def.h"
namespace dfly::acl {
UserRegistry::UserRegistry() {
User::UpdateRequest req{{}, acl::ALL, {}, true};
MaybeAddAndUpdate("default", std::move(req));
}
void UserRegistry::MaybeAddAndUpdate(std::string_view username, User::UpdateRequest req) {
std::unique_lock<util::SharedMutex> lock(mu_);
auto& user = registry_[username];
@ -50,4 +54,14 @@ bool UserRegistry::AuthUser(std::string_view username, std::string_view password
return user->second.HasPassword(password);
}
UserRegistry::RegistryViewWithLock::RegistryViewWithLock(std::shared_lock<util::SharedMutex> mu,
const RegistryType& registry)
: registry(registry), registry_mu_(std::move(mu)) {
}
UserRegistry::RegistryViewWithLock UserRegistry::GetRegistryWithLock() const {
std::shared_lock<util::SharedMutex> lock(mu_);
return {std::move(lock), registry_};
}
} // namespace dfly::acl

View file

@ -7,7 +7,10 @@
#include <absl/container/flat_hash_map.h>
#include <absl/synchronization/mutex.h>
#include <shared_mutex>
#include <string>
#include <utility>
#include <vector>
#include "core/fibers.h"
#include "server/acl/user.h"
@ -16,11 +19,13 @@ namespace dfly::acl {
class UserRegistry {
public:
UserRegistry() = default;
UserRegistry();
UserRegistry(const UserRegistry&) = delete;
UserRegistry(UserRegistry&&) = delete;
using RegistryType = absl::flat_hash_map<std::string, User>;
// Acquires a write lock of mu_
// If the user with name `username` does not exist, it's added in the store with
// the exact fields found in req
@ -48,8 +53,21 @@ class UserRegistry {
// Used by Auth
bool AuthUser(std::string_view username, std::string_view password) const;
// Helper class for accessing the registry with a ReadLock outside the scope of UserRegistry
class RegistryViewWithLock {
public:
RegistryViewWithLock(std::shared_lock<util::SharedMutex> mu, const RegistryType& registry);
const RegistryType& registry;
private:
std::shared_lock<util::SharedMutex> registry_mu_;
};
// Helper function used for printing users via ACL LIST
RegistryViewWithLock GetRegistryWithLock() const;
private:
absl::flat_hash_map<std::string, User> registry_;
RegistryType registry_;
// TODO add abseil mutex attributes
mutable util::SharedMutex mu_;
};

View file

@ -40,8 +40,8 @@ constexpr size_t kNumThreads = 3;
void BlockingControllerTest::SetUp() {
pp_.reset(fb2::Pool::IOUring(16, kNumThreads));
pp_->Run();
pp_->Await([](unsigned index, ProactorBase* p) { ServerState::Init(index); });
ServerState::Init(kNumThreads);
pp_->Await([](unsigned index, ProactorBase* p) { ServerState::Init(index, nullptr); });
ServerState::Init(kNumThreads, nullptr);
shard_set = new EngineShardSet(pp_.get());
shard_set->Init(kNumThreads, false);

View file

@ -9,6 +9,8 @@
#include <absl/types/span.h>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <string_view>
#include <vector>

View file

@ -4,6 +4,9 @@
#include "server/main_service.h"
#include "facade/resp_expr.h"
#include "server/acl/user_registry.h"
extern "C" {
#include "redis/redis_aux.h"
}
@ -23,6 +26,7 @@ extern "C" {
#include "facade/error.h"
#include "facade/reply_capture.h"
#include "server/acl/acl_commands_def.h"
#include "server/acl/acl_family.h"
#include "server/bitops_family.h"
#include "server/cluster/cluster_family.h"
#include "server/conn_context.h"
@ -654,8 +658,8 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
config_registry.Register("requirepass");
config_registry.Register("masterauth");
config_registry.Register("tcp_keepalive");
pp_.Await([](uint32_t index, ProactorBase* pb) { ServerState::Init(index); });
acl::UserRegistry* reg = &user_registry_;
pp_.Await([reg](uint32_t index, ProactorBase* pb) { ServerState::Init(index, reg); });
uint32_t shard_num = GetFlag(FLAGS_num_shards);
if (shard_num == 0) {
@ -878,7 +882,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
ServerState& etl = *ServerState::tlocal();
ToUpper(&args[0]);
const CommandId* cid = FindCmd(args);
const auto [cid, args_no_cmd] = FindCmd(args);
if (cid == nullptr) {
return (*cntx)->SendError(ReportUnknownCmd(ArgS(args, 0)));
@ -897,13 +901,11 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
etl.RecordCmd();
auto args_no_cmd = args.subspan(1);
if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) {
if (auto& exec_info = dfly_cntx->conn_state.exec_info; exec_info.IsCollecting())
exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR;
(*dfly_cntx)->SendError(move(*err));
(*dfly_cntx)->SendError(std::move(*err));
return;
}
@ -985,7 +987,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo
DCHECK(!cid->Validate(tail_args));
if (auto err = VerifyCommandExecution(cid); err) {
(*cntx)->SendError(move(*err));
(*cntx)->SendError(std::move(*err));
return true; // return false only for internal error aborts
}
@ -1042,7 +1044,7 @@ void Service::DispatchManyCommands(absl::Span<CmdArgList> args_list,
for (auto args : args_list) {
ToUpper(&args[0]);
const CommandId* cid = FindCmd(args);
const auto [cid, tail_args] = FindCmd(args);
// MULTI...EXEC commands need to be collected into a single context, so squashing is not
// possible
@ -1051,7 +1053,7 @@ void Service::DispatchManyCommands(absl::Span<CmdArgList> args_list,
if (!is_multi && cid != nullptr) {
stored_cmds.reserve(args_list.size());
stored_cmds.emplace_back(cid, args.subspan(1));
stored_cmds.emplace_back(cid, tail_args);
continue;
}
@ -1204,6 +1206,10 @@ facade::ConnectionStats* Service::GetThreadLocalConnectionStats() {
return ServerState::tl_connection_stats();
}
const CommandId* Service::FindCmd(std::string_view cmd) const {
return registry_.Find(cmd);
}
bool Service::IsLocked(DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_count());
bool is_open = pp_.at(sid)->AwaitBrief([db_index, key] {
@ -1312,7 +1318,7 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
info->async_cmds.clear();
auto reply = crb.Take();
return CapturingReplyBuilder::GetError(reply) ? make_optional(move(reply)) : nullopt;
return CapturingReplyBuilder::GetError(reply) ? make_optional(std::move(reply)) : nullopt;
}
void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) {
@ -1331,7 +1337,7 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
// Full command verification happens during squashed execution
if (auto* cid = registry_.Find(ArgS(ca.args, 0)); cid != nullptr) {
auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies);
info->async_cmds.emplace_back(std::move(*ca.buffer), cid, ca.args.subspan(1), replies);
info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory();
} else if (ca.error_abort) { // If we don't abort on errors, we can ignore it completely
findcmd_err = ReportUnknownCmd(ArgS(ca.args, 0));
@ -1339,13 +1345,13 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
}
if (auto err = FlushEvalAsyncCmds(cntx, !ca.async || findcmd_err.has_value()); err) {
CapturingReplyBuilder::Apply(move(*err), &replier); // forward error to lua
CapturingReplyBuilder::Apply(std::move(*err), &replier); // forward error to lua
*ca.requested_abort = true;
return;
}
if (findcmd_err.has_value()) {
replier.RedisReplyBuilder::SendError(move(*findcmd_err));
replier.RedisReplyBuilder::SendError(std::move(*findcmd_err));
*ca.requested_abort |= ca.error_abort;
}
@ -1370,7 +1376,7 @@ void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
if (!res)
return (*cntx)->SendError(res.error().Format(), facade::kScriptErrType);
string sha{move(res.value())};
string sha{std::move(res.value())};
CallSHA(args, sha, interpreter, cntx);
}
@ -1466,10 +1472,21 @@ bool StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams param
return false;
}
const CommandId* Service::FindCmd(CmdArgList args) const {
static std::string FullAclCommandFromArgs(CmdArgList args) {
ToUpper(&args[1]);
// Guranteed SSO no dynamic allocations here
return std::string("ACL ") + std::string(args[1].begin(), args[1].end());
}
std::pair<const CommandId*, CmdArgList> Service::FindCmd(CmdArgList args) const {
const std::string_view command = facade::ToSV(args[0]);
if (command == "ACL") {
return {registry_.Find(FullAclCommandFromArgs(args)), args.subspan(2)};
}
const CommandId* res = registry_.Find(ArgS(args, 0));
if (!res)
return nullptr;
return {nullptr, args};
// A workaround for XGROUP HELP that does not fit our static taxonomy of commands.
if (args.size() == 2 && res->name() == "XGROUP") {
@ -1477,7 +1494,7 @@ const CommandId* Service::FindCmd(CmdArgList args) const {
res = registry_.Find("_XGROUP_HELP");
}
}
return res;
return {res, args.subspan(1)};
}
void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
@ -1766,7 +1783,7 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
for (auto& sub : subscribers)
sub.conn_cntx->owner()->EnsureAsyncMemoryBudget();
auto subscribers_ptr = make_shared<decltype(subscribers)>(move(subscribers));
auto subscribers_ptr = make_shared<decltype(subscribers)>(std::move(subscribers));
auto buf = shared_ptr<char[]>{new char[channel.size() + msg.size()]};
memcpy(buf.get(), channel.data(), channel.size());
memcpy(buf.get() + channel.size(), msg.data(), msg.size());
@ -1778,7 +1795,8 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
while (it != subscribers_ptr->end() && it->thread_id == idx) {
facade::Connection* conn = it->conn_cntx->owner();
DCHECK(conn);
conn->SendPubMessageAsync({move(it->pattern), move(buf), channel.size(), msg.size()});
conn->SendPubMessageAsync(
{std::move(it->pattern), std::move(buf), channel.size(), msg.size()});
it->borrow_token.Dec();
it++;
}
@ -2081,6 +2099,7 @@ void Service::RegisterCommands() {
BitOpsFamily::Register(&registry_);
HllFamily::Register(&registry_);
SearchFamily::Register(&registry_);
acl::AclFamily::Register(&registry_);
server_family_.Register(&registry_);
cluster_family_.Register(&registry_);

View file

@ -4,9 +4,12 @@
#pragma once
#include <utility>
#include "base/varz_value.h"
#include "core/interpreter.h"
#include "facade/service_interface.h"
#include "server/acl/user_registry.h"
#include "server/cluster/cluster_family.h"
#include "server/command_registry.h"
#include "server/config_registry.h"
@ -68,9 +71,8 @@ class Service : public facade::ServiceInterface {
facade::ConnectionStats* GetThreadLocalConnectionStats() final;
const CommandId* FindCmd(std::string_view cmd) const {
return registry_.Find(cmd);
}
std::pair<const CommandId*, CmdArgList> FindCmd(CmdArgList args) const;
const CommandId* FindCmd(std::string_view) const;
CommandRegistry* mutable_registry() {
return &registry_;
@ -145,8 +147,6 @@ class Service : public facade::ServiceInterface {
std::optional<facade::ErrorReply> CheckKeysOwnership(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx);
const CommandId* FindCmd(CmdArgList args) const;
void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx);
void CallSHA(CmdArgList args, std::string_view sha, Interpreter* interpreter,
ConnectionContext* cntx);
@ -161,9 +161,9 @@ class Service : public facade::ServiceInterface {
base::VarzValue::Map GetVarzStats();
private:
util::ProactorPool& pp_;
acl::UserRegistry user_registry_;
ServerFamily server_family_;
ClusterFamily cluster_family_;
CommandRegistry registry_;

View file

@ -6,6 +6,8 @@
#include <mimalloc.h>
#include "server/acl/user_registry.h"
extern "C" {
#include "redis/zmalloc.h"
}
@ -60,10 +62,11 @@ ServerState::ServerState() : interpreter_mgr_{absl::GetFlag(FLAGS_interpreter_pe
ServerState::~ServerState() {
}
void ServerState::Init(uint32_t thread_index) {
void ServerState::Init(uint32_t thread_index, acl::UserRegistry* registry) {
state_ = new ServerState();
state_->gstate_ = GlobalState::ACTIVE;
state_->thread_index_ = thread_index;
state_->user_registry = registry;
}
void ServerState::Destroy() {

View file

@ -9,6 +9,7 @@
#include "base/histogram.h"
#include "core/interpreter.h"
#include "server/acl/user_registry.h"
#include "server/common.h"
#include "server/script_mgr.h"
#include "util/sliding_counter.h"
@ -103,7 +104,7 @@ class ServerState { // public struct - to allow initialization.
ServerState();
~ServerState();
static void Init(uint32_t thread_index);
static void Init(uint32_t thread_index, acl::UserRegistry* registry);
static void Destroy();
void EnterLameDuck() {
@ -197,7 +198,6 @@ class ServerState { // public struct - to allow initialization.
channel_store_ = replacement;
}
public:
Stats stats;
bool is_master = true;
@ -205,6 +205,8 @@ class ServerState { // public struct - to allow initialization.
facade::ConnectionStats connection_stats;
acl::UserRegistry* user_registry;
private:
int64_t live_transactions_ = 0;
mi_heap_t* data_heap_;

View file

@ -0,0 +1,14 @@
import pytest
from redis import asyncio as aioredis
from . import DflyInstanceFactory
from .utility import disconnect_clients
@pytest.mark.asyncio
async def test_acl_list_default_user(async_client):
"""
make sure that the default created user is printed correctly
"""
result = await async_client.execute_command("ACL LIST")
assert 1 == len(result)
assert "user default on nopass +@all" == result[0]