feat(server): Add --lock_on_hashtags mode. (#1611)

* feat(server): Add `--lock_on_hashtags` mode.

This new mode effectively locks hashtags (i.e. strings within {curly
braces}) instead of the full keys being used.
This can allow scripts to access undeclared keys if they all use a
common hashtag, like for the case of BullMQ.

To make sure this mode is tested, I added a way to specify flags via env
variables, and modified `ci.yml` to run all tests using this mode as well.
While at it, I also added `--cluster_mode=emulated` mode to CI.
This commit is contained in:
Shahar Mike 2023-08-03 20:13:36 +03:00 committed by GitHub
parent 8040bed10f
commit 67a4c4e6cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 229 additions and 31 deletions

View file

@ -96,6 +96,13 @@ jobs:
echo Run ctest -V -L DFLY
#GLOG_logtostderr=1 GLOG_vmodule=transaction=1,engine_shard_set=1
GLOG_logtostderr=1 GLOG_vmodule=rdb_load=1,rdb_save=1,snapshot=1 ctest -V -L DFLY
echo "Running tests with --cluster_mode=emulated"
FLAGS_cluster_mode=emulated ctest -V -L DFLY
echo "Running tests with both --cluster_mode=emulated & --lock_on_hashtags"
FLAGS_cluster_mode=emulated FLAGS_lock_on_hashtags=true ctest -V -L DFLY
./dragonfly_test --gtest_repeat=10
./multi_test --multi_exec_mode=1 --gtest_repeat=10
./multi_test --multi_exec_mode=3 --gtest_repeat=10

View file

@ -16,6 +16,7 @@ extern "C" {
#include "redis/util.h"
}
#include "base/flags.h"
#include "base/logging.h"
#include "core/compact_object.h"
#include "server/engine_shard_set.h"
@ -24,11 +25,43 @@ extern "C" {
#include "server/server_state.h"
#include "server/transaction.h"
ABSL_FLAG(bool, lock_on_hashtags, false,
"When true, locks are done in the {hashtag} level instead of key level. "
"Only use this with --cluster_mode=emulated|yes.");
namespace dfly {
using namespace std;
using namespace util;
namespace {
// Thread-local cache with static linkage.
thread_local std::optional<bool> is_enabled_flag_cache;
} // namespace
void TEST_InvalidateLockHashTag() {
is_enabled_flag_cache = nullopt;
CHECK(shard_set != nullptr);
shard_set->pool()->Await(
[](ShardId shard, ProactorBase* proactor) { is_enabled_flag_cache = nullopt; });
}
bool KeyLockArgs::IsLockHashTagEnabled() {
if (!is_enabled_flag_cache.has_value()) {
is_enabled_flag_cache = absl::GetFlag(FLAGS_lock_on_hashtags);
}
return *is_enabled_flag_cache;
}
string_view KeyLockArgs::GetLockKey(string_view key) {
if (IsLockHashTagEnabled()) {
return ClusterConfig::KeyTag(key);
}
return key;
}
atomic_uint64_t used_mem_peak(0);
atomic_uint64_t used_mem_current(0);
unsigned kernel_version = 0;

View file

@ -50,6 +50,11 @@ class Transaction;
class EngineShard;
struct KeyLockArgs {
static bool IsLockHashTagEnabled();
// Before acquiring and releasing keys, one must "normalize" them via GetLockKey().
static std::string_view GetLockKey(std::string_view key);
DbIndex db_index = 0;
ArgSlice args;
unsigned key_step = 1;

View file

@ -175,7 +175,7 @@ class ConnectionContext : public facade::ConnectionContext {
void UnsubscribeAll(bool to_reply);
void PUnsubscribeAll(bool to_reply);
void ChangeMonitor(bool start); // either start or stop monitor on a given connection
void CancelBlocking(); // Cancel an ongoing blocking transaction if there is one.
void CancelBlocking(); // Cancel an ongoing blocking transaction if there is one.
// Whether this connection is a connection from a replica to its master.
// This flag is true only on replica side, where we need to setup a special ConnectionContext

View file

@ -755,12 +755,14 @@ bool DbSlice::Acquire(IntentLock::Mode mode, const KeyLockArgs& lock_args) {
bool lock_acquired = true;
if (lock_args.args.size() == 1) {
lock_acquired = lt[lock_args.args.front()].Acquire(mode);
string_view key = KeyLockArgs::GetLockKey(lock_args.args.front());
lock_acquired = lt[key].Acquire(mode);
uniq_keys_ = {key};
} else {
uniq_keys_.clear();
for (size_t i = 0; i < lock_args.args.size(); i += lock_args.key_step) {
auto s = lock_args.args[i];
auto s = KeyLockArgs::GetLockKey(lock_args.args[i]);
if (uniq_keys_.insert(s).second) {
bool res = lt[s].Acquire(mode);
lock_acquired &= res;
@ -774,18 +776,40 @@ bool DbSlice::Acquire(IntentLock::Mode mode, const KeyLockArgs& lock_args) {
return lock_acquired;
}
void DbSlice::Release(IntentLock::Mode mode, DbIndex db_index, std::string_view key,
unsigned count) {
return ReleaseNormalized(mode, db_index, KeyLockArgs::GetLockKey(key), count);
}
void DbSlice::ReleaseNormalized(IntentLock::Mode mode, DbIndex db_index, std::string_view key,
unsigned count) {
DCHECK_EQ(key, KeyLockArgs::GetLockKey(key));
DVLOG(1) << "Release " << IntentLock::ModeName(mode) << " " << count << " for " << key;
auto& lt = db_arr_[db_index]->trans_locks;
auto it = lt.find(KeyLockArgs::GetLockKey(key));
CHECK(it != lt.end()) << key;
it->second.Release(mode, count);
if (it->second.IsFree()) {
lt.erase(it);
}
}
void DbSlice::Release(IntentLock::Mode mode, const KeyLockArgs& lock_args) {
if (lock_args.args.empty()) {
return;
}
DVLOG(2) << "Release " << IntentLock::ModeName(mode) << " for " << lock_args.args[0];
if (lock_args.args.size() == 1) {
Release(mode, lock_args.db_index, lock_args.args.front(), 1);
string_view key = KeyLockArgs::GetLockKey(lock_args.args.front());
ReleaseNormalized(mode, lock_args.db_index, key, 1);
uniq_keys_ = {key};
} else {
auto& lt = db_arr_[lock_args.db_index]->trans_locks;
uniq_keys_.clear();
for (size_t i = 0; i < lock_args.args.size(); i += lock_args.key_step) {
auto s = lock_args.args[i];
auto s = KeyLockArgs::GetLockKey(lock_args.args[i]);
if (uniq_keys_.insert(s).second) {
auto it = lt.find(s);
CHECK(it != lt.end());
@ -807,9 +831,11 @@ bool DbSlice::CheckLock(IntentLock::Mode mode, DbIndex dbid, string_view key) co
}
bool DbSlice::CheckLock(IntentLock::Mode mode, const KeyLockArgs& lock_args) const {
uniq_keys_.clear();
const auto& lt = db_arr_[lock_args.db_index]->trans_locks;
for (size_t i = 0; i < lock_args.args.size(); i += lock_args.key_step) {
auto s = lock_args.args[i];
auto s = KeyLockArgs::GetLockKey(lock_args.args[i]);
uniq_keys_.insert(s);
auto it = lt.find(s);
if (it != lt.end() && !it->second.Check(mode)) {
return false;

View file

@ -230,9 +230,7 @@ class DbSlice {
void Release(IntentLock::Mode m, const KeyLockArgs& lock_args);
void Release(IntentLock::Mode m, DbIndex db_index, std::string_view key, unsigned count) {
db_arr_[db_index]->Release(m, key, count);
}
void Release(IntentLock::Mode m, DbIndex db_index, std::string_view key, unsigned count);
// Returns true if the key can be locked under m. Does not lock.
bool CheckLock(IntentLock::Mode m, DbIndex dbid, std::string_view key) const;
@ -313,6 +311,11 @@ class DbSlice {
caching_mode_ = 1;
}
// Test hook to inspect last locked keys.
absl::flat_hash_set<std::string_view> TEST_GetLastLockedKeys() const {
return uniq_keys_;
}
void RegisterWatchedKey(DbIndex db_indx, std::string_view key,
ConnectionState::ExecInfo* exec_info);
@ -329,6 +332,10 @@ class DbSlice {
}
private:
// Releases a single key. `key` must have been normalized by GetLockKey().
void ReleaseNormalized(IntentLock::Mode m, DbIndex db_index, std::string_view key,
unsigned count);
std::pair<PrimeIterator, bool> AddOrUpdateInternal(const Context& cntx, std::string_view key,
PrimeValue obj, uint64_t expire_at_ms,
bool force_update) noexcept(false);
@ -367,7 +374,7 @@ class DbSlice {
DbTableArray db_arr_;
// Used in temporary computations in Acquire/Release.
absl::flat_hash_set<std::string_view> uniq_keys_;
mutable absl::flat_hash_set<std::string_view> uniq_keys_;
// ordered from the smallest to largest version.
std::vector<std::pair<uint64_t, ChangeCallback>> change_cb_;

View file

@ -589,6 +589,11 @@ Service::Service(ProactorPool* pp)
CHECK(pp);
CHECK(shard_set == NULL);
if (KeyLockArgs::IsLockHashTagEnabled() && !ClusterConfig::IsEnabledOrEmulated()) {
LOG(ERROR) << "Setting --lock_on_hashtags without --cluster_mode is unsupported";
exit(1);
}
shard_set = new EngineShardSet(pp);
// We support less than 1024 threads and we support less than 1024 shards.
@ -736,14 +741,15 @@ OpStatus CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_info, const C
const auto& key_index = *key_index_res;
for (unsigned i = key_index.start; i < key_index.end; ++i) {
string_view key = ArgS(args, i);
string_view key = KeyLockArgs::GetLockKey(ArgS(args, i));
if (!eval_info.keys.contains(key)) {
VLOG(1) << "Key " << key << " is not declared for command " << cid->name();
return OpStatus::KEY_NOTFOUND;
}
}
if (key_index.bonus && !eval_info.keys.contains(ArgS(args, *key_index.bonus)))
if (key_index.bonus &&
!eval_info.keys.contains(KeyLockArgs::GetLockKey(ArgS(args, *key_index.bonus))))
return OpStatus::KEY_NOTFOUND;
return OpStatus::OK;
@ -770,7 +776,7 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionCon
}
bool is_trans_cmd = CO::IsTransKind(cid->name());
bool under_script = bool(dfly_cntx->conn_state.script_info);
bool under_script = dfly_cntx->conn_state.script_info != nullptr;
bool allowed_by_state = true;
switch (etl.gstate()) {
case GlobalState::LOADING:
@ -1356,7 +1362,7 @@ Transaction::MultiMode DetermineMultiMode(ScriptMgr::ScriptParams params) {
}
// Start multi transaction for eval. Returns true if transaction was scheduled.
// Skips scheduling if multi mode requies declaring keys, but no keys were declared.
// Skips scheduling if multi mode requires declaring keys, but no keys were declared.
bool StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams params,
Transaction* trans) {
Transaction::MultiMode multi_mode = DetermineMultiMode(params);
@ -1422,9 +1428,9 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
// and checking whether all invocations consist of RO commands.
// we can do it once during script insertion into script mgr.
auto& sinfo = cntx->conn_state.script_info;
sinfo.reset(new ConnectionState::ScriptInfo{});
sinfo = make_unique<ConnectionState::ScriptInfo>();
for (size_t i = 0; i < eval_args.keys.size(); ++i) {
sinfo->keys.insert(ArgS(eval_args.keys, i));
sinfo->keys.insert(KeyLockArgs::GetLockKey(ArgS(eval_args.keys, i)));
}
sinfo->async_cmds_heap_limit = absl::GetFlag(FLAGS_multi_eval_squash_buffer);
DCHECK(cntx->transaction);

View file

@ -691,4 +691,76 @@ TEST_F(StringFamilyTest, SetWithGetParam) {
EXPECT_EQ(Run({"get", "key2"}), "val3");
}
TEST_F(StringFamilyTest, SetWithHashtagsNoCluster) {
SetTestFlag("cluster_mode", "");
SetTestFlag("lock_on_hashtags", "false");
ResetService();
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK");
EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}1"), Not(Contains("key"))));
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK");
EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}2"), Not(Contains("key"))));
EXPECT_THAT(Run({"mget", "{key}1", "{key}2"}), RespArray(ElementsAre("val1", "val2")));
EXPECT_NE(1, GetDebugInfo().shards_count);
EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("{key}1", "{key}2"));
}
TEST_F(StringFamilyTest, SetWithHashtagsWithEmulatedCluster) {
SetTestFlag("cluster_mode", "emulated");
SetTestFlag("lock_on_hashtags", "false");
ResetService();
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK");
EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}1"), Not(Contains("key"))));
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK");
EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}2"), Not(Contains("key"))));
EXPECT_THAT(Run({"mget", "{key}1", "{key}2"}), RespArray(ElementsAre("val1", "val2")));
EXPECT_EQ(1, GetDebugInfo().shards_count);
EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("{key}1", "{key}2"));
}
TEST_F(StringFamilyTest, SetWithHashtagsWithHashtagLock) {
SetTestFlag("cluster_mode", "emulated");
SetTestFlag("lock_on_hashtags", "true");
ResetService();
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK");
EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("key"), Not(Contains("{key}1"))));
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK");
EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("key"), Not(Contains("{key}2"))));
EXPECT_THAT(Run({"mget", "{key}1", "{key}2"}), RespArray(ElementsAre("val1", "val2")));
EXPECT_EQ(1, GetDebugInfo().shards_count);
EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("key"));
}
TEST_F(StringFamilyTest, MultiSetWithHashtagsDontLockHashtags) {
SetTestFlag("cluster_mode", "");
SetTestFlag("lock_on_hashtags", "false");
ResetService();
EXPECT_EQ(Run({"multi"}), "OK");
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "QUEUED");
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "QUEUED");
EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK")));
EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("{key}1", "{key}2"));
}
TEST_F(StringFamilyTest, MultiSetWithHashtagsLockHashtags) {
SetTestFlag("cluster_mode", "emulated");
SetTestFlag("lock_on_hashtags", "true");
ResetService();
EXPECT_EQ(Run({"multi"}), "OK");
EXPECT_EQ(Run({"set", "{key}1", "val1"}), "QUEUED");
EXPECT_EQ(Run({"set", "{key}2", "val2"}), "QUEUED");
EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK")));
EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("key"));
}
} // namespace dfly

View file

@ -62,15 +62,4 @@ void DbTable::Clear() {
stats = DbTableStats{};
}
void DbTable::Release(IntentLock::Mode mode, std::string_view key, unsigned count) {
DVLOG(1) << "Release " << IntentLock::ModeName(mode) << " " << count << " for " << key;
auto it = trans_locks.find(key);
CHECK(it != trans_locks.end()) << key;
it->second.Release(mode, count);
if (it->second.IsFree()) {
trans_locks.erase(it);
}
}
} // namespace dfly

View file

@ -88,7 +88,6 @@ struct DbTable : boost::intrusive_ref_counter<DbTable, boost::thread_unsafe_coun
~DbTable();
void Clear();
void Release(IntentLock::Mode mode, std::string_view key, unsigned count);
};
// We use reference counting semantics of DbTable when doing snapshotting.

View file

@ -125,16 +125,44 @@ void BaseFamilyTest::SetUpTestSuite() {
absl::SetFlag(&FLAGS_dbfilename, "");
init_zmalloc_threadlocal(mi_heap_get_backing());
// TODO: go over all env variables starting with FLAGS_ and make sure they are in the below list.
static constexpr const char* kEnvFlags[] = {"cluster_mode", "lock_on_hashtags"};
for (string_view flag : kEnvFlags) {
const char* value = getenv(absl::StrCat("FLAGS_", flag).data());
if (value != nullptr) {
SetTestFlag(flag, value);
}
}
}
void BaseFamilyTest::SetUp() {
ResetService();
}
// Test hook defined in common.cc.
void TEST_InvalidateLockHashTag();
void BaseFamilyTest::ResetService() {
if (service_ != nullptr) {
TEST_InvalidateLockHashTag();
service_->Shutdown();
service_ = nullptr;
delete shard_set;
shard_set = nullptr;
pp_->Stop();
}
if (absl::GetFlag(FLAGS_force_epoll)) {
pp_.reset(fb2::Pool::Epoll(num_threads_));
} else {
pp_.reset(fb2::Pool::IOUring(16, num_threads_));
}
pp_->Run();
service_.reset(new Service{pp_.get()});
service_ = std::make_unique<Service>(pp_.get());
Service::InitOpts opts;
opts.disable_time_update = true;
@ -456,9 +484,31 @@ vector<string> BaseFamilyTest::StrArray(const RespExpr& expr) {
return res;
}
absl::flat_hash_set<string> BaseFamilyTest::GetLastUsedKeys() {
Mutex mu;
absl::flat_hash_set<string> result;
auto add_keys = [&](ProactorBase* proactor) {
EngineShard* shard = EngineShard::tlocal();
if (shard == nullptr) {
return;
}
lock_guard lk(mu);
for (string_view key : shard->db_slice().TEST_GetLastLockedKeys()) {
result.insert(string(key));
}
};
shard_set->pool()->AwaitFiberOnAll(add_keys);
return result;
}
void BaseFamilyTest::SetTestFlag(string_view flag_name, string_view new_value) {
auto* flag = absl::FindCommandLineFlag(flag_name);
CHECK_NE(flag, nullptr);
VLOG(1) << "Changing flag " << flag_name << " from " << flag->CurrentValue() << " to "
<< new_value;
string error;
CHECK(flag->ParseFrom(new_value, &error)) << "Error: " << error;
}

View file

@ -75,6 +75,8 @@ class BaseFamilyTest : public ::testing::Test {
int64_t CheckedInt(ArgSlice list);
std::string CheckedString(ArgSlice list);
void ResetService();
bool IsLocked(DbIndex db_index, std::string_view key) const;
ConnectionContext::DebugInfo GetDebugInfo(const std::string& id) const;
@ -102,9 +104,11 @@ class BaseFamilyTest : public ::testing::Test {
const facade::Connection::PubMessage& GetPublishedMessage(std::string_view conn_id,
size_t index) const;
static absl::flat_hash_set<std::string> GetLastUsedKeys();
static unsigned NumLocked();
void SetTestFlag(std::string_view flag_name, std::string_view new_value);
static void SetTestFlag(std::string_view flag_name, std::string_view new_value);
std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_;

View file

@ -170,7 +170,7 @@ void Transaction::InitMultiData(KeyIndex key_index) {
tmp_uniques.clear();
auto lock_key = [this, mode, &tmp_uniques](string_view key) {
if (auto [_, inserted] = tmp_uniques.insert(key); !inserted)
if (auto [_, inserted] = tmp_uniques.insert(KeyLockArgs::GetLockKey(key)); !inserted)
return;
multi_->lock_counts[key][mode]++;