diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 97e4b4b23..456793a32 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -618,6 +618,15 @@ void AclFamily::Init(facade::Listener* main_listener, UserRegistry* registry) { } registry_->Init(); config_registry.RegisterMutable("aclfile"); + config_registry.RegisterMutable("acllog_max_len", [this](const absl::CommandLineFlag& flag) { + auto res = flag.TryGet(); + if (res.has_value()) { + pool_->AwaitFiberOnAll([&res](auto index, auto* context) { + ServerState::tlocal()->acl_log.SetTotalEntries(res.value()); + }); + } + return res.has_value(); + }); } } // namespace dfly::acl diff --git a/src/server/acl/acl_log.cc b/src/server/acl/acl_log.cc index d7bf39da6..9e05627a5 100644 --- a/src/server/acl/acl_log.cc +++ b/src/server/acl/acl_log.cc @@ -55,4 +55,12 @@ AclLog::LogType AclLog::GetLog(size_t number_of_entries) const { return {start, end}; } +void AclLog::SetTotalEntries(size_t total_entries) { + if (log_.size() > total_entries) { + log_.erase(std::next(log_.begin(), total_entries), log_.end()); + } + + total_entries_allowed_ = total_entries; +} + } // namespace dfly::acl diff --git a/src/server/acl/acl_log.h b/src/server/acl/acl_log.h index ced2af7ec..0a588e47d 100644 --- a/src/server/acl/acl_log.h +++ b/src/server/acl/acl_log.h @@ -42,9 +42,11 @@ class AclLog { LogType GetLog(size_t number_of_entries) const; + void SetTotalEntries(size_t total_entries); + private: LogType log_; - const size_t total_entries_allowed_; + size_t total_entries_allowed_; }; } // namespace dfly::acl diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index da11cd08c..f4841e166 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -439,3 +439,33 @@ async def test_set_acl_file(async_client: aioredis.Redis, tmp_dir): result = await async_client.execute_command("AUTH roy mypass") assert result == "OK" + + +@pytest.mark.asyncio +@dfly_args({"proactor_threads": 1}) +async def test_set_len_acl_log(async_client): + res = await async_client.execute_command("ACL LOG") + assert [] == res + + await async_client.execute_command("ACL SETUSER elon >mars ON +@string +@dangerous") + + for x in range(7): + with pytest.raises(redis.exceptions.AuthenticationError): + await async_client.execute_command("AUTH elon wrong") + + res = await async_client.execute_command("ACL LOG") + assert 7 == len(res) + + await async_client.execute_command(f"CONFIG SET acllog_max_len 3") + + res = await async_client.execute_command("ACL LOG") + assert 3 == len(res) + + await async_client.execute_command(f"CONFIG SET acllog_max_len 10") + + for x in range(7): + with pytest.raises(redis.exceptions.AuthenticationError): + await async_client.execute_command("AUTH elon wrong") + + res = await async_client.execute_command("ACL LOG") + assert 10 == len(res)