feat(acl): add validation for acl keys (#2272)

* add validation for acl keys
* add tests
This commit is contained in:
Kostas Kyrimis 2023-12-08 17:28:53 +02:00 committed by GitHub
parent 8126cf8252
commit 2703d4635d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 154 additions and 29 deletions

View file

@ -72,7 +72,7 @@ inline const absl::flat_hash_map<std::string_view, uint32_t> CATEGORY_INDEX_TABL
{"READ", READ}, {"READ", READ},
{"WRITE", WRITE}, {"WRITE", WRITE},
{"SET", SET}, {"SET", SET},
{"SORTED_SET", SORTEDSET}, {"SORTEDSET", SORTEDSET},
{"LIST", LIST}, {"LIST", LIST},
{"HASH", HASH}, {"HASH", HASH},
{"STRING", STRING}, {"STRING", STRING},
@ -99,7 +99,7 @@ inline const absl::flat_hash_map<std::string_view, uint32_t> CATEGORY_INDEX_TABL
// bit 1 at index 1 // bit 1 at index 1
// bit n at index n // bit n at index n
inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{ inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"KEYSPACE", "READ", "WRITE", "SET", "SORTED_SET", "LIST", "HASH", "KEYSPACE", "READ", "WRITE", "SET", "SORTEDSET", "LIST", "HASH",
"STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM", "PUBSUB", "ADMIN", "STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM", "PUBSUB", "ADMIN",
"FAST", "SLOW", "BLOCKING", "DANGEROUS", "CONNECTION", "TRANSACTION", "SCRIPTING", "FAST", "SLOW", "BLOCKING", "DANGEROUS", "CONNECTION", "TRANSACTION", "SCRIPTING",
"_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED",

View file

@ -548,7 +548,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
} }
const auto& user = registry.find(username)->second; const auto& user = registry.find(username)->second;
if (IsUserAllowedToInvokeCommandGeneric(user.AclCategory(), user.AclCommandsRef(), *cid)) { if (IsUserAllowedToInvokeCommandGeneric(user.AclCategory(), user.AclCommandsRef(), {{}, true}, {},
*cid)) {
cntx->SendOk(); cntx->SendOk();
return; return;
} }

View file

@ -7,18 +7,24 @@
#include "base/logging.h" #include "base/logging.h"
#include "facade/dragonfly_connection.h" #include "facade/dragonfly_connection.h"
#include "server/acl/acl_commands_def.h" #include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/server_state.h" #include "server/server_state.h"
#include "server/transaction.h"
// we need this because of stringmatchlen
extern "C" {
#include "redis/util.h"
}
namespace dfly::acl { namespace dfly::acl {
[[nodiscard]] bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, [[nodiscard]] bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id,
const facade::CommandId& id) { CmdArgList tail_args) {
if (cntx.skip_acl_validation) { if (cntx.skip_acl_validation) {
return true; return true;
} }
const bool is_authed = const auto is_authed = IsUserAllowedToInvokeCommandGeneric(cntx.acl_categories, cntx.acl_commands,
IsUserAllowedToInvokeCommandGeneric(cntx.acl_categories, cntx.acl_commands, id); cntx.keys, tail_args, id);
if (!is_authed) { if (!is_authed) {
auto& log = ServerState::tlocal()->acl_log; auto& log = ServerState::tlocal()->acl_log;
@ -29,14 +35,70 @@ namespace dfly::acl {
return is_authed; return is_authed;
} }
// GCC yields a wrong warning about uninitialized optional use
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
[[nodiscard]] bool IsUserAllowedToInvokeCommandGeneric(uint32_t acl_cat, [[nodiscard]] bool IsUserAllowedToInvokeCommandGeneric(uint32_t acl_cat,
const std::vector<uint64_t>& acl_commands, const std::vector<uint64_t>& acl_commands,
const facade::CommandId& id) { const AclKeys& keys, CmdArgList tail_args,
const CommandId& id) {
const auto cat_credentials = id.acl_categories(); const auto cat_credentials = id.acl_categories();
const size_t index = id.GetFamily(); const size_t index = id.GetFamily();
const uint64_t command_mask = id.GetBitIndex(); const uint64_t command_mask = id.GetBitIndex();
DCHECK_LT(index, acl_commands.size()); DCHECK_LT(index, acl_commands.size());
return (acl_cat & cat_credentials) != 0 || (acl_commands[index] & command_mask) != 0;
const bool command =
(acl_cat & cat_credentials) != 0 || (acl_commands[index] & command_mask) != 0;
if (!command) {
return false;
}
auto match = [](const auto& pattern, const auto& target) {
return stringmatchlen(pattern.data(), pattern.size(), target.data(), target.size(), 0);
};
const bool is_read_command = id.IsReadOnly();
const bool is_write_command = id.IsWriteOnly();
auto iterate_globs = [&](auto target) {
for (auto& [elem, op] : keys.key_globs) {
if (match(elem, target)) {
if (is_read_command && (op == KeyOp::READ || op == KeyOp::READ_WRITE)) {
return true;
}
if (is_write_command && (op == KeyOp::WRITE || op == KeyOp::READ_WRITE)) {
return true;
}
}
}
return false;
};
bool keys_allowed = true;
if (!keys.all_keys && id.first_key_pos() != 0 && (is_read_command || is_write_command)) {
const auto keys_index = DetermineKeys(&id, tail_args).value();
const size_t end = keys_index.end;
if (keys_index.bonus) {
auto target = facade::ToSV(tail_args[*keys_index.bonus]);
if (!iterate_globs(target)) {
keys_allowed = false;
}
}
if (keys_allowed) {
for (size_t i = keys_index.start; i < end; i += keys_index.step) {
auto target = facade::ToSV(tail_args[i]);
if (!iterate_globs(target)) {
keys_allowed = false;
break;
}
}
}
}
return keys_allowed;
} }
#pragma GCC diagnostic pop
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -4,15 +4,20 @@
#pragma once #pragma once
#include <utility>
#include "facade/command_id.h" #include "facade/command_id.h"
#include "server/acl/acl_log.h"
#include "server/conn_context.h" #include "server/conn_context.h"
namespace dfly::acl { namespace dfly::acl {
bool IsUserAllowedToInvokeCommandGeneric(uint32_t acl_cat, bool IsUserAllowedToInvokeCommandGeneric(uint32_t acl_cat,
const std::vector<uint64_t>& acl_commands, const std::vector<uint64_t>& acl_commands,
const facade::CommandId& id); const AclKeys& keys, CmdArgList tail_args,
const CommandId& id);
bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const facade::CommandId& id); bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id,
CmdArgList tail_args);
} // namespace dfly::acl } // namespace dfly::acl

View file

@ -92,6 +92,14 @@ class CommandId : public facade::CommandId {
bool IsTransactional() const; bool IsTransactional() const;
bool IsReadOnly() const {
return opt_mask_ & CO::CommandOpt::READONLY;
}
bool IsWriteOnly() const {
return opt_mask_ & CO::CommandOpt::WRITE;
}
static const char* OptName(CO::CommandOpt fl); static const char* OptName(CO::CommandOpt fl);
CommandId&& SetHandler(Handler f) && { CommandId&& SetHandler(Handler f) && {

View file

@ -924,23 +924,25 @@ OpStatus CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_info, const C
static optional<ErrorReply> VerifyConnectionAclStatus(const CommandId* cid, static optional<ErrorReply> VerifyConnectionAclStatus(const CommandId* cid,
const ConnectionContext* cntx, const ConnectionContext* cntx,
string_view error_msg) { string_view error_msg, CmdArgList tail_args) {
// If we are on a squashed context we need to use the owner, because the // If we are on a squashed context we need to use the owner, because the
// context we are operating on is a stub and the acl username is not copied // context we are operating on is a stub and the acl username is not copied
// See: MultiCommandSquasher::SquashedHopCb // See: MultiCommandSquasher::SquashedHopCb
if (cntx->conn_state.squashing_info) if (cntx->conn_state.squashing_info)
cntx = cntx->conn_state.squashing_info->owner; cntx = cntx->conn_state.squashing_info->owner;
if (!acl::IsUserAllowedToInvokeCommand(*cntx, *cid)) { if (!acl::IsUserAllowedToInvokeCommand(*cntx, *cid, tail_args)) {
return ErrorReply(absl::StrCat("NOPERM: ", cntx->authed_username, " ", error_msg)); return ErrorReply(absl::StrCat("NOPERM: ", cntx->authed_username, " ", error_msg));
} }
return nullopt; return nullopt;
} }
optional<ErrorReply> Service::VerifyCommandExecution(const CommandId* cid, optional<ErrorReply> Service::VerifyCommandExecution(const CommandId* cid,
const ConnectionContext* cntx) { const ConnectionContext* cntx,
CmdArgList tail_args) {
// TODO: Move OOM check here // TODO: Move OOM check here
return VerifyConnectionAclStatus(cid, cntx, "ACL rules changed between the MULTI and EXEC"); return VerifyConnectionAclStatus(cid, cntx, "ACL rules changed between the MULTI and EXEC",
tail_args);
} }
std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdArgList tail_args, std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdArgList tail_args,
@ -960,7 +962,7 @@ std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdA
bool is_trans_cmd = CO::IsTransKind(cid->name()); bool is_trans_cmd = CO::IsTransKind(cid->name());
bool under_script = dfly_cntx.conn_state.script_info != nullptr; bool under_script = dfly_cntx.conn_state.script_info != nullptr;
bool is_write_cmd = cid->opt_mask() & CO::WRITE; bool is_write_cmd = cid->IsWriteOnly();
bool under_multi = dfly_cntx.conn_state.exec_info.IsCollecting() && !is_trans_cmd; bool under_multi = dfly_cntx.conn_state.exec_info.IsCollecting() && !is_trans_cmd;
// Check if the command is allowed to execute under this global state // Check if the command is allowed to execute under this global state
@ -1037,7 +1039,7 @@ std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdA
return ErrorReply{status}; return ErrorReply{status};
} }
return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions"); return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args);
} }
void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) {
@ -1064,7 +1066,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
} }
if (!dispatching_in_multi) { // Don't interrupt running multi commands if (!dispatching_in_multi) { // Don't interrupt running multi commands
bool is_write = (cid->opt_mask() & CO::WRITE); bool is_write = cid->IsWriteOnly();
is_write |= cid->name() == "PUBLISH" || cid->name() == "EVAL" || cid->name() == "EVALSHA"; is_write |= cid->name() == "PUBLISH" || cid->name() == "EVAL" || cid->name() == "EVALSHA";
is_write |= cid->name() == "EXEC" && dfly_cntx->conn_state.exec_info.is_write; is_write |= cid->name() == "EXEC" && dfly_cntx->conn_state.exec_info.is_write;
@ -1088,7 +1090,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
// TODO: protect against aggregating huge transactions. // TODO: protect against aggregating huge transactions.
StoredCmd stored_cmd{cid, args_no_cmd}; StoredCmd stored_cmd{cid, args_no_cmd};
dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd)); dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd));
if (stored_cmd.Cid()->opt_mask() & CO::WRITE) { if (stored_cmd.Cid()->IsWriteOnly()) {
dfly_cntx->conn_state.exec_info.is_write = true; dfly_cntx->conn_state.exec_info.is_write = true;
} }
return cntx->SendSimpleString("QUEUED"); return cntx->SendSimpleString("QUEUED");
@ -1183,7 +1185,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo
DCHECK(cid); DCHECK(cid);
DCHECK(!cid->Validate(tail_args)); DCHECK(!cid->Validate(tail_args));
if (auto err = VerifyCommandExecution(cid, cntx); err) { if (auto err = VerifyCommandExecution(cid, cntx, tail_args); err) {
cntx->SendError(std::move(*err)); cntx->SendError(std::move(*err));
return true; // return false only for internal error aborts return true; // return false only for internal error aborts
} }

View file

@ -58,7 +58,8 @@ class Service : public facade::ServiceInterface {
// Verify command can be executed now (check out of memory), always called immediately before // Verify command can be executed now (check out of memory), always called immediately before
// execution // execution
std::optional<facade::ErrorReply> VerifyCommandExecution(const CommandId* cid, std::optional<facade::ErrorReply> VerifyCommandExecution(const CommandId* cid,
const ConnectionContext* cntx); const ConnectionContext* cntx,
CmdArgList tail_args);
// Verify command prepares excution in correct state. // Verify command prepares excution in correct state.
// It's usually called before command execution. Only for multi/exec transactions it's checked // It's usually called before command execution. Only for multi/exec transactions it's checked

View file

@ -98,7 +98,7 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm
auto& sinfo = PrepareShardInfo(last_sid); auto& sinfo = PrepareShardInfo(last_sid);
sinfo.had_writes |= (cmd->Cid()->opt_mask() & CO::WRITE); sinfo.had_writes |= (cmd->Cid()->IsWriteOnly());
sinfo.cmds.push_back(cmd); sinfo.cmds.push_back(cmd);
order_.push_back(last_sid); order_.push_back(last_sid);

View file

@ -1200,6 +1200,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
auto cred = registry->GetCredentials(username); auto cred = registry->GetCredentials(username);
cntx->acl_categories = cred.acl_categories; cntx->acl_categories = cred.acl_categories;
cntx->acl_commands = cred.acl_commands; cntx->acl_commands = cred.acl_commands;
cntx->keys = std::move(cred.keys);
cntx->authenticated = true; cntx->authenticated = true;
return cntx->SendOk(); return cntx->SendOk();
} }

View file

@ -31,7 +31,7 @@ constexpr size_t kTransSize [[maybe_unused]] = sizeof(Transaction);
} // namespace } // namespace
IntentLock::Mode Transaction::Mode() const { IntentLock::Mode Transaction::Mode() const {
return (cid_->opt_mask() & CO::READONLY) ? IntentLock::SHARED : IntentLock::EXCLUSIVE; return cid_->IsReadOnly() ? IntentLock::SHARED : IntentLock::EXCLUSIVE;
} }
/** /**
@ -1362,7 +1362,7 @@ void Transaction::LogAutoJournalOnShard(EngineShard* shard) {
return; return;
// Only write commands and/or no-key-transactional commands are logged // Only write commands and/or no-key-transactional commands are logged
if ((cid_->opt_mask() & CO::WRITE) == 0 && (cid_->opt_mask() & CO::NO_KEY_TRANSACTIONAL) == 0) if ((cid_->IsWriteOnly()) == 0 && (cid_->opt_mask() & CO::NO_KEY_TRANSACTIONAL) == 0)
return; return;
// If autojournaling was disabled and not re-enabled, skip it // If autojournaling was disabled and not re-enabled, skip it

View file

@ -65,7 +65,9 @@ async def test_acl_setuser(async_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acl_categories(async_client): async def test_acl_categories(async_client):
await async_client.execute_command("ACL SETUSER vlad ON >mypass +@string +@list +@connection") await async_client.execute_command(
"ACL SETUSER vlad ON >mypass +@string +@list +@connection ~*"
)
result = await async_client.execute_command("AUTH vlad mypass") result = await async_client.execute_command("AUTH vlad mypass")
assert result == "OK" assert result == "OK"
@ -114,7 +116,7 @@ async def test_acl_categories(async_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acl_commands(async_client): async def test_acl_commands(async_client):
await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get") await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get ~*")
result = await async_client.execute_command("AUTH random mypass") result = await async_client.execute_command("AUTH random mypass")
assert result == "OK" assert result == "OK"
@ -134,7 +136,7 @@ async def test_acl_cat_commands_multi_exec_squash(df_local_factory):
# Testing acl categories # Testing acl categories
client = aioredis.Redis(port=df.port) client = aioredis.Redis(port=df.port)
res = await client.execute_command("ACL SETUSER kk ON >kk +@transaction +@string") res = await client.execute_command("ACL SETUSER kk ON >kk +@transaction +@string ~*")
assert res == b"OK" assert res == b"OK"
res = await client.execute_command("AUTH kk kk") res = await client.execute_command("AUTH kk kk")
@ -191,7 +193,7 @@ async def test_acl_cat_commands_multi_exec_squash(df_local_factory):
# Testing acl commands # Testing acl commands
client = aioredis.Redis(port=df.port) client = aioredis.Redis(port=df.port)
res = await client.execute_command("ACL SETUSER myuser ON >kk +@transaction +set") res = await client.execute_command("ACL SETUSER myuser ON >kk +@transaction +set ~*")
assert res == b"OK" assert res == b"OK"
res = await client.execute_command("AUTH myuser kk") res = await client.execute_command("AUTH myuser kk")
@ -359,7 +361,7 @@ async def test_acl_log(async_client):
res = await async_client.execute_command("ACL LOG") res = await async_client.execute_command("ACL LOG")
assert [] == res assert [] == res
await async_client.execute_command("ACL SETUSER elon >mars ON +@string +@dangerous") await async_client.execute_command("ACL SETUSER elon >mars ON +@string +@dangerous ~*")
with pytest.raises(redis.exceptions.AuthenticationError): with pytest.raises(redis.exceptions.AuthenticationError):
await async_client.execute_command("AUTH elon wrong") await async_client.execute_command("AUTH elon wrong")
@ -472,3 +474,46 @@ async def test_set_len_acl_log(async_client):
res = await async_client.execute_command("ACL LOG") res = await async_client.execute_command("ACL LOG")
assert 10 == len(res) assert 10 == len(res)
@pytest.mark.asyncio
async def test_acl_keys(async_client):
await async_client.execute_command("ACL SETUSER mrkeys ON >mrkeys allkeys +@admin")
await async_client.execute_command("AUTH mrkeys mrkeys")
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("SET foo bar")
await async_client.execute_command(
"ACL SETUSER mrkeys ON >mrkeys resetkeys +@string ~foo ~bar* ~dr*gon"
)
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("SET random rand")
assert "OK" == await async_client.execute_command("SET foo val")
assert "OK" == await async_client.execute_command("SET bar val")
assert "OK" == await async_client.execute_command("SET barsomething val")
assert "OK" == await async_client.execute_command("SET dragon val")
await async_client.execute_command("ACL SETUSER mrkeys ON >mrkeys allkeys +@sortedset")
assert "OK" == await async_client.execute_command("SET random rand")
await async_client.execute_command(
"ACL SETUSER mrkeys ON >mrkeys resetkeys resetkeys %R~foo %W~bar"
)
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("SET foo val")
assert "val" == await async_client.execute_command("GET foo")
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("GET bar")
assert "OK" == await async_client.execute_command("SET bar val")
await async_client.execute_command("ACL SETUSER mrkeys resetkeys ~bar* +@sortedset")
assert 1 == await async_client.execute_command("ZADD barz1 1 val1")
assert 1 == await async_client.execute_command("ZADD barz2 1 val2")
# reject because bonus key does not match
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("ZUNIONSTORE destkey 2 barz1 barz2")