feat(cmd): add restricted commands flag (#1967)

This commit is contained in:
Andy Dunstall 2023-09-29 16:16:06 +01:00 committed by GitHub
parent ba4eba991d
commit e6b8cd1d76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 0 deletions

View file

@ -80,6 +80,16 @@ class CommandId {
return bit_index_;
}
// Returns true if the command can only be used by admin connections, false
// otherwise.
bool IsRestricted() const {
return restricted_;
}
void SetRestricted(bool restricted) {
restricted_ = restricted;
}
static uint32_t OptCount(uint32_t mask);
protected:
@ -95,6 +105,9 @@ class CommandId {
// Acl commands indices
size_t family_;
uint64_t bit_index_;
// Whether the command can only be used by admin connections.
bool restricted_ = false;
};
} // namespace facade

View file

@ -20,6 +20,8 @@ using namespace std;
ABSL_FLAG(vector<string>, rename_command, {},
"Change the name of commands, format is: <cmd1_name>=<cmd1_new_name>, "
"<cmd2_name>=<cmd2_new_name>");
ABSL_FLAG(vector<string>, restricted_commands, {},
"Commands restricted to connections on the admin port");
namespace dfly {
@ -87,6 +89,10 @@ CommandRegistry::CommandRegistry() {
exit(1);
}
}
for (string name : GetFlag(FLAGS_restricted_commands)) {
restricted_cmds_.emplace(AsciiStrToUpper(name));
}
}
void CommandRegistry::Init(unsigned int thread_count) {
@ -105,6 +111,10 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
k = it->second;
}
if (restricted_cmds_.find(k) != restricted_cmds_.end()) {
cmd.SetRestricted(true);
}
family_of_commands_.back().push_back(std::string(k));
cmd.SetFamily(family_of_commands_.size() - 1);
cmd.SetBitIndex(1ULL << bit_index_++);

View file

@ -5,6 +5,7 @@
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>
#include <functional>
@ -169,6 +170,7 @@ class CommandRegistry {
private:
absl::flat_hash_map<std::string_view, CommandId> cmd_map_;
absl::flat_hash_map<std::string, std::string> cmd_rename_map_;
absl::flat_hash_set<std::string> restricted_cmds_;
FamiliesVec family_of_commands_;
size_t bit_index_;

View file

@ -821,6 +821,12 @@ std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdA
ServerState& etl = *ServerState::tlocal();
// If there is no connection owner, it means the command it being called
// from another command or used internally, therefore is always permitted.
if (dfly_cntx.owner() != nullptr && !dfly_cntx.owner()->IsAdmin() && cid->IsRestricted()) {
return ErrorReply{"Cannot execute restricted command (admin only)"};
}
if (auto err = cid->Validate(tail_args); err)
return err;

View file

@ -103,3 +103,19 @@ async def test_unknown_dfly_env(df_local_factory, export_dfly_password):
with pytest.raises(DflyStartException):
dfly = df_local_factory.create()
dfly.start()
async def test_restricted_commands(df_local_factory):
# Restrict GET and SET, then verify non-admin clients are blocked from
# using these commands, though admin clients can use them.
with df_local_factory.create(restricted_commands="get,set", admin_port=1112) as server:
async with aioredis.Redis(port=server.port) as client:
with pytest.raises(redis.exceptions.ResponseError):
await client.get("foo")
with pytest.raises(redis.exceptions.ResponseError):
await client.set("foo", "bar")
async with aioredis.Redis(port=server.admin_port) as admin_client:
await admin_client.get("foo")
await admin_client.set("foo", "bar")