diff --git a/src/server/acl/acl_commands_def.h b/src/server/acl/acl_commands_def.h index 76d8c7a90..9cdf52101 100644 --- a/src/server/acl/acl_commands_def.h +++ b/src/server/acl/acl_commands_def.h @@ -72,7 +72,7 @@ inline const absl::flat_hash_map CATEGORY_INDEX_TABL {"READ", READ}, {"WRITE", WRITE}, {"SET", SET}, - {"SORTED_SET", SORTEDSET}, + {"SORTEDSET", SORTEDSET}, {"LIST", LIST}, {"HASH", HASH}, {"STRING", STRING}, @@ -99,7 +99,7 @@ inline const absl::flat_hash_map CATEGORY_INDEX_TABL // bit 1 at index 1 // bit n at index n inline const std::vector REVERSE_CATEGORY_INDEX_TABLE{ - "KEYSPACE", "READ", "WRITE", "SET", "SORTED_SET", "LIST", "HASH", + "KEYSPACE", "READ", "WRITE", "SET", "SORTEDSET", "LIST", "HASH", "STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM", "PUBSUB", "ADMIN", "FAST", "SLOW", "BLOCKING", "DANGEROUS", "CONNECTION", "TRANSACTION", "SCRIPTING", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 56a5fc163..bbce5afd1 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -548,7 +548,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) { } const auto& user = registry.find(username)->second; - if (IsUserAllowedToInvokeCommandGeneric(user.AclCategory(), user.AclCommandsRef(), *cid)) { + if (IsUserAllowedToInvokeCommandGeneric(user.AclCategory(), user.AclCommandsRef(), {{}, true}, {}, + *cid)) { cntx->SendOk(); return; } diff --git a/src/server/acl/validator.cc b/src/server/acl/validator.cc index 8a8f6e059..37c3e8613 100644 --- a/src/server/acl/validator.cc +++ b/src/server/acl/validator.cc @@ -7,18 +7,24 @@ #include "base/logging.h" #include "facade/dragonfly_connection.h" #include "server/acl/acl_commands_def.h" +#include "server/command_registry.h" #include "server/server_state.h" +#include "server/transaction.h" +// we need this because of stringmatchlen +extern "C" { +#include "redis/util.h" +} namespace dfly::acl { -[[nodiscard]] bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, - const facade::CommandId& id) { +[[nodiscard]] bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id, + CmdArgList tail_args) { if (cntx.skip_acl_validation) { return true; } - const bool is_authed = - IsUserAllowedToInvokeCommandGeneric(cntx.acl_categories, cntx.acl_commands, id); + const auto is_authed = IsUserAllowedToInvokeCommandGeneric(cntx.acl_categories, cntx.acl_commands, + cntx.keys, tail_args, id); if (!is_authed) { auto& log = ServerState::tlocal()->acl_log; @@ -29,14 +35,70 @@ namespace dfly::acl { return is_authed; } +// GCC yields a wrong warning about uninitialized optional use +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + [[nodiscard]] bool IsUserAllowedToInvokeCommandGeneric(uint32_t acl_cat, const std::vector& acl_commands, - const facade::CommandId& id) { + const AclKeys& keys, CmdArgList tail_args, + const CommandId& id) { const auto cat_credentials = id.acl_categories(); const size_t index = id.GetFamily(); const uint64_t command_mask = id.GetBitIndex(); DCHECK_LT(index, acl_commands.size()); - return (acl_cat & cat_credentials) != 0 || (acl_commands[index] & command_mask) != 0; + + const bool command = + (acl_cat & cat_credentials) != 0 || (acl_commands[index] & command_mask) != 0; + + if (!command) { + return false; + } + + auto match = [](const auto& pattern, const auto& target) { + return stringmatchlen(pattern.data(), pattern.size(), target.data(), target.size(), 0); + }; + + const bool is_read_command = id.IsReadOnly(); + const bool is_write_command = id.IsWriteOnly(); + + auto iterate_globs = [&](auto target) { + for (auto& [elem, op] : keys.key_globs) { + if (match(elem, target)) { + if (is_read_command && (op == KeyOp::READ || op == KeyOp::READ_WRITE)) { + return true; + } + if (is_write_command && (op == KeyOp::WRITE || op == KeyOp::READ_WRITE)) { + return true; + } + } + } + return false; + }; + + bool keys_allowed = true; + if (!keys.all_keys && id.first_key_pos() != 0 && (is_read_command || is_write_command)) { + const auto keys_index = DetermineKeys(&id, tail_args).value(); + const size_t end = keys_index.end; + if (keys_index.bonus) { + auto target = facade::ToSV(tail_args[*keys_index.bonus]); + if (!iterate_globs(target)) { + keys_allowed = false; + } + } + if (keys_allowed) { + for (size_t i = keys_index.start; i < end; i += keys_index.step) { + auto target = facade::ToSV(tail_args[i]); + if (!iterate_globs(target)) { + keys_allowed = false; + break; + } + } + } + } + return keys_allowed; } +#pragma GCC diagnostic pop + } // namespace dfly::acl diff --git a/src/server/acl/validator.h b/src/server/acl/validator.h index 5a4090982..004aedfbc 100644 --- a/src/server/acl/validator.h +++ b/src/server/acl/validator.h @@ -4,15 +4,20 @@ #pragma once +#include + #include "facade/command_id.h" +#include "server/acl/acl_log.h" #include "server/conn_context.h" namespace dfly::acl { bool IsUserAllowedToInvokeCommandGeneric(uint32_t acl_cat, const std::vector& acl_commands, - const facade::CommandId& id); + const AclKeys& keys, CmdArgList tail_args, + const CommandId& id); -bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const facade::CommandId& id); +bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id, + CmdArgList tail_args); } // namespace dfly::acl diff --git a/src/server/command_registry.h b/src/server/command_registry.h index c605b8fd2..39fa88352 100644 --- a/src/server/command_registry.h +++ b/src/server/command_registry.h @@ -92,6 +92,14 @@ class CommandId : public facade::CommandId { bool IsTransactional() const; + bool IsReadOnly() const { + return opt_mask_ & CO::CommandOpt::READONLY; + } + + bool IsWriteOnly() const { + return opt_mask_ & CO::CommandOpt::WRITE; + } + static const char* OptName(CO::CommandOpt fl); CommandId&& SetHandler(Handler f) && { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 97c33c254..ca8065d32 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -924,23 +924,25 @@ OpStatus CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_info, const C static optional VerifyConnectionAclStatus(const CommandId* cid, const ConnectionContext* cntx, - string_view error_msg) { + string_view error_msg, CmdArgList tail_args) { // If we are on a squashed context we need to use the owner, because the // context we are operating on is a stub and the acl username is not copied // See: MultiCommandSquasher::SquashedHopCb if (cntx->conn_state.squashing_info) cntx = cntx->conn_state.squashing_info->owner; - if (!acl::IsUserAllowedToInvokeCommand(*cntx, *cid)) { + if (!acl::IsUserAllowedToInvokeCommand(*cntx, *cid, tail_args)) { return ErrorReply(absl::StrCat("NOPERM: ", cntx->authed_username, " ", error_msg)); } return nullopt; } optional Service::VerifyCommandExecution(const CommandId* cid, - const ConnectionContext* cntx) { + const ConnectionContext* cntx, + CmdArgList tail_args) { // TODO: Move OOM check here - return VerifyConnectionAclStatus(cid, cntx, "ACL rules changed between the MULTI and EXEC"); + return VerifyConnectionAclStatus(cid, cntx, "ACL rules changed between the MULTI and EXEC", + tail_args); } std::optional Service::VerifyCommandState(const CommandId* cid, CmdArgList tail_args, @@ -960,7 +962,7 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA bool is_trans_cmd = CO::IsTransKind(cid->name()); bool under_script = dfly_cntx.conn_state.script_info != nullptr; - bool is_write_cmd = cid->opt_mask() & CO::WRITE; + bool is_write_cmd = cid->IsWriteOnly(); bool under_multi = dfly_cntx.conn_state.exec_info.IsCollecting() && !is_trans_cmd; // Check if the command is allowed to execute under this global state @@ -1037,7 +1039,7 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return ErrorReply{status}; } - return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions"); + return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); } void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { @@ -1064,7 +1066,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) } if (!dispatching_in_multi) { // Don't interrupt running multi commands - bool is_write = (cid->opt_mask() & CO::WRITE); + bool is_write = cid->IsWriteOnly(); is_write |= cid->name() == "PUBLISH" || cid->name() == "EVAL" || cid->name() == "EVALSHA"; is_write |= cid->name() == "EXEC" && dfly_cntx->conn_state.exec_info.is_write; @@ -1088,7 +1090,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // TODO: protect against aggregating huge transactions. StoredCmd stored_cmd{cid, args_no_cmd}; dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd)); - if (stored_cmd.Cid()->opt_mask() & CO::WRITE) { + if (stored_cmd.Cid()->IsWriteOnly()) { dfly_cntx->conn_state.exec_info.is_write = true; } return cntx->SendSimpleString("QUEUED"); @@ -1183,7 +1185,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo DCHECK(cid); DCHECK(!cid->Validate(tail_args)); - if (auto err = VerifyCommandExecution(cid, cntx); err) { + if (auto err = VerifyCommandExecution(cid, cntx, tail_args); err) { cntx->SendError(std::move(*err)); return true; // return false only for internal error aborts } diff --git a/src/server/main_service.h b/src/server/main_service.h index 8a7559359..f80e92b5b 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -58,7 +58,8 @@ class Service : public facade::ServiceInterface { // Verify command can be executed now (check out of memory), always called immediately before // execution std::optional VerifyCommandExecution(const CommandId* cid, - const ConnectionContext* cntx); + const ConnectionContext* cntx, + CmdArgList tail_args); // Verify command prepares excution in correct state. // It's usually called before command execution. Only for multi/exec transactions it's checked diff --git a/src/server/multi_command_squasher.cc b/src/server/multi_command_squasher.cc index 973da25b2..e1e6781d0 100644 --- a/src/server/multi_command_squasher.cc +++ b/src/server/multi_command_squasher.cc @@ -98,7 +98,7 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm auto& sinfo = PrepareShardInfo(last_sid); - sinfo.had_writes |= (cmd->Cid()->opt_mask() & CO::WRITE); + sinfo.had_writes |= (cmd->Cid()->IsWriteOnly()); sinfo.cmds.push_back(cmd); order_.push_back(last_sid); diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 354db74bd..fe65fef14 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1200,6 +1200,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) { auto cred = registry->GetCredentials(username); cntx->acl_categories = cred.acl_categories; cntx->acl_commands = cred.acl_commands; + cntx->keys = std::move(cred.keys); cntx->authenticated = true; return cntx->SendOk(); } diff --git a/src/server/transaction.cc b/src/server/transaction.cc index ebf29b034..7fad09129 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -31,7 +31,7 @@ constexpr size_t kTransSize [[maybe_unused]] = sizeof(Transaction); } // namespace IntentLock::Mode Transaction::Mode() const { - return (cid_->opt_mask() & CO::READONLY) ? IntentLock::SHARED : IntentLock::EXCLUSIVE; + return cid_->IsReadOnly() ? IntentLock::SHARED : IntentLock::EXCLUSIVE; } /** @@ -1362,7 +1362,7 @@ void Transaction::LogAutoJournalOnShard(EngineShard* shard) { return; // Only write commands and/or no-key-transactional commands are logged - if ((cid_->opt_mask() & CO::WRITE) == 0 && (cid_->opt_mask() & CO::NO_KEY_TRANSACTIONAL) == 0) + if ((cid_->IsWriteOnly()) == 0 && (cid_->opt_mask() & CO::NO_KEY_TRANSACTIONAL) == 0) return; // If autojournaling was disabled and not re-enabled, skip it diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index bbf7ec45e..fa7561204 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -65,7 +65,9 @@ async def test_acl_setuser(async_client): @pytest.mark.asyncio async def test_acl_categories(async_client): - await async_client.execute_command("ACL SETUSER vlad ON >mypass +@string +@list +@connection") + await async_client.execute_command( + "ACL SETUSER vlad ON >mypass +@string +@list +@connection ~*" + ) result = await async_client.execute_command("AUTH vlad mypass") assert result == "OK" @@ -114,7 +116,7 @@ async def test_acl_categories(async_client): @pytest.mark.asyncio async def test_acl_commands(async_client): - await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get") + await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get ~*") result = await async_client.execute_command("AUTH random mypass") assert result == "OK" @@ -134,7 +136,7 @@ async def test_acl_cat_commands_multi_exec_squash(df_local_factory): # Testing acl categories client = aioredis.Redis(port=df.port) - res = await client.execute_command("ACL SETUSER kk ON >kk +@transaction +@string") + res = await client.execute_command("ACL SETUSER kk ON >kk +@transaction +@string ~*") assert res == b"OK" res = await client.execute_command("AUTH kk kk") @@ -191,7 +193,7 @@ async def test_acl_cat_commands_multi_exec_squash(df_local_factory): # Testing acl commands client = aioredis.Redis(port=df.port) - res = await client.execute_command("ACL SETUSER myuser ON >kk +@transaction +set") + res = await client.execute_command("ACL SETUSER myuser ON >kk +@transaction +set ~*") assert res == b"OK" res = await client.execute_command("AUTH myuser kk") @@ -359,7 +361,7 @@ async def test_acl_log(async_client): res = await async_client.execute_command("ACL LOG") assert [] == res - await async_client.execute_command("ACL SETUSER elon >mars ON +@string +@dangerous") + await async_client.execute_command("ACL SETUSER elon >mars ON +@string +@dangerous ~*") with pytest.raises(redis.exceptions.AuthenticationError): await async_client.execute_command("AUTH elon wrong") @@ -472,3 +474,46 @@ async def test_set_len_acl_log(async_client): res = await async_client.execute_command("ACL LOG") assert 10 == len(res) + + +@pytest.mark.asyncio +async def test_acl_keys(async_client): + await async_client.execute_command("ACL SETUSER mrkeys ON >mrkeys allkeys +@admin") + await async_client.execute_command("AUTH mrkeys mrkeys") + + with pytest.raises(redis.exceptions.ResponseError): + await async_client.execute_command("SET foo bar") + + await async_client.execute_command( + "ACL SETUSER mrkeys ON >mrkeys resetkeys +@string ~foo ~bar* ~dr*gon" + ) + + with pytest.raises(redis.exceptions.ResponseError): + await async_client.execute_command("SET random rand") + + assert "OK" == await async_client.execute_command("SET foo val") + assert "OK" == await async_client.execute_command("SET bar val") + assert "OK" == await async_client.execute_command("SET barsomething val") + assert "OK" == await async_client.execute_command("SET dragon val") + + await async_client.execute_command("ACL SETUSER mrkeys ON >mrkeys allkeys +@sortedset") + assert "OK" == await async_client.execute_command("SET random rand") + + await async_client.execute_command( + "ACL SETUSER mrkeys ON >mrkeys resetkeys resetkeys %R~foo %W~bar" + ) + + with pytest.raises(redis.exceptions.ResponseError): + await async_client.execute_command("SET foo val") + assert "val" == await async_client.execute_command("GET foo") + + with pytest.raises(redis.exceptions.ResponseError): + await async_client.execute_command("GET bar") + assert "OK" == await async_client.execute_command("SET bar val") + + await async_client.execute_command("ACL SETUSER mrkeys resetkeys ~bar* +@sortedset") + assert 1 == await async_client.execute_command("ZADD barz1 1 val1") + assert 1 == await async_client.execute_command("ZADD barz2 1 val2") + # reject because bonus key does not match + with pytest.raises(redis.exceptions.ResponseError): + await async_client.execute_command("ZUNIONSTORE destkey 2 barz1 barz2")