mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 10:25:47 +02:00
feat(cmd): add restricted commands flag (#1967)
This commit is contained in:
parent
ba4eba991d
commit
e6b8cd1d76
5 changed files with 47 additions and 0 deletions
|
@ -80,6 +80,16 @@ class CommandId {
|
|||
return bit_index_;
|
||||
}
|
||||
|
||||
// Returns true if the command can only be used by admin connections, false
|
||||
// otherwise.
|
||||
bool IsRestricted() const {
|
||||
return restricted_;
|
||||
}
|
||||
|
||||
void SetRestricted(bool restricted) {
|
||||
restricted_ = restricted;
|
||||
}
|
||||
|
||||
static uint32_t OptCount(uint32_t mask);
|
||||
|
||||
protected:
|
||||
|
@ -95,6 +105,9 @@ class CommandId {
|
|||
// Acl commands indices
|
||||
size_t family_;
|
||||
uint64_t bit_index_;
|
||||
|
||||
// Whether the command can only be used by admin connections.
|
||||
bool restricted_ = false;
|
||||
};
|
||||
|
||||
} // namespace facade
|
||||
|
|
|
@ -20,6 +20,8 @@ using namespace std;
|
|||
ABSL_FLAG(vector<string>, rename_command, {},
|
||||
"Change the name of commands, format is: <cmd1_name>=<cmd1_new_name>, "
|
||||
"<cmd2_name>=<cmd2_new_name>");
|
||||
ABSL_FLAG(vector<string>, restricted_commands, {},
|
||||
"Commands restricted to connections on the admin port");
|
||||
|
||||
namespace dfly {
|
||||
|
||||
|
@ -87,6 +89,10 @@ CommandRegistry::CommandRegistry() {
|
|||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
for (string name : GetFlag(FLAGS_restricted_commands)) {
|
||||
restricted_cmds_.emplace(AsciiStrToUpper(name));
|
||||
}
|
||||
}
|
||||
|
||||
void CommandRegistry::Init(unsigned int thread_count) {
|
||||
|
@ -105,6 +111,10 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
|
|||
k = it->second;
|
||||
}
|
||||
|
||||
if (restricted_cmds_.find(k) != restricted_cmds_.end()) {
|
||||
cmd.SetRestricted(true);
|
||||
}
|
||||
|
||||
family_of_commands_.back().push_back(std::string(k));
|
||||
cmd.SetFamily(family_of_commands_.size() - 1);
|
||||
cmd.SetBitIndex(1ULL << bit_index_++);
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <absl/container/flat_hash_map.h>
|
||||
#include <absl/container/flat_hash_set.h>
|
||||
#include <absl/types/span.h>
|
||||
|
||||
#include <functional>
|
||||
|
@ -169,6 +170,7 @@ class CommandRegistry {
|
|||
private:
|
||||
absl::flat_hash_map<std::string_view, CommandId> cmd_map_;
|
||||
absl::flat_hash_map<std::string, std::string> cmd_rename_map_;
|
||||
absl::flat_hash_set<std::string> restricted_cmds_;
|
||||
|
||||
FamiliesVec family_of_commands_;
|
||||
size_t bit_index_;
|
||||
|
|
|
@ -821,6 +821,12 @@ std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdA
|
|||
|
||||
ServerState& etl = *ServerState::tlocal();
|
||||
|
||||
// If there is no connection owner, it means the command it being called
|
||||
// from another command or used internally, therefore is always permitted.
|
||||
if (dfly_cntx.owner() != nullptr && !dfly_cntx.owner()->IsAdmin() && cid->IsRestricted()) {
|
||||
return ErrorReply{"Cannot execute restricted command (admin only)"};
|
||||
}
|
||||
|
||||
if (auto err = cid->Validate(tail_args); err)
|
||||
return err;
|
||||
|
||||
|
|
|
@ -103,3 +103,19 @@ async def test_unknown_dfly_env(df_local_factory, export_dfly_password):
|
|||
with pytest.raises(DflyStartException):
|
||||
dfly = df_local_factory.create()
|
||||
dfly.start()
|
||||
|
||||
|
||||
async def test_restricted_commands(df_local_factory):
|
||||
# Restrict GET and SET, then verify non-admin clients are blocked from
|
||||
# using these commands, though admin clients can use them.
|
||||
with df_local_factory.create(restricted_commands="get,set", admin_port=1112) as server:
|
||||
async with aioredis.Redis(port=server.port) as client:
|
||||
with pytest.raises(redis.exceptions.ResponseError):
|
||||
await client.get("foo")
|
||||
|
||||
with pytest.raises(redis.exceptions.ResponseError):
|
||||
await client.set("foo", "bar")
|
||||
|
||||
async with aioredis.Redis(port=server.admin_port) as admin_client:
|
||||
await admin_client.get("foo")
|
||||
await admin_client.set("foo", "bar")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue