fix: acl compatibility (#3147)

* remove acl categories from context and all acl checks
* category assign,ent now assigns all the acl commands for that category to the user
* introduce modification order of acl's per user
* acl rules are now printed in the same order as in redis/valkey
* remove old user_registry_test which was part of the poc
This commit is contained in:
Kostas Kyrimis 2024-06-13 10:56:30 +03:00 committed by GitHub
parent 165631a5aa
commit d2ae0ab75c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 318 additions and 230 deletions

View file

@ -7,7 +7,6 @@
#include <string_view>
#include "facade/facade_types.h"
#include "server/acl/acl_commands_def.h"
namespace facade {

View file

@ -433,7 +433,6 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (self->cntx()) {
if (msg.username == self->cntx()->authed_username) {
self->cntx()->acl_categories = msg.categories;
self->cntx()->acl_commands = msg.commands;
self->cntx()->keys = msg.keys;
}

View file

@ -117,7 +117,6 @@ class Connection : public util::Connection {
// ACL Update message, contains ACL updates to be applied to the connection.
struct AclUpdateMessage {
std::string username;
uint32_t categories;
std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
};

View file

@ -117,7 +117,6 @@ cxx_test(hll_family_test dfly_test_lib LABELS DFLY)
cxx_test(bloom_family_test dfly_test_lib LABELS DFLY)
cxx_test(cluster/cluster_config_test dfly_test_lib LABELS DFLY)
cxx_test(cluster/cluster_family_test dfly_test_lib LABELS DFLY)
cxx_test(acl/user_registry_test dfly_test_lib LABELS DFLY)
cxx_test(acl/acl_family_test dfly_test_lib LABELS DFLY)
cxx_test(engine_shard_set_test dfly_test_lib LABELS DFLY)
cxx_test(search/search_family_test dfly_test_lib LABELS DFLY)
@ -135,4 +134,4 @@ add_dependencies(check_dfly dragonfly_test json_family_test list_family_test
generic_family_test memcache_parser_test rdb_test journal_test
redis_parser_test stream_family_test string_family_test
bitops_family_test set_family_test zset_family_test hll_family_test
cluster_config_test cluster_family_test user_registry_test acl_family_test)
cluster_config_test cluster_family_test acl_family_test)

View file

@ -5,7 +5,11 @@
#pragma once
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "base/logging.h"
#include "facade/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
namespace dfly::acl {
@ -84,6 +88,14 @@ inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED",
"BLOOM", "FT_SEARCH", "THROTTLE", "JSON"};
// bit index to index in the REVERSE_CATEGORY_INDEX_TABLE
using CategoryToIdxStore = absl::flat_hash_map<uint32_t, uint32_t>;
inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) {
static CategoryToIdxStore cat_idx = std::move(store);
return cat_idx;
}
using RevCommandField = std::vector<std::string>;
using RevCommandsIndexStore = std::vector<RevCommandField>;
@ -104,9 +116,39 @@ inline const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore sto
return rev_index_store;
}
inline void BuildIndexers(std::vector<std::vector<std::string>> families) {
using CategoryToCommandsIndexStore = absl::flat_hash_map<std::string, std::vector<uint64_t>>;
inline const CategoryToCommandsIndexStore& CategoryToCommandsIndex(
CategoryToCommandsIndexStore store = {}) {
static CategoryToCommandsIndexStore index = std::move(store);
return index;
}
inline void BuildIndexers(RevCommandsIndexStore families, CommandRegistry* cmd_registry) {
acl::NumberOfFamilies(families.size());
acl::CommandsRevIndexer(std::move(families));
CategoryToCommandsIndexStore index;
cmd_registry->Traverse([&](std::string_view name, auto& cid) {
auto cat = cid.acl_categories();
for (size_t i = 0; i < 32; ++i) {
if (cat & (1 << i)) {
std::string_view cat_name = REVERSE_CATEGORY_INDEX_TABLE[i];
if (index[cat_name].empty()) {
index[cat_name].resize(CommandsRevIndexer().size());
}
auto family = cid.GetFamily();
auto bit_index = cid.GetBitIndex();
index[cat_name][family] |= bit_index;
}
}
});
CategoryToCommandsIndex(std::move(index));
CategoryToIdxStore idx_store;
for (size_t i = 0; i < 32; ++i) {
idx_store[1 << i] = i;
}
CategoryToIdx(std::move(idx_store));
}
} // namespace dfly::acl

View file

@ -66,22 +66,23 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass);
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";
const std::string maybe_space_com = acl_keys.empty() ? "" : " ";
using namespace std::string_view_literals;
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys);
acl_cat_and_commands, maybe_space_com, acl_keys);
cntx->SendSimpleString(buffer);
}
}
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys) {
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
@ -90,7 +91,7 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, u
if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() &&
connection->cntx()) {
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys});
facade::Connection::AclUpdateMessage{user, update_commands, update_keys});
}
};
@ -113,10 +114,14 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
if (!exists) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
user.Update(std::move(default_req));
}
user.Update(std::move(req));
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(),
user.AclCommands(), user.Keys());
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys());
}
cntx->SendOk();
};
@ -194,20 +199,15 @@ std::string AclFamily::RegistryToString() const {
const std::string_view pass = user.Password();
const std::string password =
pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " ");
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";
using namespace std::string_view_literals;
absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password,
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys, "\n");
}
if (!result.empty()) {
result.pop_back();
acl_cat_and_commands, maybe_space, acl_keys, "\n");
}
return result;
@ -298,7 +298,10 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,
}
for (size_t i = 0; i < usernames.size(); ++i) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
auto& user = registry[usernames[i]];
user.Update(std::move(default_req));
user.Update(std::move(requests[i]));
}
@ -446,6 +449,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
const uint32_t cid_mask = CATEGORY_INDEX_TABLE.find(category)->second;
std::vector<std::string_view> results;
// TODO replace this with indexer
auto cb = [cid_mask, &results](auto name, auto& cid) {
if (cid_mask & cid.acl_categories()) {
results.push_back(name);
@ -510,10 +514,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
}
rb->SendSimpleString("commands");
std::string acl = absl::StrCat(AclCatToString(user.AclCategory()), " ",
AclCommandToString(user.AclCommandsRef()));
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
rb->SendSimpleString(acl);
rb->SendSimpleString(acl_cat_and_commands);
rb->SendSimpleString("keys");
std::string keys = AclKeysToString(user.Keys());
@ -572,9 +576,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
}
const auto& user = registry.find(username)->second;
const bool is_allowed = IsUserAllowedToInvokeCommandGeneric(
user.AclCategory(), user.AclCommandsRef(), {{}, true}, {}, *cid)
.first;
const bool is_allowed =
IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first;
if (is_allowed) {
cntx->SendOk();
return;

View file

@ -49,7 +49,7 @@ class AclFamily final {
// Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads
using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys);

View file

@ -46,16 +46,16 @@ TEST_F(AclFamilyTest, AclSetUser) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user vlad off nopass +@NONE"));
EXPECT_THAT(
vec, UnorderedElementsAre("user default on nopass +@ALL ~*", "user vlad off nopass -@ALL"));
resp = Run({"ACL", "SETUSER", "vlad", "+ACL"});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user vlad off nopass +@NONE +ACL"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*",
"user vlad off nopass -@ALL +ACL"));
}
TEST_F(AclFamilyTest, AclDelUser) {
@ -82,7 +82,7 @@ TEST_F(AclFamilyTest, AclDelUser) {
EXPECT_THAT(resp, IntArg(0));
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL +ALL ~*");
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL ~*");
Run({"ACL", "SETUSER", "michael", "ON"});
Run({"ACL", "SETUSER", "kobe", "ON"});
@ -103,9 +103,9 @@ TEST_F(AclFamilyTest, AclList) {
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user kostas off d74ff0ee8da3b98 +@ADMIN",
"user adi off d74ff0ee8da3b98 +@FAST"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*",
"user kostas off d74ff0ee8da3b98 -@ALL +@ADMIN",
"user adi off d74ff0ee8da3b98 -@ALL +@FAST"));
}
TEST_F(AclFamilyTest, AclAuth) {
@ -154,16 +154,16 @@ TEST_F(AclFamilyTest, TestAllCategories) {
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@", cat)));
UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ", "+@", cat)));
resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)});
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@NONE")));
UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ", "-@", cat)));
resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
@ -201,16 +201,16 @@ TEST_F(AclFamilyTest, TestAllCommands) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass +@NONE ",
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ",
"+", command_name)));
resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)});
resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@NONE")));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass ",
"-@ALL ", "-", command_name)));
resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
@ -259,7 +259,7 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(vec[2], "passwords");
EXPECT_TRUE(vec[3].GetVec().empty());
EXPECT_THAT(vec[4], "commands");
EXPECT_THAT(vec[5], "+@ALL +ALL");
EXPECT_THAT(vec[5], "+@ALL");
EXPECT_THAT(vec[6], "keys");
EXPECT_THAT(vec[7], "~*");
@ -271,7 +271,7 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(kvec[2], "passwords");
EXPECT_TRUE(kvec[3].GetVec().empty());
EXPECT_THAT(kvec[4], "commands");
EXPECT_THAT(kvec[5], "+@STRING +HSET");
EXPECT_THAT(kvec[5], "-@ALL +@STRING +HSET");
}
TEST_F(AclFamilyTest, TestDryRun) {

View file

@ -12,64 +12,113 @@
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "core/overloaded.h"
#include "facade/acl_commands_def.h"
#include "server/acl/acl_commands_def.h"
#include "server/acl/user.h"
#include "server/common.h"
namespace dfly::acl {
std::string AclCatToString(uint32_t acl_category) {
std::string tmp;
namespace {
std::string AclCatToString(uint32_t acl_category, User::Sign sign) {
std::string res = sign == User::Sign::PLUS ? "+@" : "-@";
if (acl_category == acl::ALL) {
return "+@ALL";
absl::StrAppend(&res, "ALL");
return res;
}
if (acl_category == acl::NONE) {
return "+@NONE";
}
const std::string prefix = "+@";
const std::string postfix = " ";
for (uint32_t i = 0; i < 32; ++i) {
uint32_t cat_bit = 1ULL << i;
if (acl_category & cat_bit) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[i], postfix);
}
}
tmp.pop_back();
return tmp;
const auto& index = CategoryToIdx().at(acl_category);
absl::StrAppend(&res, REVERSE_CATEGORY_INDEX_TABLE[index]);
return res;
}
std::string AclCommandToString(const std::vector<uint64_t>& acl_category) {
std::string AclCommandToString(size_t family, uint64_t mask, User::Sign sign) {
// This is constant but can be optimized with an indexer
const auto& rev_index = CommandsRevIndexer();
std::string res;
std::string prefix = (sign == User::Sign::PLUS) ? "+" : "-";
if (mask == ALL_COMMANDS) {
for (const auto& cmd : rev_index[family]) {
absl::StrAppend(&res, prefix, cmd, " ");
}
res.pop_back();
return res;
}
size_t pos = 0;
while (mask != 0) {
++pos;
mask = mask >> 1;
}
--pos;
absl::StrAppend(&res, prefix, rev_index[family][pos]);
return res;
}
struct CategoryAndMetadata {
User::CategoryChange change;
User::ChangeMetadata metadata;
};
struct CommandAndMetadata {
User::CommandChange change;
User::ChangeMetadata metadata;
};
using MergeResult = std::vector<std::variant<CategoryAndMetadata, CommandAndMetadata>>;
} // namespace
// Merge Category and Command changes and sort them by global order seq_no
MergeResult MergeTables(const User::CategoryChanges& categories,
const User::CommandChanges& commands) {
MergeResult result;
for (auto [cat, meta] : categories) {
result.push_back(CategoryAndMetadata{cat, meta});
}
for (auto [cmd, meta] : commands) {
result.push_back(CommandAndMetadata{cmd, meta});
}
std::sort(result.begin(), result.end(), [](const auto& l, const auto& r) {
auto fetch = [](const auto& l) { return l.metadata.seq_no; };
return std::visit(fetch, l) < std::visit(fetch, r);
});
return result;
}
std::string AclCatAndCommandToString(const User::CategoryChanges& cat,
const User::CommandChanges& cmds) {
std::string result;
const std::string prefix = "+";
const std::string postfix = " ";
const auto& rev_index = CommandsRevIndexer();
bool all = true;
auto tables = MergeTables(cat, cmds);
size_t family_id = 0;
for (auto family : acl_category) {
for (uint64_t i = 0; i < 64; ++i) {
const uint64_t cmd_bit = 1ULL << i;
if (family & cmd_bit && i < rev_index[family_id].size()) {
absl::StrAppend(&result, prefix, rev_index[family_id][i], postfix);
continue;
}
if (i < rev_index[family_id].size()) {
all = false;
}
}
++family_id;
auto cat_visitor = [&result](const CategoryAndMetadata& val) {
const auto& [change, meta] = val;
absl::StrAppend(&result, AclCatToString(change, meta.sign), " ");
};
auto cmd_visitor = [&result](const CommandAndMetadata& val) {
const auto& [change, meta] = val;
const auto [family, bit_index] = change;
absl::StrAppend(&result, AclCommandToString(family, bit_index, meta.sign), " ");
};
Overloaded visitor{cat_visitor, cmd_visitor};
for (auto change : tables) {
std::visit(visitor, change);
}
if (!result.empty()) {
result.pop_back();
}
return all ? "+ALL" : result;
return result;
}
std::string PrettyPrintSha(std::string_view pass, bool all) {
@ -157,21 +206,8 @@ std::pair<OptCat, bool> MaybeParseAclCategory(std::string_view command) {
return {};
}
bool IsIndexAllCommandsFlag(size_t index) {
return index == std::numeric_limits<size_t>::max();
}
std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,
const CommandRegistry& registry) {
const auto all_commands = std::pair<size_t, uint64_t>{std::numeric_limits<size_t>::max(), 0};
if (command == "+ALL") {
return {all_commands, true};
}
if (command == "-ALL") {
return {all_commands, false};
}
if (absl::StartsWith(command, "+")) {
auto res = registry.Find(command.substr(1));
if (!res) {
@ -281,7 +317,7 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
using Sign = User::Sign;
using Val = std::pair<Sign, uint32_t>;
auto val = add ? Val{Sign::PLUS, *cat} : Val{Sign::MINUS, *cat};
req.categories.push_back(val);
req.updates.push_back(val);
continue;
}
@ -292,10 +328,9 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
using Sign = User::Sign;
using Val = User::UpdateRequest::CommandsValueType;
;
auto [index, bit] = *cmd;
auto val = sign ? Val{Sign::PLUS, index, bit} : Val{Sign::MINUS, index, bit};
req.commands.push_back(val);
req.updates.push_back(val);
}
return req;

View file

@ -17,9 +17,8 @@
namespace dfly::acl {
std::string AclCatToString(uint32_t acl_category);
std::string AclCommandToString(const std::vector<uint64_t>& acl_category);
std::string AclCatAndCommandToString(const User::CategoryChanges& cat,
const User::CommandChanges& cmds);
std::string PrettyPrintSha(std::string_view pass, bool all = false);

View file

@ -9,6 +9,7 @@
#include <limits>
#include "absl/strings/escaping.h"
#include "core/overloaded.h"
#include "server/acl/helpers.h"
namespace dfly::acl {
@ -33,20 +34,28 @@ void User::Update(UpdateRequest&& req) {
SetPasswordHash(*req.password, req.is_hashed);
}
for (auto [sign, category] : req.categories) {
auto cat_visitor = [this](UpdateRequest::CategoryValueType cat) {
auto [sign, category] = cat;
if (sign == Sign::PLUS) {
SetAclCategories(category);
continue;
SetAclCategoriesAndIncrSeq(category);
return;
}
UnsetAclCategories(category);
}
UnsetAclCategoriesAndIncrSeq(category);
};
for (auto [sign, index, bit_index] : req.commands) {
auto cmd_visitor = [this](UpdateRequest::CommandsValueType cmd) {
auto [sign, index, bit_index] = cmd;
if (sign == Sign::PLUS) {
SetAclCommands(index, bit_index);
continue;
SetAclCommandsAndIncrSeq(index, bit_index);
return;
}
UnsetAclCommands(index, bit_index);
UnsetAclCommandsAndIncrSeq(index, bit_index);
};
Overloaded visitor{cat_visitor, cmd_visitor};
for (auto req : req.updates) {
std::visit(visitor, req);
}
if (!req.keys.empty()) {
@ -78,17 +87,42 @@ bool User::HasPassword(std::string_view password) const {
return *password_hash_ == StringSHA256(password);
}
void User::SetAclCategories(uint32_t cat) {
void User::SetAclCategoriesAndIncrSeq(uint32_t cat) {
acl_categories_ |= cat;
if (cat == acl::ALL) {
SetAclCommands(std::numeric_limits<size_t>::max(), 0);
} else {
auto id = CategoryToIdx().at(cat);
std::string_view name = REVERSE_CATEGORY_INDEX_TABLE[id];
const auto& commands_group = CategoryToCommandsIndex().at(name);
for (size_t fam_id = 0; fam_id < commands_group.size(); ++fam_id) {
SetAclCommands(fam_id, commands_group[fam_id]);
}
}
CategoryChange change{cat};
cat_changes_[change] = ChangeMetadata{Sign::PLUS, seq_++};
}
void User::UnsetAclCategories(uint32_t cat) {
SetAclCategories(cat);
void User::UnsetAclCategoriesAndIncrSeq(uint32_t cat) {
acl_categories_ ^= cat;
if (cat == acl::ALL) {
UnsetAclCommands(std::numeric_limits<size_t>::max(), 0);
} else {
auto id = CategoryToIdx().at(cat);
std::string_view name = REVERSE_CATEGORY_INDEX_TABLE[id];
const auto& commands_group = CategoryToCommandsIndex().at(name);
for (size_t fam_id = 0; fam_id < commands_group.size(); ++fam_id) {
UnsetAclCommands(fam_id, commands_group[fam_id]);
}
}
CategoryChange change{cat};
cat_changes_[change] = ChangeMetadata{Sign::MINUS, seq_++};
}
void User::SetAclCommands(size_t index, uint64_t bit_index) {
if (IsIndexAllCommandsFlag(index)) {
if (index == std::numeric_limits<size_t>::max()) {
for (auto& family : commands_) {
family = ALL_COMMANDS;
}
@ -97,8 +131,14 @@ void User::SetAclCommands(size_t index, uint64_t bit_index) {
commands_[index] |= bit_index;
}
void User::SetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index) {
SetAclCommands(index, bit_index);
CommandChange change{index, bit_index};
cmd_changes_[change] = ChangeMetadata{Sign::PLUS, seq_++};
}
void User::UnsetAclCommands(size_t index, uint64_t bit_index) {
if (IsIndexAllCommandsFlag(index)) {
if (index == std::numeric_limits<size_t>::max()) {
for (auto& family : commands_) {
family = NONE_COMMANDS;
}
@ -108,6 +148,12 @@ void User::UnsetAclCommands(size_t index, uint64_t bit_index) {
commands_[index] ^= bit_index;
}
void User::UnsetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index) {
UnsetAclCommands(index, bit_index);
CommandChange change{index, bit_index};
cmd_changes_[change] = ChangeMetadata{Sign::MINUS, seq_++};
}
uint32_t User::AclCategory() const {
return acl_categories_;
}
@ -138,6 +184,14 @@ const AclKeys& User::Keys() const {
return keys_;
}
const User::CategoryChanges& User::CatChanges() const {
return cat_changes_;
}
const User::CommandChanges& User::CmdChanges() const {
return cmd_changes_;
}
void User::SetKeyGlobs(std::vector<UpdateKey> keys) {
for (auto& key : keys) {
if (key.all_keys) {

View file

@ -33,16 +33,16 @@ class User final {
struct UpdateRequest {
std::optional<std::string> password{};
std::vector<std::pair<Sign, uint32_t>> categories;
std::optional<bool> is_active{};
bool is_hashed{false};
// Categories and commands
using CategoryValueType = std::pair<Sign, uint32_t>;
// If index s numberic_limits::max() then it's a +all flag
using CommandsValueType = std::tuple<Sign, size_t /*index*/, uint64_t /*bit*/>;
using CommandsUpdateType = std::vector<CommandsValueType>;
CommandsUpdateType commands;
using UpdateType = std::vector<std::variant<CategoryValueType, CommandsValueType>>;
UpdateType updates;
// keys
std::vector<UpdateKey> keys;
@ -50,6 +50,14 @@ class User final {
bool allow_all_keys{false};
};
using CategoryChange = uint32_t;
using CommandChange = std::pair<size_t, uint64_t>;
struct ChangeMetadata {
Sign sign;
size_t seq_no;
};
/* Used for default user
* password = nopass
* acl_categories = +@all
@ -80,15 +88,24 @@ class User final {
const AclKeys& Keys() const;
using CategoryChanges = absl::flat_hash_map<CategoryChange, ChangeMetadata>;
using CommandChanges = absl::flat_hash_map<CommandChange, ChangeMetadata>;
const CategoryChanges& CatChanges() const;
const CommandChanges& CmdChanges() const;
private:
// For ACL categories
void SetAclCategories(uint32_t cat);
void UnsetAclCategories(uint32_t cat);
void SetAclCategoriesAndIncrSeq(uint32_t cat);
void UnsetAclCategoriesAndIncrSeq(uint32_t cat);
// For ACL commands
void SetAclCommands(size_t index, uint64_t bit_index);
void UnsetAclCommands(size_t index, uint64_t bit_index);
void SetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index);
void UnsetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index);
// For is_active flag
void SetIsActive(bool is_active);
@ -108,6 +125,16 @@ class User final {
// on how this mapping is built during the startup/registration of commands
std::vector<uint64_t> commands_;
// We also need to track all the explicit changes (ACL SETUSER) of acl's in-order.
// To speed up insertion we use the flat_hash_map and a seq_ variable which is a
// strictly monotonically increasing number that is used for ordering. Both of these
// indexers are merged and then sorted by the seq_ number when for example we print
// the ACL rules of each user via ACL LIST.
CategoryChanges cat_changes_;
CommandChanges cmd_changes_;
// Global modification order for changes in rules for acl commands and categories
size_t seq_ = 0;
// Glob patterns for the keys that a user is allowed to read/write
AclKeys keys_;

View file

@ -8,6 +8,7 @@
#include <mutex>
#include "base/flags.h"
#include "facade/acl_commands_def.h"
#include "facade/facade_types.h"
#include "server/acl/acl_commands_def.h"
@ -71,29 +72,13 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock<fb2::SharedM
: user(user), exists(exists), registry_lk_(std::move(lk)) {
}
UserRegistry::UserWithWriteLock UserRegistry::MaybeAddAndUpdateWithLock(std::string_view username,
User::UpdateRequest req) {
std::unique_lock<fb2::SharedMutex> lock(mu_);
const bool exists = registry_.contains(username);
auto& user = registry_[username];
user.Update(std::move(req));
return {std::move(lock), user, exists};
}
User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const {
User::UpdateRequest::CommandsUpdateType tmp(NumberOfFamilies());
size_t id = 0;
for (auto& elem : tmp) {
elem = {User::Sign::PLUS, id++, acl::ALL_COMMANDS};
}
std::pair<User::Sign, uint32_t> acl{User::Sign::PLUS, acl::ALL};
auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false};
return {{}, {acl}, true, false, std::move(tmp), {std::move(key)}};
return {{}, true, false, {std::move(acl)}, {std::move(key)}};
}
void UserRegistry::Init() {
// Add default user
User::UpdateRequest::CommandsUpdateType tmp(NumberOfFamilies());
// if there exists an acl file to load from, requirepass
// will not overwrite the default's user password loaded from
// that file. Loading the default's user password from a file

View file

@ -77,8 +77,6 @@ class UserRegistry {
std::unique_lock<util::fb2::SharedMutex> registry_lk_;
};
UserWithWriteLock MaybeAddAndUpdateWithLock(std::string_view username, User::UpdateRequest req);
User::UpdateRequest DefaultUserUpdateRequest() const;
private:

View file

@ -1,49 +0,0 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/acl/user_registry.h"
#include <string>
#include <string_view>
#include "base/gtest.h"
#include "base/logging.h"
#include "server/acl/acl_commands_def.h"
#include "server/acl/user.h"
using namespace testing;
namespace dfly::acl {
class UserRegistryTest : public Test {};
TEST_F(UserRegistryTest, BasicOp) {
UserRegistry registry;
const std::string username = "kostas";
const std::string pass = "mypass";
User::UpdateRequest req{pass, {}, true, false, {}, {}, false, false};
registry.MaybeAddAndUpdate(username, std::move(req));
CHECK_EQ(registry.AuthUser(username, pass), true);
CHECK_EQ(registry.IsUserActive(username), true);
CHECK_EQ(registry.GetCredentials(username).acl_categories, NONE);
using Sign = User::Sign;
std::vector<std::pair<Sign, uint32_t>> cat = {{Sign::PLUS, LIST}, {Sign::PLUS, SET}};
req = User::UpdateRequest{{}, std::move(cat), true, false, {}, {}, false, false};
registry.MaybeAddAndUpdate(username, std::move(req));
auto acl_categories = registry.GetCredentials(username).acl_categories;
uint32_t expected_result = NONE | LIST | SET;
CHECK_EQ(acl_categories, expected_result);
cat.push_back({Sign::MINUS, LIST});
req = User::UpdateRequest{{}, std::move(cat), true, false, {}, {}, false, false};
registry.MaybeAddAndUpdate(username, std::move(req));
acl_categories = registry.GetCredentials(username).acl_categories;
expected_result = NONE | SET;
CHECK_EQ(acl_categories, expected_result);
}
} // namespace dfly::acl

View file

@ -23,8 +23,8 @@ namespace dfly::acl {
return true;
}
const auto [is_authed, reason] = IsUserAllowedToInvokeCommandGeneric(
cntx.acl_categories, cntx.acl_commands, cntx.keys, tail_args, id);
const auto [is_authed, reason] =
IsUserAllowedToInvokeCommandGeneric(cntx.acl_commands, cntx.keys, tail_args, id);
if (!is_authed) {
auto& log = ServerState::tlocal()->acl_log;
@ -41,15 +41,13 @@ namespace dfly::acl {
#endif
[[nodiscard]] std::pair<bool, AclLog::Reason> IsUserAllowedToInvokeCommandGeneric(
uint32_t acl_cat, const std::vector<uint64_t>& acl_commands, const AclKeys& keys,
CmdArgList tail_args, const CommandId& id) {
const auto cat_credentials = id.acl_categories();
const std::vector<uint64_t>& acl_commands, const AclKeys& keys, CmdArgList tail_args,
const CommandId& id) {
const size_t index = id.GetFamily();
const uint64_t command_mask = id.GetBitIndex();
DCHECK_LT(index, acl_commands.size());
const bool command =
(acl_cat & cat_credentials) != 0 || (acl_commands[index] & command_mask) != 0;
const bool command = (acl_commands[index] & command_mask) != 0;
if (!command) {
return {false, AclLog::Reason::COMMAND};

View file

@ -13,8 +13,8 @@
namespace dfly::acl {
std::pair<bool, AclLog::Reason> IsUserAllowedToInvokeCommandGeneric(
uint32_t acl_cat, const std::vector<uint64_t>& acl_commands, const AclKeys& keys,
CmdArgList tail_args, const CommandId& id);
const std::vector<uint64_t>& acl_commands, const AclKeys& keys, CmdArgList tail_args,
const CommandId& id);
bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id,
CmdArgList tail_args);

View file

@ -129,7 +129,7 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
}
cmd.SetFamily(family_of_commands_.size() - 1);
if (!is_sub_command) {
if (!is_sub_command || absl::StartsWith(cmd.name(), "ACL")) {
cmd.SetBitIndex(1ULL << bit_index_);
family_of_commands_.back().push_back(std::string(k));
++bit_index_;

View file

@ -2660,7 +2660,7 @@ void Service::RegisterCommands() {
cluster_family_.Register(&registry_);
acl_family_.Register(&registry_);
acl::BuildIndexers(registry_.GetFamilies());
acl::BuildIndexers(registry_.GetFamilies(), &registry_);
// Only after all the commands are registered
registry_.Init(pp_.size());

View file

@ -14,59 +14,59 @@ async def test_acl_setuser(async_client):
await async_client.execute_command("ACL SETUSER kostas")
result = await async_client.execute_command("ACL LIST")
assert 2 == len(result)
assert "user kostas off nopass +@NONE" in result
assert "user kostas off nopass -@ALL" in result
await async_client.execute_command("ACL SETUSER kostas ON")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@NONE" in result
assert "user kostas on nopass -@ALL" in result
await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin")
result = await async_client.execute_command("ACL LIST")
# TODO consider printing to lowercase
assert "user kostas on nopass +@LIST +@STRING +@ADMIN" in result
assert "user kostas on nopass -@ALL +@LIST +@STRING +@ADMIN" in result
await async_client.execute_command("ACL SETUSER kostas -@list -@admin")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@STRING" in result
assert "user kostas on nopass -@ALL +@STRING -@LIST -@ADMIN" in result
# mix and match
await async_client.execute_command("ACL SETUSER kostas +@list -@string")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@LIST" in result
assert "user kostas on nopass -@ALL -@ADMIN +@LIST -@STRING" in result
# mix and match interleaved
await async_client.execute_command("ACL SETUSER kostas +@set -@set +@set")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@SET +@LIST" in result
assert "user kostas on nopass -@ALL -@ADMIN +@LIST -@STRING +@SET" in result
await async_client.execute_command("ACL SETUSER kostas +@all")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL" in result
assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET +@ALL" in result
# commands
await async_client.execute_command("ACL SETUSER kostas +set +get +hset")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL +SET +GET +HSET" in result
assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET +@ALL +SET +GET +HSET" in result
await async_client.execute_command("ACL SETUSER kostas -set -get +hset")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL +HSET" in result
assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET +@ALL -SET -GET +HSET" in result
# interleaved
await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@NONE" in result
assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET -SET -HSET -GET -@ALL" in result
# interleaved with categories
await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@STRING +SET" in result
assert "user kostas on nopass -@ADMIN +@LIST +@SET -HSET -@ALL +@STRING -GET +SET" in result
@pytest.mark.asyncio
async def test_acl_categories(async_client):
await async_client.execute_command(
"ACL SETUSER vlad ON >mypass +@string +@list +@connection ~*"
"ACL SETUSER vlad ON >mypass -@ALL +@string +@list +@connection ~*"
)
result = await async_client.execute_command("AUTH vlad mypass")
@ -80,7 +80,7 @@ async def test_acl_categories(async_client):
# This should fail, vlad does not have @admin
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("ACL SETUSER vlad ON >mypass")
result = await async_client.execute_command("ACL SETUSER vlad ON >mypass")
# This should fail, vlad does not have @sortedset
with pytest.raises(redis.exceptions.ResponseError):
@ -116,7 +116,7 @@ async def test_acl_categories(async_client):
@pytest.mark.asyncio
async def test_acl_commands(async_client):
await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get ~*")
await async_client.execute_command("ACL SETUSER random ON >mypass -@ALL +set +get ~*")
result = await async_client.execute_command("AUTH random mypass")
assert result == "OK"
@ -332,8 +332,8 @@ async def test_good_acl_file(df_local_factory, tmp_dir):
await client.execute_command("ACL LOAD")
result = await client.execute_command("ACL LIST")
assert 2 == len(result)
assert "user MrFoo on ea71c25a7a60224 +@NONE" in result
assert "user default on nopass +@ALL +ALL ~*" in result
assert "user MrFoo on ea71c25a7a60224 -@ALL" in result
assert "user default on nopass +@ALL ~*" in result
await client.execute_command("ACL DELUSER MrFoo")
await client.execute_command("ACL SETUSER roy ON >mypass +@STRING +HSET")
@ -342,10 +342,10 @@ async def test_good_acl_file(df_local_factory, tmp_dir):
result = await client.execute_command("ACL LIST")
assert 4 == len(result)
assert "user roy on ea71c25a7a60224 +@STRING +HSET" in result
assert "user shahar off ea71c25a7a60224 +@SET" in result
assert "user vlad off nopass +@STRING ~foo ~bar*" in result
assert "user default on nopass +@ALL +ALL ~*" in result
assert "user roy on ea71c25a7a60224 -@ALL +@STRING +HSET" in result
assert "user shahar off ea71c25a7a60224 -@ALL +@SET" in result
assert "user vlad off nopass -@ALL +@STRING ~foo ~bar*" in result
assert "user default on nopass +@ALL ~*" in result
result = await client.execute_command("ACL DELUSER shahar")
assert result == 1
@ -356,9 +356,9 @@ async def test_good_acl_file(df_local_factory, tmp_dir):
result = await client.execute_command("ACL LIST")
assert 3 == len(result)
assert "user roy on ea71c25a7a60224 +@STRING +HSET" in result
assert "user vlad off nopass +@STRING ~foo ~bar*" in result
assert "user default on nopass +@ALL +ALL ~*" in result
assert "user roy on ea71c25a7a60224 -@ALL +@STRING +HSET" in result
assert "user vlad off nopass -@ALL +@STRING ~foo ~bar*" in result
assert "user default on nopass +@ALL ~*" in result
await client.close()