feat(AclFamily): add acl commands (#1844)

This commit is contained in:
Kostas Kyrimis 2023-09-15 14:28:36 +03:00 committed by GitHub
parent ff079f0af1
commit bbd4c6b636
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
39 changed files with 671 additions and 240 deletions

View file

@ -64,6 +64,22 @@ class CommandId {
return acl_categories_;
}
void SetFamily(size_t fam) {
family_ = fam;
}
void SetBitIndex(uint64_t bit) {
bit_index_ = bit;
}
size_t GetFamily() const {
return family_;
}
uint64_t GetBitIndex() const {
return bit_index_;
}
static uint32_t OptCount(uint32_t mask);
protected:
@ -74,7 +90,11 @@ class CommandId {
int8_t first_key_;
int8_t last_key_;
int8_t step_key_;
// Acl categories
uint32_t acl_categories_;
// Acl commands indices
size_t family_;
uint64_t bit_index_;
};
} // namespace facade

View file

@ -217,6 +217,7 @@ void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
for (size_t id = 0; id < msg.username.size(); ++id) {
if (msg.username[id] == ctx->authed_username) {
ctx->acl_categories = msg.categories[id];
ctx->acl_commands = msg.commands[id];
}
}
}
@ -989,7 +990,7 @@ void Connection::SendMonitorMessageAsync(string msg) {
}
void Connection::SendAclUpdateAsync(AclUpdateMessage msg) {
SendAsync({msg});
SendAsync({std::move(msg)});
}
void Connection::SendAsync(MessageHandle msg) {
@ -1009,11 +1010,9 @@ void Connection::SendAsync(MessageHandle msg) {
auto place_in_dispatch_q = [this](MessageHandle msg) {
auto it = dispatch_q_.begin();
for (; it < dispatch_q_.end(); ++it) {
if (!std::holds_alternative<AclUpdateMessage>(it->handle)) {
break;
}
}
const auto end = dispatch_q_.end();
while (it < end && std::holds_alternative<AclUpdateMessage>(it->handle))
++it;
dispatch_q_.insert(it, std::move(msg));
};

View file

@ -82,6 +82,7 @@ class Connection : public util::Connection {
struct AclUpdateMessage {
std::vector<std::string> username;
std::vector<uint32_t> categories;
std::vector<std::vector<uint64_t>> commands;
};
struct PipelineMessage {

View file

@ -41,7 +41,7 @@ add_library(dragonfly_lib channel_store.cc command_registry.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc
top_keys.cc multi_command_squasher.cc hll_family.cc cluster/cluster_config.cc
cluster/cluster_family.cc acl/user.cc acl/user_registry.cc acl/acl_family.cc
acl/validator.cc)
acl/validator.cc acl/helpers.cc)
find_library(ZSTD_LIB NAMES libzstd.a libzstdstatic.a zstd NAMES_PER_DIR REQUIRED)
@ -88,4 +88,4 @@ add_dependencies(check_dfly dragonfly_test json_family_test list_family_test
generic_family_test memcache_parser_test rdb_test journal_test
redis_parser_test snapshot_test stream_family_test string_family_test
bitops_family_test set_family_test zset_family_test hll_family_test
cluster_config_test cluster_family_test user_registry_test)
cluster_config_test cluster_family_test user_registry_test acl_family_test)

View file

@ -5,6 +5,7 @@
#pragma once
#include "absl/container/flat_hash_map.h"
#include "base/logging.h"
namespace dfly::acl {
/* There are 21 ACL categories as of redis 7
@ -106,4 +107,29 @@ inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED",
"_RESERVED", "FT_SEARCH", "THROTTLE", "JSON"};
using RevCommandField = std::vector<std::string>;
using RevCommandsIndexStore = std::vector<RevCommandField>;
constexpr uint64_t ALL_COMMANDS = std::numeric_limits<uint64_t>::max();
constexpr uint64_t NONE_COMMANDS = std::numeric_limits<uint64_t>::min();
// A variation of meyers singleton
// This is initialized when the constructor of Service is called.
// Basically, it calls this functions within the AclFamily::Register
// functions which has the number of all the acl families registered
inline size_t NumberOfFamilies(size_t number = 0) {
static size_t number_of_families = number;
return number_of_families;
}
inline const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore store = {}) {
static RevCommandsIndexStore rev_index_store = std::move(store);
return rev_index_store;
}
inline void BuildIndexers(std::vector<std::vector<std::string>> families) {
acl::NumberOfFamilies(families.size());
acl::CommandsRevIndexer(std::move(families));
}
} // namespace dfly::acl

View file

@ -12,11 +12,7 @@
#include <utility>
#include <variant>
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "base/flags.h"
#include "base/logging.h"
@ -27,47 +23,14 @@
#include "io/file_util.h"
#include "io/io.h"
#include "server/acl/acl_commands_def.h"
#include "server/acl/helpers.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/server_state.h"
ABSL_FLAG(std::string, aclfile, "", "Path and name to aclfile");
namespace dfly::acl {
static std::string AclToString(uint32_t acl_category) {
std::string tmp;
if (acl_category == acl::ALL) {
return "+@ALL";
}
if (acl_category == acl::NONE) {
return "+@NONE";
}
const std::string prefix = "+@";
const std::string postfix = " ";
for (uint32_t i = 0; i < 32; i++) {
uint32_t cat_bit = 1ULL << i;
if (acl_category & cat_bit) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[i], postfix);
}
}
tmp.pop_back();
return tmp;
}
static std::string PrettyPrintSha(std::string_view pass, bool all = false) {
if (all) {
return absl::BytesToHexString(pass);
}
return absl::BytesToHexString(pass.substr(0, 15)).substr(0, 15);
};
AclFamily::AclFamily(UserRegistry* registry) : registry_(registry) {
}
@ -84,139 +47,29 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass);
const std::string acl_cat = AclToString(user.AclCategory());
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space = acl_commands.empty() ? "" : " ";
using namespace std::string_view_literals;
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat);
acl_cat, maybe_space, acl_commands);
(*cntx)->SendSimpleString(buffer);
}
}
namespace {
std::optional<std::string> MaybeParsePassword(std::string_view command) {
if (command == "nopass") {
return std::string(command);
}
if (command[0] != '>') {
return {};
}
return std::string(command.substr(1));
}
std::optional<bool> MaybeParseStatus(std::string_view command) {
if (command == "ON") {
return true;
}
if (command == "OFF") {
return false;
}
return {};
}
using OptCat = std::optional<uint32_t>;
// bool == true if +
// bool == false if -
std::pair<OptCat, bool> MaybeParseAclCategory(std::string_view command) {
if (absl::StartsWith(command, "+@")) {
auto res = CATEGORY_INDEX_TABLE.find(command.substr(2));
if (res == CATEGORY_INDEX_TABLE.end()) {
return {};
}
return {res->second, true};
}
if (absl::StartsWith(command, "-@")) {
auto res = CATEGORY_INDEX_TABLE.find(command.substr(2));
if (res == CATEGORY_INDEX_TABLE.end()) {
return {};
}
return {res->second, false};
}
return {};
}
using facade::ErrorReply;
template <typename T>
std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args, bool hashed = false) {
User::UpdateRequest req;
for (auto& arg : args) {
if (auto pass = MaybeParsePassword(facade::ToSV(arg)); pass) {
if (req.password) {
return ErrorReply("Only one password is allowed");
}
req.password = std::move(pass);
req.is_hashed = hashed;
continue;
}
if constexpr (std::is_same_v<T, CmdArgList>) {
ToUpper(&arg);
}
const auto command = facade::ToSV(arg);
if (auto status = MaybeParseStatus(command); status) {
if (req.is_active) {
return ErrorReply("Multiple ON/OFF are not allowed");
}
req.is_active = *status;
continue;
}
auto [cat, add] = MaybeParseAclCategory(command);
if (!cat) {
return ErrorReply(absl::StrCat("Unrecognized parameter ", command));
}
using Sign = User::Sign;
using Val = std::pair<Sign, uint32_t>;
auto val = add ? Val{Sign::PLUS, *cat} : Val{Sign::MINUS, *cat};
req.categories.push_back(val);
}
return req;
}
using MaterializedContents = std::optional<std::vector<std::vector<std::string_view>>>;
MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames,
std::string_view file_contents) {
// This is fine, a very large file will top at 1-2 mb. And that's for 5000+ users with 400
// characters of ACL infor...
std::vector<std::string_view> commands = absl::StrSplit(file_contents, "\n");
std::vector<std::vector<std::string_view>> materialized;
materialized.reserve(commands.size());
usernames->reserve(commands.size());
for (auto& command : commands) {
if (command.empty())
continue;
std::vector<std::string_view> cmds = absl::StrSplit(command, ' ');
if (cmds[0] != "ACL" || cmds[1] != "SETUSER" || cmds.size() < 3) {
return {};
}
usernames->push_back(std::string(cmds[2]));
cmds.erase(cmds.begin(), cmds.begin() + 3);
materialized.push_back(cmds);
}
return materialized;
}
} // namespace
void AclFamily::StreamUpdatesToAllProactorConnections(const std::vector<std::string>& user,
const std::vector<uint32_t>& update_cat) {
auto update_cb = [&user, &update_cat]([[maybe_unused]] size_t id, util::Connection* conn) {
const std::vector<uint32_t>& update_cat,
const NestedVector& update_commands) {
auto update_cb = [&user, &update_cat, &update_commands]([[maybe_unused]] size_t id,
util::Connection* conn) {
DCHECK(conn);
auto connection = static_cast<facade::Connection*>(conn);
DCHECK(user.size() == update_cat.size());
connection->SendAclUpdateAsync(facade::Connection::AclUpdateMessage{user, update_cat});
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands});
};
if (main_listener_) {
@ -224,15 +77,18 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::vector<std::str
}
}
using facade::ErrorReply;
void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
std::string_view username = facade::ToSV(args[0]);
auto req = ParseAclSetUser(args.subspan(1));
auto req = ParseAclSetUser(args.subspan(1), *cmd_registry_);
auto error_case = [cntx](ErrorReply&& error) { (*cntx)->SendError(error); };
auto update_case = [username, cntx, this](User::UpdateRequest&& req) {
auto user_with_lock = registry_->MaybeAddAndUpdateWithLock(username, std::move(req));
if (user_with_lock.exists) {
StreamUpdatesToAllProactorConnections({std::string(username)},
{user_with_lock.user.AclCategory()});
{user_with_lock.user.AclCategory()},
{user_with_lock.user.AclCommands()});
}
cntx->SendOk();
};
@ -280,15 +136,19 @@ std::string AclFamily::RegistryToString() const {
const std::string_view pass = user.Password();
const std::string password =
pass == "nopass" ? "nopass " : absl::StrCat(">", PrettyPrintSha(pass, true), " ");
const std::string acl_cat = AclToString(user.AclCategory());
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space = acl_commands.empty() ? "" : " ";
using namespace std::string_view_literals;
absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password,
acl_cat, "\n");
acl_cat, maybe_space, acl_commands, "\n");
}
result.pop_back();
if (!result.empty()) {
result.pop_back();
}
return result;
}
@ -359,7 +219,7 @@ std::optional<facade::ErrorReply> AclFamily::LoadToRegistryFromFile(std::string_
std::vector<User::UpdateRequest> requests;
for (auto& cmds : *materialized) {
auto req = ParseAclSetUser<std::vector<std::string_view>&>(cmds, true);
auto req = ParseAclSetUser<std::vector<std::string_view>&>(cmds, *cmd_registry_, true);
if (std::holds_alternative<ErrorReply>(req)) {
auto error = std::move(std::get<ErrorReply>(req));
LOG(WARNING) << "Error while parsing aclfile: " << error.ToSv();
@ -375,22 +235,24 @@ std::optional<facade::ErrorReply> AclFamily::LoadToRegistryFromFile(std::string_
registry.clear();
}
std::vector<uint32_t> categories;
NestedVector commands;
for (size_t i = 0; i < usernames.size(); ++i) {
auto& user = registry[usernames[i]];
user.Update(std::move(requests[i]));
categories.push_back(user.AclCategory());
commands.push_back(user.AclCommands());
}
if (!init) {
StreamUpdatesToAllProactorConnections(usernames, categories);
StreamUpdatesToAllProactorConnections(usernames, categories, commands);
}
return {};
}
void AclFamily::Load() {
bool AclFamily::Load() {
auto acl_file = absl::GetFlag(FLAGS_aclfile);
LoadToRegistryFromFile(acl_file, true);
return !LoadToRegistryFromFile(acl_file, true).has_value();
}
void AclFamily::Load(CmdArgList args, ConnectionContext* cntx) {
@ -410,8 +272,6 @@ void AclFamily::Load(CmdArgList args, ConnectionContext* cntx) {
cntx->SendOk();
}
using CI = dfly::CommandId;
using MemberFunc = void (AclFamily::*)(CmdArgList args, ConnectionContext* cntx);
inline CommandId::Handler HandlerFunc(AclFamily* acl, MemberFunc f) {
@ -436,6 +296,9 @@ constexpr uint32_t kLoad = acl::ADMIN | acl::SLOW | acl::DANGEROUS;
// easy to handle that case explicitly in `DispatchCommand`.
void AclFamily::Register(dfly::CommandRegistry* registry) {
using CI = dfly::CommandId;
registry->StartFamily();
*registry << CI{"ACL", CO::NOSCRIPT | CO::LOADING, 0, 0, 0, 0, acl::kAcl}.HFUNC(Acl);
*registry << CI{"ACL LIST", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, 1, 0, 0, 0, acl::kList}.HFUNC(
List);
@ -449,6 +312,8 @@ void AclFamily::Register(dfly::CommandRegistry* registry) {
Save);
*registry << CI{"ACL LOAD", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, 1, 0, 0, 0, acl::kLoad}.HFUNC(
Load);
cmd_registry_ = registry;
}
#undef HFUNC
@ -458,8 +323,12 @@ void AclFamily::Init(facade::Listener* main_listener, UserRegistry* registry) {
registry_ = registry;
auto acl_file = absl::GetFlag(FLAGS_aclfile);
if (!acl_file.empty()) {
Load();
if (!Load()) {
registry_->Init();
}
return;
}
registry_->Init();
}
} // namespace dfly::acl

View file

@ -13,13 +13,12 @@
#include "facade/facade_types.h"
#include "helio/util/proactor_pool.h"
#include "server/acl/user_registry.h"
#include "server/command_registry.h"
#include "server/common.h"
namespace dfly {
class ConnectionContext;
class CommandRegistry;
namespace acl {
class AclFamily final {
@ -37,12 +36,14 @@ class AclFamily final {
void WhoAmI(CmdArgList args, ConnectionContext* cntx);
void Save(CmdArgList args, ConnectionContext* cntx);
void Load(CmdArgList args, ConnectionContext* cntx);
void Load();
bool Load();
// Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads
using NestedVector = std::vector<std::vector<uint64_t>>;
void StreamUpdatesToAllProactorConnections(const std::vector<std::string>& user,
const std::vector<uint32_t>& update_cat);
const std::vector<uint32_t>& update_cat,
const NestedVector& update_commands);
// Helper function that closes all open connection from the deleted user
void EvictOpenConnectionsOnAllProactors(std::string_view user);
@ -54,6 +55,8 @@ class AclFamily final {
facade::Listener* main_listener_{nullptr};
UserRegistry* registry_;
CommandRegistry* cmd_registry_;
util::ProactorPool* pool_;
};
} // namespace acl

View file

@ -21,6 +21,7 @@ class AclFamilyTest : public BaseFamilyTest {
};
TEST_F(AclFamilyTest, AclSetUser) {
TestInitAclFam();
auto resp = Run({"ACL", "SETUSER"});
EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl setuser' command"));
@ -34,11 +35,12 @@ TEST_F(AclFamilyTest, AclSetUser) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec,
UnorderedElementsAre("user default on nopass +@ALL", "user vlad off nopass +@NONE"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL",
"user vlad off nopass +@NONE"));
}
TEST_F(AclFamilyTest, AclDelUser) {
TestInitAclFam();
auto resp = Run({"ACL", "DELUSER"});
EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl deluser' command"));
@ -55,10 +57,11 @@ TEST_F(AclFamilyTest, AclDelUser) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL");
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL +ALL");
}
TEST_F(AclFamilyTest, AclList) {
TestInitAclFam();
auto resp = Run({"ACL", "LIST", "NONSENSE"});
EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl list' command"));
@ -70,12 +73,13 @@ TEST_F(AclFamilyTest, AclList) {
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL",
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL",
"user kostas off d74ff0ee8da3b98 +@ADMIN",
"user adi off d74ff0ee8da3b98 +@FAST"));
}
TEST_F(AclFamilyTest, AclAuth) {
TestInitAclFam();
auto resp = Run({"ACL", "SETUSER", "shahar", ">mypass"});
EXPECT_THAT(resp, "OK");
@ -94,6 +98,7 @@ TEST_F(AclFamilyTest, AclAuth) {
}
TEST_F(AclFamilyTest, AclWhoAmI) {
TestInitAclFam();
auto resp = Run({"ACL", "WHOAMI", "WHO"});
EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl whoami' command"));
@ -108,6 +113,7 @@ TEST_F(AclFamilyTest, AclWhoAmI) {
}
TEST_F(AclFamilyTest, TestAllCategories) {
TestInitAclFam();
for (auto& cat : acl::REVERSE_CATEGORY_INDEX_TABLE) {
if (cat != "_RESERVED") {
auto resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("+@", cat)});
@ -115,7 +121,7 @@ TEST_F(AclFamilyTest, TestAllCategories) {
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL",
UnorderedElementsAre("user default on nopass +@ALL +ALL",
absl::StrCat("user kostas off nopass ", "+@", cat)));
resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)});
@ -123,7 +129,7 @@ TEST_F(AclFamilyTest, TestAllCategories) {
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL",
UnorderedElementsAre("user default on nopass +@ALL +ALL",
absl::StrCat("user kostas off nopass ", "+@NONE")));
resp = Run({"ACL", "DELUSER", "kostas"});
@ -152,4 +158,31 @@ TEST_F(AclFamilyTest, TestAllCategories) {
// EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL", "user kostas
// off nopass +@NONE"));
}
TEST_F(AclFamilyTest, TestAllCommands) {
TestInitAclFam();
const auto& rev_indexer = acl::CommandsRevIndexer();
for (const auto& family : rev_indexer) {
for (const auto& command_name : family) {
auto resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("+", command_name)});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL +ALL",
absl::StrCat("user kostas off nopass +@NONE ",
"+", command_name)));
resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)});
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL",
absl::StrCat("user kostas off nopass ", "+@NONE")));
resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, "OK");
}
}
}
} // namespace dfly

249
src/server/acl/helpers.cc Normal file
View file

@ -0,0 +1,249 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/acl/helpers.h"
#include <limits>
#include <vector>
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "server/acl/acl_commands_def.h"
#include "server/common.h"
namespace dfly::acl {
std::string AclCatToString(uint32_t acl_category) {
std::string tmp;
if (acl_category == acl::ALL) {
return "+@ALL";
}
if (acl_category == acl::NONE) {
return "+@NONE";
}
const std::string prefix = "+@";
const std::string postfix = " ";
for (uint32_t i = 0; i < 32; ++i) {
uint32_t cat_bit = 1ULL << i;
if (acl_category & cat_bit) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[i], postfix);
}
}
tmp.pop_back();
return tmp;
}
std::string AclCommandToString(const std::vector<uint64_t>& acl_category) {
std::string result;
const std::string prefix = "+";
const std::string postfix = " ";
const auto& rev_index = CommandsRevIndexer();
bool all = true;
size_t family_id = 0;
for (auto family : acl_category) {
for (uint64_t i = 0; i < 64; ++i) {
const uint64_t cmd_bit = 1ULL << i;
if (family & cmd_bit && i < rev_index[family_id].size()) {
absl::StrAppend(&result, prefix, rev_index[family_id][i], postfix);
continue;
}
if (i < rev_index[family_id].size()) {
all = false;
}
}
++family_id;
}
if (!result.empty()) {
result.pop_back();
}
return all ? "+ALL" : result;
}
std::string PrettyPrintSha(std::string_view pass, bool all) {
if (all) {
return absl::BytesToHexString(pass);
}
return absl::BytesToHexString(pass.substr(0, 15)).substr(0, 15);
};
std::optional<std::string> MaybeParsePassword(std::string_view command) {
if (command == "nopass") {
return std::string(command);
}
if (command[0] != '>') {
return {};
}
return std::string(command.substr(1));
}
std::optional<bool> MaybeParseStatus(std::string_view command) {
if (command == "ON") {
return true;
}
if (command == "OFF") {
return false;
}
return {};
}
using OptCat = std::optional<uint32_t>;
// bool == true if +
// bool == false if -
std::pair<OptCat, bool> MaybeParseAclCategory(std::string_view command) {
if (absl::StartsWith(command, "+@")) {
auto res = CATEGORY_INDEX_TABLE.find(command.substr(2));
if (res == CATEGORY_INDEX_TABLE.end()) {
return {};
}
return {res->second, true};
}
if (absl::StartsWith(command, "-@")) {
auto res = CATEGORY_INDEX_TABLE.find(command.substr(2));
if (res == CATEGORY_INDEX_TABLE.end()) {
return {};
}
return {res->second, false};
}
return {};
}
bool IsIndexAllCommandsFlag(size_t index) {
return index == std::numeric_limits<size_t>::max();
}
std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,
const CommandRegistry& registry) {
const auto all_commands = std::pair<size_t, uint64_t>{std::numeric_limits<size_t>::max(), 0};
if (command == "+ALL") {
return {all_commands, true};
}
if (command == "-ALL") {
return {all_commands, false};
}
if (absl::StartsWith(command, "+")) {
auto res = registry.Find(command.substr(1));
if (!res) {
return {};
}
std::pair<size_t, uint64_t> cmd{res->GetFamily(), res->GetBitIndex()};
return {cmd, true};
}
if (absl::StartsWith(command, "-")) {
auto res = registry.Find(command.substr(1));
if (!res) {
return {};
}
std::pair<size_t, uint64_t> cmd{res->GetFamily(), res->GetBitIndex()};
return {cmd, false};
}
return {};
}
MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames,
std::string_view file_contents) {
// This is fine, a very large file will top at 1-2 mb. And that's for 5000+ users with 400
// characters of ACL infor...
std::vector<std::string_view> commands = absl::StrSplit(file_contents, "\n");
std::vector<std::vector<std::string_view>> materialized;
materialized.reserve(commands.size());
usernames->reserve(commands.size());
for (auto& command : commands) {
if (command.empty())
continue;
std::vector<std::string_view> cmds = absl::StrSplit(command, ' ');
if (cmds[0] != "ACL" || cmds[1] != "SETUSER" || cmds.size() < 3) {
return {};
}
usernames->push_back(std::string(cmds[2]));
cmds.erase(cmds.begin(), cmds.begin() + 3);
materialized.push_back(cmds);
}
return materialized;
}
using facade::ErrorReply;
template <typename T>
std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
const CommandRegistry& registry,
bool hashed) {
User::UpdateRequest req;
for (auto& arg : args) {
if (auto pass = MaybeParsePassword(facade::ToSV(arg)); pass) {
if (req.password) {
return ErrorReply("Only one password is allowed");
}
req.password = std::move(pass);
req.is_hashed = hashed;
continue;
}
if constexpr (std::is_same_v<T, facade::CmdArgList>) {
ToUpper(&arg);
}
const auto command = facade::ToSV(arg);
if (auto status = MaybeParseStatus(command); status) {
if (req.is_active) {
return ErrorReply("Multiple ON/OFF are not allowed");
}
req.is_active = *status;
continue;
}
auto [cat, add] = MaybeParseAclCategory(command);
if (cat) {
using Sign = User::Sign;
using Val = std::pair<Sign, uint32_t>;
auto val = add ? Val{Sign::PLUS, *cat} : Val{Sign::MINUS, *cat};
req.categories.push_back(val);
continue;
}
auto [cmd, sign] = MaybeParseAclCommand(command, registry);
if (!cmd) {
return ErrorReply(absl::StrCat("Unrecognized parameter ", command));
}
using Sign = User::Sign;
using Val = User::UpdateRequest::CommandsValueType;
;
auto [index, bit] = *cmd;
auto val = sign ? Val{Sign::PLUS, index, bit} : Val{Sign::MINUS, index, bit};
req.commands.push_back(val);
}
return req;
}
using facade::CmdArgList;
template std::variant<User::UpdateRequest, ErrorReply>
ParseAclSetUser<std::vector<std::string_view>&>(std::vector<std::string_view>&,
const CommandRegistry& registry, bool hashed);
template std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser<CmdArgList>(
CmdArgList args, const CommandRegistry& registry, bool hashed);
} // namespace dfly::acl

47
src/server/acl/helpers.h Normal file
View file

@ -0,0 +1,47 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
#include <variant>
#include "facade/facade_types.h"
#include "server/acl/user.h"
#include "server/command_registry.h"
namespace dfly::acl {
std::string AclCatToString(uint32_t acl_category);
std::string AclCommandToString(const std::vector<uint64_t>& acl_category);
std::string PrettyPrintSha(std::string_view pass, bool all = false);
std::optional<std::string> MaybeParsePassword(std::string_view command);
std::optional<bool> MaybeParseStatus(std::string_view command);
using OptCat = std::optional<uint32_t>;
std::pair<OptCat, bool> MaybeParseAclCategory(std::string_view command);
bool IsIndexAllCommandsFlag(size_t index);
using OptCommand = std::optional<std::pair<size_t, uint64_t>>;
std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,
const CommandRegistry& registry);
template <typename T>
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
T args, const CommandRegistry& registry, bool hashed = false);
using MaterializedContents = std::optional<std::vector<std::vector<std::string_view>>>;
MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames,
std::string_view file_contents);
} // namespace dfly::acl

View file

@ -6,7 +6,10 @@
#include <openssl/sha.h>
#include <limits>
#include "absl/strings/escaping.h"
#include "server/acl/helpers.h"
namespace dfly::acl {
@ -22,7 +25,7 @@ std::string StringSHA256(std::string_view password) {
} // namespace
User::User() {
// acl_categories_ = AclCat::ACL_CATEGORY_ADMIN;
commands_ = std::vector<uint64_t>(NumberOfFamilies(), 0);
}
void User::Update(UpdateRequest&& req) {
@ -38,6 +41,14 @@ void User::Update(UpdateRequest&& req) {
UnsetAclCategories(category);
}
for (auto [sign, index, bit_index] : req.commands) {
if (sign == Sign::PLUS) {
SetAclCommands(index, bit_index);
continue;
}
UnsetAclCommands(index, bit_index);
}
if (req.is_active) {
SetIsActive(*req.is_active);
}
@ -66,22 +77,47 @@ bool User::HasPassword(std::string_view password) const {
return *password_hash_ == StringSHA256(password);
}
void User::SetAclCategories(uint64_t cat) {
void User::SetAclCategories(uint32_t cat) {
acl_categories_ |= cat;
}
void User::UnsetAclCategories(uint64_t cat) {
void User::UnsetAclCategories(uint32_t cat) {
SetAclCategories(cat);
acl_categories_ ^= cat;
}
void User::SetAclCommands(size_t index, uint64_t bit_index) {
if (IsIndexAllCommandsFlag(index)) {
for (auto& family : commands_) {
family = ALL_COMMANDS;
}
return;
}
commands_[index] |= bit_index;
}
void User::UnsetAclCommands(size_t index, uint64_t bit_index) {
if (IsIndexAllCommandsFlag(index)) {
for (auto& family : commands_) {
family = NONE_COMMANDS;
}
return;
}
SetAclCommands(index, bit_index);
commands_[index] ^= bit_index;
}
uint32_t User::AclCategory() const {
return acl_categories_;
}
// For ACL commands
// void SetAclCommand()
// void AclCommand() const;
std::vector<uint64_t> User::AclCommands() const {
return commands_;
}
const std::vector<uint64_t>& User::AclCommandsRef() const {
return commands_;
}
void User::SetIsActive(bool is_active) {
is_active_ = is_active;

View file

@ -9,6 +9,7 @@
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>
@ -27,11 +28,14 @@ class User final {
std::vector<std::pair<Sign, uint32_t>> categories;
// DATATYPE_BITSET commands;
std::optional<bool> is_active{};
bool is_hashed{false};
// If index s numberic_limits::max() then it's a +all flag
using CommandsValueType = std::tuple<Sign, size_t /*index*/, uint64_t /*bit*/>;
using CommandsUpdateType = std::vector<CommandsValueType>;
CommandsUpdateType commands;
};
/* Used for default user
@ -51,19 +55,25 @@ class User final {
uint32_t AclCategory() const;
// TODO
// For ACL commands
// void SetAclCommand()
// void AclCommand() const;
std::vector<uint64_t> AclCommands() const;
const std::vector<uint64_t>& AclCommandsRef() const;
bool IsActive() const;
std::string_view Password() const;
// Selector maps a command string (like HSET, SET etc) to
// its respective ID within the commands vector.
static size_t Selector(std::string_view);
private:
// For ACL categories
void SetAclCategories(uint64_t cat);
void UnsetAclCategories(uint64_t cat);
void SetAclCategories(uint32_t cat);
void UnsetAclCategories(uint32_t cat);
// For ACL commands
void SetAclCommands(size_t index, uint64_t bit_index);
void UnsetAclCommands(size_t index, uint64_t bit_index);
// For is_active flag
void SetIsActive(bool is_active);
@ -75,6 +85,11 @@ class User final {
// password hashed with xx64
std::optional<std::string> password_hash_;
uint32_t acl_categories_{NONE};
// Each element index in the vector corresponds to a familly of commands
// Each bit in the uin64_t field at index id, corresponds to a specific
// command of that family. Look on TableCommandBuilder and on Service::Register
// on how this mapping is built during the startup/registration of commands
std::vector<uint64_t> commands_;
// we have at least 221 commands including a bunch of subcommands
// LARGE_BITFIELD_DATATYPE acl_commands_;

View file

@ -4,6 +4,7 @@
#include "server/acl/user_registry.h"
#include <limits>
#include <mutex>
#include "core/fibers.h"
@ -12,12 +13,6 @@
namespace dfly::acl {
UserRegistry::UserRegistry() {
std::pair<User::Sign, uint32_t> acl{User::Sign::PLUS, acl::ALL};
User::UpdateRequest req{{}, {acl}, true};
MaybeAddAndUpdate("default", std::move(req));
}
void UserRegistry::MaybeAddAndUpdate(std::string_view username, User::UpdateRequest req) {
std::unique_lock<util::SharedMutex> lock(mu_);
auto& user = registry_[username];
@ -35,7 +30,7 @@ UserRegistry::UserCredentials UserRegistry::GetCredentials(std::string_view user
if (it == registry_.end()) {
return {};
}
return {it->second.AclCategory()};
return {it->second.AclCategory(), it->second.AclCommands()};
}
bool UserRegistry::IsUserActive(std::string_view username) const {
@ -81,4 +76,16 @@ UserRegistry::UserWithWriteLock UserRegistry::MaybeAddAndUpdateWithLock(std::str
return {std::move(lock), user, exists};
}
void UserRegistry::Init() {
// Add default user
User::UpdateRequest::CommandsUpdateType tmp(NumberOfFamilies());
size_t id = 0;
for (auto& elem : tmp) {
elem = {User::Sign::PLUS, id++, acl::ALL_COMMANDS};
}
std::pair<User::Sign, uint32_t> acl{User::Sign::PLUS, acl::ALL};
User::UpdateRequest req{{}, {acl}, true, false, std::move(tmp)};
MaybeAddAndUpdate("default", std::move(req));
}
} // namespace dfly::acl

View file

@ -23,11 +23,13 @@ class UserRegistry {
template <template <typename T> typename LockT, typename RegT> class RegistryWithLock;
public:
UserRegistry();
UserRegistry() = default;
UserRegistry(const UserRegistry&) = delete;
UserRegistry(UserRegistry&&) = delete;
void Init();
using RegistryType = absl::flat_hash_map<std::string, User>;
// Acquires a write lock of mu_
@ -43,6 +45,7 @@ class UserRegistry {
struct UserCredentials {
uint32_t acl_categories{0};
std::vector<uint64_t> acl_commands;
};
// Acquires a read lock

View file

@ -23,7 +23,7 @@ TEST_F(UserRegistryTest, BasicOp) {
const std::string username = "kostas";
const std::string pass = "mypass";
User::UpdateRequest req{pass, {}, true};
User::UpdateRequest req{pass, {}, true, false, {}};
registry.MaybeAddAndUpdate(username, std::move(req));
CHECK_EQ(registry.AuthUser(username, pass), true);
CHECK_EQ(registry.IsUserActive(username), true);
@ -32,14 +32,14 @@ TEST_F(UserRegistryTest, BasicOp) {
using Sign = User::Sign;
std::vector<std::pair<Sign, uint32_t>> cat = {{Sign::PLUS, LIST}, {Sign::PLUS, SET}};
req = User::UpdateRequest{{}, std::move(cat), {}};
req = User::UpdateRequest{{}, std::move(cat), true, false, {}};
registry.MaybeAddAndUpdate(username, std::move(req));
auto acl_categories = registry.GetCredentials(username).acl_categories;
uint32_t expected_result = NONE | LIST | SET;
CHECK_EQ(acl_categories, expected_result);
cat.push_back({Sign::MINUS, LIST});
req = User::UpdateRequest{{}, std::move(cat), {}};
req = User::UpdateRequest{{}, std::move(cat), true, false, {}};
registry.MaybeAddAndUpdate(username, std::move(req));
acl_categories = registry.GetCredentials(username).acl_categories;
expected_result = NONE | SET;

View file

@ -4,12 +4,22 @@
#include "server/acl/validator.h"
#include "base/logging.h"
#include "server/acl/acl_commands_def.h"
#include "server/server_state.h"
namespace dfly::acl {
[[nodiscard]] bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx,
const facade::CommandId& id) {
auto command_credentials = id.acl_categories();
return (cntx.acl_categories & command_credentials) != 0;
const auto cat_credentials = id.acl_categories();
const size_t index = id.GetFamily();
const uint64_t command_mask = id.GetBitIndex();
DCHECK_LT(index, cntx.acl_commands.size());
return (cntx.acl_categories & cat_credentials) != 0 ||
(cntx.acl_commands[index] & command_mask) != 0;
}
} // namespace dfly::acl

View file

@ -841,7 +841,7 @@ constexpr uint32_t kSetBit = WRITE | BITMAP | SLOW;
void BitOpsFamily::Register(CommandRegistry* registry) {
using CI = CommandId;
registry->StartFamily();
*registry
<< CI{"BITPOS", CO::CommandOpt::READONLY, -3, 1, 1, 1, acl::kBitPos}.SetHandler(&BitPos)
<< CI{"BITCOUNT", CO::READONLY, -2, 1, 1, 1, acl::kBitCount}.SetHandler(&BitCount)

View file

@ -608,6 +608,7 @@ constexpr uint32_t kReadWrite = FAST | CONNECTION;
} // namespace acl
void ClusterFamily::Register(CommandRegistry* registry) {
registry->StartFamily();
*registry << CI{"CLUSTER", CO::READONLY, -2, 0, 0, 0, acl::kCluster}.HFUNC(Cluster)
<< CI{"DFLYCLUSTER", CO::ADMIN | CO::GLOBAL_TRANS | CO::HIDDEN, -2, 0, 0, 0,
acl::kDflyCluster}

View file

@ -12,6 +12,7 @@
#include "base/flags.h"
#include "base/logging.h"
#include "facade/error.h"
#include "server/acl/acl_commands_def.h"
#include "server/conn_context.h"
#include "server/server_state.h"
@ -103,11 +104,24 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
}
k = it->second;
}
family_of_commands_.back().push_back(std::string(k));
cmd.SetFamily(family_of_commands_.size() - 1);
cmd.SetBitIndex(1ULL << bit_index_++);
CHECK(cmd_map_.emplace(k, std::move(cmd)).second) << k;
return *this;
}
void CommandRegistry::StartFamily() {
family_of_commands_.push_back({});
bit_index_ = 0;
}
CommandRegistry::FamiliesVec CommandRegistry::GetFamilies() {
return std::move(family_of_commands_);
}
namespace CO {
const char* OptName(CO::CommandOpt fl) {

View file

@ -116,9 +116,6 @@ class CommandId : public facade::CommandId {
};
class CommandRegistry {
absl::flat_hash_map<std::string_view, CommandId> cmd_map_;
absl::flat_hash_map<std::string, std::string> cmd_rename_map_;
public:
CommandRegistry();
@ -159,6 +156,17 @@ class CommandRegistry {
cb(k_v.second.name(), src);
}
}
using FamiliesVec = std::vector<std::vector<std::string>>;
void StartFamily();
FamiliesVec GetFamilies();
private:
absl::flat_hash_map<std::string_view, CommandId> cmd_map_;
absl::flat_hash_map<std::string, std::string> cmd_rename_map_;
FamiliesVec family_of_commands_;
size_t bit_index_;
};
} // namespace dfly

View file

@ -5,6 +5,7 @@
#include "server/conn_context.h"
#include "base/logging.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/engine_shard_set.h"
#include "server/server_family.h"
@ -79,6 +80,7 @@ const CommandId* StoredCmd::Cid() const {
ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb)
: facade::ConnectionContext(nullptr, nullptr), transaction{tx} {
acl_commands = std::vector<uint64_t>(acl::NumberOfFamilies(), acl::ALL_COMMANDS);
if (tx) { // If we have a carrier transaction, this context is used for squashing
DCHECK(owner);
conn_state.db_index = owner->conn_state.db_index;

View file

@ -157,6 +157,7 @@ class ConnectionContext : public facade::ConnectionContext {
public:
ConnectionContext(::io::Sink* stream, facade::Connection* owner)
: facade::ConnectionContext(stream, owner) {
acl_commands = std::vector<uint64_t>(acl::NumberOfFamilies(), acl::ALL_COMMANDS);
}
ConnectionContext(const ConnectionContext* owner, Transaction* tx,
@ -199,6 +200,7 @@ class ConnectionContext : public facade::ConnectionContext {
std::string authed_username{"default"};
uint32_t acl_categories{acl::ALL};
std::vector<uint64_t> acl_commands;
private:
void EnableMonitoring(bool enable) {

View file

@ -1468,7 +1468,7 @@ constexpr uint32_t kRestore = KEYSPACE | WRITE | SLOW | DANGEROUS;
void GenericFamily::Register(CommandRegistry* registry) {
constexpr auto kSelectOpts = CO::LOADING | CO::FAST | CO::NOSCRIPT;
registry->StartFamily();
*registry
<< CI{"DEL", CO::WRITE, -2, 1, -1, 1, acl::kDel}.HFUNC(Del)
/* Redis compatibility:

View file

@ -294,7 +294,7 @@ constexpr uint32_t kPFMerge = WRITE | HYPERLOGLOG | SLOW;
void HllFamily::Register(CommandRegistry* registry) {
using CI = CommandId;
registry->StartFamily();
*registry << CI{"PFADD", CO::WRITE, -3, 1, 1, 1, acl::kPFAdd}.SetHandler(PFAdd)
<< CI{"PFCOUNT", CO::WRITE, -2, 1, -1, 1, acl::kPFCount}.SetHandler(PFCount)
<< CI{"PFMERGE", CO::WRITE, -2, 1, -1, 1, acl::kPFMerge}.SetHandler(PFMerge);

View file

@ -1141,6 +1141,7 @@ constexpr uint32_t kHVals = READ | HASH | SLOW;
} // namespace acl
void HSetFamily::Register(CommandRegistry* registry) {
registry->StartFamily();
*registry
<< CI{"HDEL", CO::FAST | CO::WRITE, -3, 1, 1, 1, acl::kHDel}.HFUNC(HDel)
<< CI{"HLEN", CO::FAST | CO::READONLY, 2, 1, 1, 1, acl::kHLen}.HFUNC(HLen)

View file

@ -1796,6 +1796,7 @@ void JsonFamily::Get(CmdArgList args, ConnectionContext* cntx) {
// TODO: Add sensible defaults/categories to json commands
void JsonFamily::Register(CommandRegistry* registry) {
registry->StartFamily();
*registry << CI{"JSON.GET", CO::READONLY | CO::FAST, -2, 1, 1, 1, acl::JSON}.HFUNC(Get);
*registry << CI{"JSON.MGET", CO::READONLY | CO::FAST | CO::REVERSE_MAPPING, -3, 1, -2, 1,
acl::JSON}

View file

@ -1326,6 +1326,7 @@ constexpr uint32_t kBLMove = READ | LIST | SLOW | BLOCKING;
} // namespace acl
void ListFamily::Register(CommandRegistry* registry) {
registry->StartFamily();
*registry
<< CI{"LPUSH", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1, acl::kLPush}.HFUNC(LPush)
<< CI{"LPUSHX", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1, acl::kLPushX}.HFUNC(LPushX)

View file

@ -2139,10 +2139,10 @@ constexpr uint32_t kPubSub = SLOW;
constexpr uint32_t kCommand = SLOW | CONNECTION;
} // namespace acl
void Service::RegisterCommands() {
void Service::Register(CommandRegistry* registry) {
using CI = CommandId;
registry_
registry->StartFamily();
*registry
<< CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0, acl::kQuit}.HFUNC(Quit)
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0, acl::kMulti}.HFUNC(Multi)
<< CI{"WATCH", CO::LOADING, -2, 1, -1, 1, acl::kWatch}.HFUNC(Watch)
@ -2168,7 +2168,10 @@ void Service::RegisterCommands() {
<< CI{"MONITOR", CO::ADMIN, 1, 0, 0, 0, acl::kMonitor}.MFUNC(Monitor)
<< CI{"PUBSUB", CO::LOADING | CO::FAST, -1, 0, 0, 0, acl::kPubSub}.MFUNC(Pubsub)
<< CI{"COMMAND", CO::LOADING | CO::NOSCRIPT, -1, 0, 0, 0, acl::kCommand}.MFUNC(Command);
}
void Service::RegisterCommands() {
Register(&registry_);
StreamFamily::Register(&registry_);
StringFamily::Register(&registry_);
GenericFamily::Register(&registry_);
@ -2184,14 +2187,16 @@ void Service::RegisterCommands() {
SearchFamily::Register(&registry_);
#endif
acl_family_.Register(&registry_);
server_family_.Register(&registry_);
cluster_family_.Register(&registry_);
acl_family_.Register(&registry_);
acl::BuildIndexers(registry_.GetFamilies());
// Only after all the commands are registered
registry_.Init(pp_.size());
using CI = CommandId;
if (VLOG_IS_ON(1)) {
LOG(INFO) << "Multi-key commands are: ";
registry_.Traverse([](std::string_view key, const CI& cid) {
@ -2214,6 +2219,10 @@ void Service::RegisterCommands() {
}
}
void Service::TestInit() {
acl_family_.Init(nullptr, &user_registry_);
}
void SetMaxMemoryFlag(uint64_t value) {
absl::SetFlag(&FLAGS_maxmemory, {value});
}

View file

@ -9,6 +9,7 @@
#include "base/varz_value.h"
#include "core/interpreter.h"
#include "facade/service_interface.h"
#include "server/acl/acl_commands_def.h"
#include "server/acl/acl_family.h"
#include "server/acl/user_registry.h"
#include "server/cluster/cluster_family.h"
@ -116,6 +117,10 @@ class Service : public facade::ServiceInterface {
return server_family_;
}
// Utility function used in unit tests
// Do not use in production, only meant to be used by unit tests
void TestInit();
private:
static void Quit(CmdArgList args, ConnectionContext* cntx);
static void Multi(CmdArgList args, ConnectionContext* cntx);
@ -160,6 +165,7 @@ class Service : public facade::ServiceInterface {
void CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& args);
void RegisterCommands();
void Register(CommandRegistry* registry);
base::VarzValue::Map GetVarzStats();

View file

@ -520,7 +520,7 @@ void SearchFamily::FtProfile(CmdArgList args, ConnectionContext* cntx) {
void SearchFamily::Register(CommandRegistry* registry) {
using CI = CommandId;
registry->StartFamily();
*registry << CI{"FT.CREATE", CO::GLOBAL_TRANS, -2, 0, 0, 0, acl::FT_SEARCH}.HFUNC(FtCreate)
<< CI{"FT.DROPINDEX", CO::GLOBAL_TRANS, -2, 0, 0, 0, acl::FT_SEARCH}.HFUNC(FtDropIndex)
<< CI{"FT.INFO", CO::GLOBAL_TRANS, 2, 0, 0, 0, acl::FT_SEARCH}.HFUNC(FtInfo)

View file

@ -1013,6 +1013,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
cntx->authed_username = username;
auto cred = registry->GetCredentials(username);
cntx->acl_categories = cred.acl_categories;
cntx->acl_commands = cred.acl_commands;
return (*cntx)->SendOk();
}
return (*cntx)->SendError(absl::StrCat("Could not authorize user: ", username));
@ -1989,13 +1990,14 @@ constexpr uint32_t kReplConf = ADMIN | SLOW | DANGEROUS;
constexpr uint32_t kRole = ADMIN | FAST | DANGEROUS;
constexpr uint32_t kSlowLog = ADMIN | SLOW | DANGEROUS;
constexpr uint32_t kScript = SLOW | SCRIPTING;
// TODO(check this)
constexpr uint32_t kDfly = ADMIN;
} // namespace acl
void ServerFamily::Register(CommandRegistry* registry) {
constexpr auto kReplicaOpts = CO::LOADING | CO::ADMIN | CO::GLOBAL_TRANS;
constexpr auto kMemOpts = CO::LOADING | CO::READONLY | CO::FAST | CO::NOSCRIPT;
registry->StartFamily();
*registry
<< CI{"AUTH", CO::NOSCRIPT | CO::FAST | CO::LOADING, -2, 0, 0, 0, acl::kAuth}.HFUNC(Auth)
<< CI{"BGSAVE", CO::ADMIN | CO::GLOBAL_TRANS, 1, 0, 0, 0, acl::kBGSave}.HFUNC(Save)

View file

@ -1580,6 +1580,7 @@ constexpr uint32_t kSScan = READ | SET | SLOW;
} // namespace acl
void SetFamily::Register(CommandRegistry* registry) {
registry->StartFamily();
*registry
<< CI{"SADD", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1, acl::kSAdd}.HFUNC(SAdd)
<< CI{"SDIFF", CO::READONLY, -2, 1, -1, 1, acl::kSDiff}.HFUNC(SDiff)

View file

@ -7,7 +7,6 @@
#include "facade/op_status.h"
#include "server/common.h"
typedef struct intset intset;
typedef struct redisObject robj;
typedef struct dict dict;

View file

@ -2541,7 +2541,7 @@ constexpr uint32_t kXGroupHelp = READ | STREAM | SLOW;
void StreamFamily::Register(CommandRegistry* registry) {
using CI = CommandId;
registry->StartFamily();
*registry
<< CI{"XADD", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, 1, acl::kXAdd}.HFUNC(XAdd)
<< CI{"XCLAIM", CO::WRITE | CO::FAST, -6, 1, 1, 1, acl::kXClaim}.HFUNC(XClaim)

View file

@ -1499,6 +1499,7 @@ constexpr uint32_t kClThrottle = THROTTLE;
} // namespace acl
void StringFamily::Register(CommandRegistry* registry) {
registry->StartFamily();
*registry
<< CI{"SET", CO::WRITE | CO::DENYOOM | CO::NO_AUTOJOURNAL, -3, 1, 1, 1, acl::kSet}.HFUNC(Set)
<< CI{"SETEX", CO::WRITE | CO::DENYOOM | CO::NO_AUTOJOURNAL, 4, 1, 1, 1, acl::kSetEx}.HFUNC(

View file

@ -562,4 +562,8 @@ void BaseFamilyTest::SetTestFlag(string_view flag_name, string_view new_value) {
CHECK(flag->ParseFrom(new_value, &error)) << "Error: " << error;
}
void BaseFamilyTest::TestInitAclFam() {
service_->TestInit();
}
} // namespace dfly

View file

@ -127,6 +127,8 @@ class BaseFamilyTest : public ::testing::Test {
static void SetTestFlag(std::string_view flag_name, std::string_view new_value);
void TestInitAclFam();
std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_;
unsigned num_threads_ = 3;

View file

@ -2574,7 +2574,7 @@ constexpr uint32_t kGeoDist = READ | GEO | SLOW;
void ZSetFamily::Register(CommandRegistry* registry) {
constexpr uint32_t kStoreMask = CO::WRITE | CO::VARIADIC_KEYS | CO::REVERSE_MAPPING | CO::DENYOOM;
registry->StartFamily();
*registry
<< CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1, acl::kZAdd}.HFUNC(ZAdd)
<< CI{"BZPOPMIN",

View file

@ -43,6 +43,25 @@ async def test_acl_setuser(async_client):
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL" in result
# commands
await async_client.execute_command("ACL SETUSER kostas +set +get +hset")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL +SET +GET +HSET" in result
await async_client.execute_command("ACL SETUSER kostas -set -get +hset")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL +HSET" in result
# interleaved
await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@NONE" in result
# interleaved with categories
await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@STRING +SET" in result
@pytest.mark.asyncio
async def test_acl_categories(async_client):
@ -94,11 +113,26 @@ async def test_acl_categories(async_client):
@pytest.mark.asyncio
async def test_acl_categories_multi_exec_squash(df_local_factory):
async def test_acl_commands(async_client):
await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get")
result = await async_client.execute_command("AUTH random mypass")
assert result == "OK"
result = await async_client.execute_command("SET foo bar")
assert result == "OK"
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("ZADD myset 1 two")
@pytest.mark.asyncio
async def test_acl_cat_commands_multi_exec_squash(df_local_factory):
df = df_local_factory.create(multi_exec_squash=True, port=1111)
df.start()
# Testing acl categories
client = aioredis.Redis(port=df.port)
res = await client.execute_command("ACL SETUSER kk ON >kk +@transaction +@string")
assert res == b"OK"
@ -155,6 +189,34 @@ async def test_acl_categories_multi_exec_squash(df_local_factory):
await admin_client.close()
await client.close()
# Testing acl commands
client = aioredis.Redis(port=df.port)
res = await client.execute_command("ACL SETUSER myuser ON >kk +@transaction +set")
assert res == b"OK"
res = await client.execute_command("AUTH myuser kk")
assert res == b"OK"
await client.execute_command("MULTI")
assert res == b"OK"
for x in range(33):
await client.execute_command(f"SET x{x} {x}")
await client.execute_command("EXEC")
# NOPERM between multi and exec
admin_client = aioredis.Redis(port=df.port)
res = await admin_client.execute_command("ACL SETUSER myuser -set")
assert res == b"OK"
# NOPERM while executing multi
await client.execute_command("MULTI")
with pytest.raises(redis.exceptions.ResponseError):
await client.execute_command(f"SET x{x} {x}")
await admin_client.close()
await client.close()
@pytest.mark.asyncio
async def test_acl_deluser(df_server):
@ -264,15 +326,14 @@ async def test_good_acl_file(df_local_factory, tmp_dir):
df.start()
client = aioredis.Redis(port=df.port)
await client.execute_command("ACL SETUSER roy ON >mypass +@STRING")
await client.execute_command("ACL SETUSER roy ON >mypass +@STRING +HSET")
await client.execute_command("ACL SETUSER shahar >mypass +@SET")
await client.execute_command("ACL SETUSER vlad +@STRING")
result = await client.execute_command("ACL LIST")
assert 4 == len(result)
assert "user roy on ea71c25a7a60224 +@STRING" in result
assert 3 == len(result)
assert "user roy on ea71c25a7a60224 +@STRING +HSET" in result
assert "user shahar off ea71c25a7a60224 +@SET" in result
assert "user default on nopass +@ALL" in result
assert "user vlad off nopass +@STRING" in result
result = await client.execute_command("ACL DELUSER shahar")
@ -281,12 +342,10 @@ async def test_good_acl_file(df_local_factory, tmp_dir):
result = await client.execute_command("ACL SAVE")
result = await client.execute_command("ACL LOAD")
# assert result == b"OK"
result = await client.execute_command("ACL LIST")
assert 3 == len(result)
assert "user roy on ea71c25a7a60224 +@STRING" in result
assert "user default on nopass +@ALL" in result
assert 2 == len(result)
assert "user roy on ea71c25a7a60224 +@STRING +HSET" in result
assert "user vlad off nopass +@STRING" in result
await client.close()