feat(acl): add acl keys to acl setuser command (#2258)

* add parsing of ACL keys
* add ACL keys to acl setuser command
This commit is contained in:
Kostas Kyrimis 2023-12-08 11:53:22 +02:00 committed by GitHub
parent 636507c356
commit b642fb6901
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 110 additions and 40 deletions

View file

@ -104,11 +104,14 @@ class ConnectionContext {
// How many async subscription sources are active: monitor and/or pubsub - at most 2. // How many async subscription sources are active: monitor and/or pubsub - at most 2.
uint8_t subscriptions; uint8_t subscriptions;
// TODO fix inherit actual values from default
std::string authed_username{"default"}; std::string authed_username{"default"};
uint32_t acl_categories{dfly::acl::ALL}; uint32_t acl_categories{dfly::acl::ALL};
std::vector<uint64_t> acl_commands; std::vector<uint64_t> acl_commands;
// Skip ACL validation, used by internal commands and commands run on admin port // Skip ACL validation, used by internal commands and commands run on admin port
bool skip_acl_validation = false; bool skip_acl_validation = false;
// keys
dfly::acl::AclKeys keys{{}, true};
private: private:
Connection* owner_; Connection* owner_;

View file

@ -8,6 +8,7 @@
#include <absl/strings/match.h> #include <absl/strings/match.h>
#include <mimalloc.h> #include <mimalloc.h>
#include <numeric>
#include <variant> #include <variant>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
@ -205,9 +206,12 @@ size_t Connection::MessageHandle::UsedMemory() const {
return msg.capacity(); return msg.capacity();
} }
size_t operator()(const AclUpdateMessagePtr& msg) { size_t operator()(const AclUpdateMessagePtr& msg) {
return sizeof(AclUpdateMessage) + msg->username.capacity() * sizeof(string) + size_t key_cap = std::accumulate(
msg->commands.capacity() * sizeof(vector<int>) + msg->keys.key_globs.begin(), msg->keys.key_globs.end(), 0, [](auto acc, auto& str) {
msg->categories.capacity() * sizeof(uint32_t); return acc + (str.first.capacity() * sizeof(char)) + sizeof(str.second);
});
return sizeof(AclUpdateMessage) + msg->username.capacity() * sizeof(char) +
msg->commands.capacity() * sizeof(uint64_t) + key_cap;
} }
size_t operator()(const MigrationRequestMessage& msg) { size_t operator()(const MigrationRequestMessage& msg) {
return 0; return 0;
@ -240,11 +244,10 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) { void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (self->cntx()) { if (self->cntx()) {
for (size_t id = 0; id < msg.username.size(); ++id) { if (msg.username == self->cntx()->authed_username) {
if (msg.username[id] == self->cntx()->authed_username) { self->cntx()->acl_categories = msg.categories;
self->cntx()->acl_categories = msg.categories[id]; self->cntx()->acl_commands = msg.commands;
self->cntx()->acl_commands = msg.commands[id]; self->cntx()->keys = msg.keys;
}
} }
} }
} }

View file

@ -16,6 +16,7 @@
#include "base/io_buf.h" #include "base/io_buf.h"
#include "core/fibers.h" #include "core/fibers.h"
#include "facade/acl_commands_def.h"
#include "facade/facade_types.h" #include "facade/facade_types.h"
#include "facade/resp_expr.h" #include "facade/resp_expr.h"
#include "util/connection.h" #include "util/connection.h"
@ -101,9 +102,10 @@ class Connection : public util::Connection {
// ACL Update message, contains ACL updates to be applied to the connection. // ACL Update message, contains ACL updates to be applied to the connection.
struct AclUpdateMessage { struct AclUpdateMessage {
std::vector<std::string> username; std::string username;
std::vector<uint32_t> categories; uint32_t categories;
std::vector<std::vector<uint64_t>> commands; std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
}; };
// Migration request message, the dispatch fiber stops to give way for thread migration. // Migration request message, the dispatch fiber stops to give way for thread migration.

View file

@ -76,16 +76,14 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
} }
} }
void AclFamily::StreamUpdatesToAllProactorConnections(const std::vector<std::string>& user, void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
const std::vector<uint32_t>& update_cat, const Commands& update_commands,
const NestedVector& update_commands) { const AclKeys& update_keys) {
auto update_cb = [&user, &update_cat, &update_commands]([[maybe_unused]] size_t id, auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
util::Connection* conn) {
DCHECK(conn); DCHECK(conn);
auto connection = static_cast<facade::Connection*>(conn); auto connection = static_cast<facade::Connection*>(conn);
DCHECK(user.size() == update_cat.size());
connection->SendAclUpdateAsync( connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands}); facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys});
}; };
if (main_listener_) { if (main_listener_) {
@ -97,14 +95,20 @@ using facade::ErrorReply;
void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) { void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
std::string_view username = facade::ToSV(args[0]); std::string_view username = facade::ToSV(args[0]);
auto req = ParseAclSetUser(args.subspan(1), *cmd_registry_); auto reg = registry_->GetRegistryWithWriteLock();
auto error_case = [cntx](ErrorReply&& error) { cntx->SendError(std::move(error)); }; const bool exists = reg.registry.contains(username);
auto update_case = [username, cntx, this](User::UpdateRequest&& req) { const bool has_all_keys = exists ? reg.registry.find(username)->second.Keys().all_keys : false;
auto user_with_lock = registry_->MaybeAddAndUpdateWithLock(username, std::move(req));
if (user_with_lock.exists) { auto req = ParseAclSetUser(args.subspan(1), *cmd_registry_, false, has_all_keys);
StreamUpdatesToAllProactorConnections({std::string(username)},
{user_with_lock.user.AclCategory()}, auto error_case = [cntx](ErrorReply&& error) { cntx->SendError(error); };
{user_with_lock.user.AclCommands()});
auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
user.Update(std::move(req));
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(),
user.AclCommands(), user.Keys());
} }
cntx->SendOk(); cntx->SendOk();
}; };
@ -273,13 +277,10 @@ std::optional<facade::ErrorReply> AclFamily::LoadToRegistryFromFile(std::string_
EvictOpenConnectionsOnAllProactorsWithRegistry(registry); EvictOpenConnectionsOnAllProactorsWithRegistry(registry);
registry.clear(); registry.clear();
} }
std::vector<uint32_t> categories;
NestedVector commands;
for (size_t i = 0; i < usernames.size(); ++i) { for (size_t i = 0; i < usernames.size(); ++i) {
auto& user = registry[usernames[i]]; auto& user = registry[usernames[i]];
user.Update(std::move(requests[i])); user.Update(std::move(requests[i]));
categories.push_back(user.AclCategory());
commands.push_back(user.AclCommands());
} }
if (!registry.contains("default")) { if (!registry.contains("default")) {

View file

@ -47,10 +47,10 @@ class AclFamily final {
// Helper function that updates all open connections and their // Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads // respective ACL fields on all the available proactor threads
using NestedVector = std::vector<std::vector<uint64_t>>; using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::vector<std::string>& user, void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
const std::vector<uint32_t>& update_cat, const Commands& update_commands,
const NestedVector& update_commands); const AclKeys& update_keys);
// Helper function that closes all open connection from the deleted user // Helper function that closes all open connection from the deleted user
void EvictOpenConnectionsOnAllProactors(std::string_view user); void EvictOpenConnectionsOnAllProactors(std::string_view user);

View file

@ -79,6 +79,38 @@ std::string PrettyPrintSha(std::string_view pass, bool all) {
return absl::BytesToHexString(pass.substr(0, 15)).substr(0, 15); return absl::BytesToHexString(pass.substr(0, 15)).substr(0, 15);
}; };
std::optional<ParseKeyResult> MaybeParseAclKey(std::string_view command) {
if (absl::EqualsIgnoreCase(command, "ALLKEYS") || command == "~*") {
return ParseKeyResult{"", {}, true};
}
if (absl::EqualsIgnoreCase(command, "RESETKEYS")) {
return ParseKeyResult{"", {}, false, true};
}
auto op = KeyOp::READ_WRITE;
if (absl::StartsWith(command, "%RW")) {
command = command.substr(3);
} else if (absl::StartsWith(command, "%R")) {
op = KeyOp::READ;
command = command.substr(2);
} else if (absl::StartsWith(command, "%W")) {
op = KeyOp::WRITE;
command = command.substr(2);
}
if (!absl::StartsWith(command, "~")) {
return {};
}
auto key = command.substr(1);
if (key.empty()) {
return {};
}
return ParseKeyResult{std::string(key), op};
}
std::optional<std::string> MaybeParsePassword(std::string_view command, bool hashed) { std::optional<std::string> MaybeParsePassword(std::string_view command, bool hashed) {
if (command == "nopass") { if (command == "nopass") {
return std::string(command); return std::string(command);
@ -190,7 +222,7 @@ using facade::ErrorReply;
template <typename T> template <typename T>
std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args, std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
const CommandRegistry& registry, const CommandRegistry& registry,
bool hashed) { bool hashed, bool has_all_keys) {
User::UpdateRequest req; User::UpdateRequest req;
for (auto& arg : args) { for (auto& arg : args) {
@ -202,6 +234,26 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
req.is_hashed = hashed; req.is_hashed = hashed;
continue; continue;
} }
if (auto res = MaybeParseAclKey(facade::ToSV(arg)); res) {
auto& [glob, op, all_keys, reset_keys] = *res;
if ((has_all_keys && !all_keys && !reset_keys) ||
(req.allow_all_keys && !all_keys && !reset_keys)) {
return ErrorReply(
"Error in ACL SETUSER modifier '~tmp': Adding a pattern after the * pattern (or the "
"'allkeys' flag) is not valid and does not have any effect. Try 'resetkeys' to start "
"with an empty list of patterns");
}
req.allow_all_keys = all_keys;
req.reset_all_keys = reset_keys;
if (reset_keys) {
has_all_keys = false;
}
req.keys.push_back({std::move(glob), op, all_keys, reset_keys});
continue;
}
std::string buffer; std::string buffer;
std::string_view command; std::string_view command;
if constexpr (std::is_same_v<T, facade::CmdArgList>) { if constexpr (std::is_same_v<T, facade::CmdArgList>) {
@ -252,8 +304,9 @@ using facade::CmdArgList;
template std::variant<User::UpdateRequest, ErrorReply> template std::variant<User::UpdateRequest, ErrorReply>
ParseAclSetUser<std::vector<std::string_view>&>(std::vector<std::string_view>&, ParseAclSetUser<std::vector<std::string_view>&>(std::vector<std::string_view>&,
const CommandRegistry& registry, bool hashed); const CommandRegistry& registry, bool hashed,
bool has_all_keys);
template std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser<CmdArgList>( template std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser<CmdArgList>(
CmdArgList args, const CommandRegistry& registry, bool hashed); CmdArgList args, const CommandRegistry& registry, bool hashed, bool has_all_keys);
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -38,11 +38,19 @@ std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,
template <typename T> template <typename T>
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser( std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
T args, const CommandRegistry& registry, bool hashed = false); T args, const CommandRegistry& registry, bool hashed = false, bool has_all_keys = false);
using MaterializedContents = std::optional<std::vector<std::vector<std::string_view>>>; using MaterializedContents = std::optional<std::vector<std::vector<std::string_view>>>;
MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames, MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames,
std::string_view file_contents); std::string_view file_contents);
struct ParseKeyResult {
std::string glob;
KeyOp op;
bool all_keys{false};
bool reset_keys{false};
};
std::optional<ParseKeyResult> MaybeParseAclKey(std::string_view command);
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -138,7 +138,7 @@ const AclKeys& User::Keys() const {
return keys_; return keys_;
} }
void User::SetKeyGlobs(std::vector<UpdateKey>&& keys) { void User::SetKeyGlobs(std::vector<UpdateKey> keys) {
for (auto& key : keys) { for (auto& key : keys) {
if (key.all_keys) { if (key.all_keys) {
keys_.key_globs.clear(); keys_.key_globs.clear();

View file

@ -96,7 +96,7 @@ class User final {
void SetPasswordHash(std::string_view password, bool is_hashed); void SetPasswordHash(std::string_view password, bool is_hashed);
// For ACL key globs // For ACL key globs
void SetKeyGlobs(std::vector<UpdateKey>&& keys); void SetKeyGlobs(std::vector<UpdateKey> keys);
// when optional is empty, the special `nopass` password is implied // when optional is empty, the special `nopass` password is implied
// password hashed with xx64 // password hashed with xx64