diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bae8894e..ab4b17797 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,6 +96,13 @@ jobs: echo Run ctest -V -L DFLY #GLOG_logtostderr=1 GLOG_vmodule=transaction=1,engine_shard_set=1 GLOG_logtostderr=1 GLOG_vmodule=rdb_load=1,rdb_save=1,snapshot=1 ctest -V -L DFLY + + echo "Running tests with --cluster_mode=emulated" + FLAGS_cluster_mode=emulated ctest -V -L DFLY + + echo "Running tests with both --cluster_mode=emulated & --lock_on_hashtags" + FLAGS_cluster_mode=emulated FLAGS_lock_on_hashtags=true ctest -V -L DFLY + ./dragonfly_test --gtest_repeat=10 ./multi_test --multi_exec_mode=1 --gtest_repeat=10 ./multi_test --multi_exec_mode=3 --gtest_repeat=10 diff --git a/src/server/common.cc b/src/server/common.cc index 02d68b98d..81e2315db 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -16,6 +16,7 @@ extern "C" { #include "redis/util.h" } +#include "base/flags.h" #include "base/logging.h" #include "core/compact_object.h" #include "server/engine_shard_set.h" @@ -24,11 +25,43 @@ extern "C" { #include "server/server_state.h" #include "server/transaction.h" +ABSL_FLAG(bool, lock_on_hashtags, false, + "When true, locks are done in the {hashtag} level instead of key level. " + "Only use this with --cluster_mode=emulated|yes."); + namespace dfly { using namespace std; using namespace util; +namespace { +// Thread-local cache with static linkage. +thread_local std::optional is_enabled_flag_cache; +} // namespace + +void TEST_InvalidateLockHashTag() { + is_enabled_flag_cache = nullopt; + CHECK(shard_set != nullptr); + shard_set->pool()->Await( + [](ShardId shard, ProactorBase* proactor) { is_enabled_flag_cache = nullopt; }); +} + +bool KeyLockArgs::IsLockHashTagEnabled() { + if (!is_enabled_flag_cache.has_value()) { + is_enabled_flag_cache = absl::GetFlag(FLAGS_lock_on_hashtags); + } + + return *is_enabled_flag_cache; +} + +string_view KeyLockArgs::GetLockKey(string_view key) { + if (IsLockHashTagEnabled()) { + return ClusterConfig::KeyTag(key); + } + + return key; +} + atomic_uint64_t used_mem_peak(0); atomic_uint64_t used_mem_current(0); unsigned kernel_version = 0; diff --git a/src/server/common.h b/src/server/common.h index bfe762949..56c24300b 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -50,6 +50,11 @@ class Transaction; class EngineShard; struct KeyLockArgs { + static bool IsLockHashTagEnabled(); + + // Before acquiring and releasing keys, one must "normalize" them via GetLockKey(). + static std::string_view GetLockKey(std::string_view key); + DbIndex db_index = 0; ArgSlice args; unsigned key_step = 1; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index e2cabc401..026e72258 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -175,7 +175,7 @@ class ConnectionContext : public facade::ConnectionContext { void UnsubscribeAll(bool to_reply); void PUnsubscribeAll(bool to_reply); void ChangeMonitor(bool start); // either start or stop monitor on a given connection - void CancelBlocking(); // Cancel an ongoing blocking transaction if there is one. + void CancelBlocking(); // Cancel an ongoing blocking transaction if there is one. // Whether this connection is a connection from a replica to its master. // This flag is true only on replica side, where we need to setup a special ConnectionContext diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 8a0c4b199..7b2fc493b 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -755,12 +755,14 @@ bool DbSlice::Acquire(IntentLock::Mode mode, const KeyLockArgs& lock_args) { bool lock_acquired = true; if (lock_args.args.size() == 1) { - lock_acquired = lt[lock_args.args.front()].Acquire(mode); + string_view key = KeyLockArgs::GetLockKey(lock_args.args.front()); + lock_acquired = lt[key].Acquire(mode); + uniq_keys_ = {key}; } else { uniq_keys_.clear(); for (size_t i = 0; i < lock_args.args.size(); i += lock_args.key_step) { - auto s = lock_args.args[i]; + auto s = KeyLockArgs::GetLockKey(lock_args.args[i]); if (uniq_keys_.insert(s).second) { bool res = lt[s].Acquire(mode); lock_acquired &= res; @@ -774,18 +776,40 @@ bool DbSlice::Acquire(IntentLock::Mode mode, const KeyLockArgs& lock_args) { return lock_acquired; } +void DbSlice::Release(IntentLock::Mode mode, DbIndex db_index, std::string_view key, + unsigned count) { + return ReleaseNormalized(mode, db_index, KeyLockArgs::GetLockKey(key), count); +} + +void DbSlice::ReleaseNormalized(IntentLock::Mode mode, DbIndex db_index, std::string_view key, + unsigned count) { + DCHECK_EQ(key, KeyLockArgs::GetLockKey(key)); + DVLOG(1) << "Release " << IntentLock::ModeName(mode) << " " << count << " for " << key; + + auto& lt = db_arr_[db_index]->trans_locks; + auto it = lt.find(KeyLockArgs::GetLockKey(key)); + CHECK(it != lt.end()) << key; + it->second.Release(mode, count); + if (it->second.IsFree()) { + lt.erase(it); + } +} + void DbSlice::Release(IntentLock::Mode mode, const KeyLockArgs& lock_args) { if (lock_args.args.empty()) { return; } + DVLOG(2) << "Release " << IntentLock::ModeName(mode) << " for " << lock_args.args[0]; if (lock_args.args.size() == 1) { - Release(mode, lock_args.db_index, lock_args.args.front(), 1); + string_view key = KeyLockArgs::GetLockKey(lock_args.args.front()); + ReleaseNormalized(mode, lock_args.db_index, key, 1); + uniq_keys_ = {key}; } else { auto& lt = db_arr_[lock_args.db_index]->trans_locks; uniq_keys_.clear(); for (size_t i = 0; i < lock_args.args.size(); i += lock_args.key_step) { - auto s = lock_args.args[i]; + auto s = KeyLockArgs::GetLockKey(lock_args.args[i]); if (uniq_keys_.insert(s).second) { auto it = lt.find(s); CHECK(it != lt.end()); @@ -807,9 +831,11 @@ bool DbSlice::CheckLock(IntentLock::Mode mode, DbIndex dbid, string_view key) co } bool DbSlice::CheckLock(IntentLock::Mode mode, const KeyLockArgs& lock_args) const { + uniq_keys_.clear(); const auto& lt = db_arr_[lock_args.db_index]->trans_locks; for (size_t i = 0; i < lock_args.args.size(); i += lock_args.key_step) { - auto s = lock_args.args[i]; + auto s = KeyLockArgs::GetLockKey(lock_args.args[i]); + uniq_keys_.insert(s); auto it = lt.find(s); if (it != lt.end() && !it->second.Check(mode)) { return false; diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 10d87b127..eb8baed33 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -230,9 +230,7 @@ class DbSlice { void Release(IntentLock::Mode m, const KeyLockArgs& lock_args); - void Release(IntentLock::Mode m, DbIndex db_index, std::string_view key, unsigned count) { - db_arr_[db_index]->Release(m, key, count); - } + void Release(IntentLock::Mode m, DbIndex db_index, std::string_view key, unsigned count); // Returns true if the key can be locked under m. Does not lock. bool CheckLock(IntentLock::Mode m, DbIndex dbid, std::string_view key) const; @@ -313,6 +311,11 @@ class DbSlice { caching_mode_ = 1; } + // Test hook to inspect last locked keys. + absl::flat_hash_set TEST_GetLastLockedKeys() const { + return uniq_keys_; + } + void RegisterWatchedKey(DbIndex db_indx, std::string_view key, ConnectionState::ExecInfo* exec_info); @@ -329,6 +332,10 @@ class DbSlice { } private: + // Releases a single key. `key` must have been normalized by GetLockKey(). + void ReleaseNormalized(IntentLock::Mode m, DbIndex db_index, std::string_view key, + unsigned count); + std::pair AddOrUpdateInternal(const Context& cntx, std::string_view key, PrimeValue obj, uint64_t expire_at_ms, bool force_update) noexcept(false); @@ -367,7 +374,7 @@ class DbSlice { DbTableArray db_arr_; // Used in temporary computations in Acquire/Release. - absl::flat_hash_set uniq_keys_; + mutable absl::flat_hash_set uniq_keys_; // ordered from the smallest to largest version. std::vector> change_cb_; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 29967de0d..b278f7c6e 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -589,6 +589,11 @@ Service::Service(ProactorPool* pp) CHECK(pp); CHECK(shard_set == NULL); + if (KeyLockArgs::IsLockHashTagEnabled() && !ClusterConfig::IsEnabledOrEmulated()) { + LOG(ERROR) << "Setting --lock_on_hashtags without --cluster_mode is unsupported"; + exit(1); + } + shard_set = new EngineShardSet(pp); // We support less than 1024 threads and we support less than 1024 shards. @@ -736,14 +741,15 @@ OpStatus CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_info, const C const auto& key_index = *key_index_res; for (unsigned i = key_index.start; i < key_index.end; ++i) { - string_view key = ArgS(args, i); + string_view key = KeyLockArgs::GetLockKey(ArgS(args, i)); if (!eval_info.keys.contains(key)) { VLOG(1) << "Key " << key << " is not declared for command " << cid->name(); return OpStatus::KEY_NOTFOUND; } } - if (key_index.bonus && !eval_info.keys.contains(ArgS(args, *key_index.bonus))) + if (key_index.bonus && + !eval_info.keys.contains(KeyLockArgs::GetLockKey(ArgS(args, *key_index.bonus)))) return OpStatus::KEY_NOTFOUND; return OpStatus::OK; @@ -770,7 +776,7 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionCon } bool is_trans_cmd = CO::IsTransKind(cid->name()); - bool under_script = bool(dfly_cntx->conn_state.script_info); + bool under_script = dfly_cntx->conn_state.script_info != nullptr; bool allowed_by_state = true; switch (etl.gstate()) { case GlobalState::LOADING: @@ -1356,7 +1362,7 @@ Transaction::MultiMode DetermineMultiMode(ScriptMgr::ScriptParams params) { } // Start multi transaction for eval. Returns true if transaction was scheduled. -// Skips scheduling if multi mode requies declaring keys, but no keys were declared. +// Skips scheduling if multi mode requires declaring keys, but no keys were declared. bool StartMultiEval(DbIndex dbid, CmdArgList keys, ScriptMgr::ScriptParams params, Transaction* trans) { Transaction::MultiMode multi_mode = DetermineMultiMode(params); @@ -1422,9 +1428,9 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, // and checking whether all invocations consist of RO commands. // we can do it once during script insertion into script mgr. auto& sinfo = cntx->conn_state.script_info; - sinfo.reset(new ConnectionState::ScriptInfo{}); + sinfo = make_unique(); for (size_t i = 0; i < eval_args.keys.size(); ++i) { - sinfo->keys.insert(ArgS(eval_args.keys, i)); + sinfo->keys.insert(KeyLockArgs::GetLockKey(ArgS(eval_args.keys, i))); } sinfo->async_cmds_heap_limit = absl::GetFlag(FLAGS_multi_eval_squash_buffer); DCHECK(cntx->transaction); diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index 0be828e75..c82ee9680 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -691,4 +691,76 @@ TEST_F(StringFamilyTest, SetWithGetParam) { EXPECT_EQ(Run({"get", "key2"}), "val3"); } +TEST_F(StringFamilyTest, SetWithHashtagsNoCluster) { + SetTestFlag("cluster_mode", ""); + SetTestFlag("lock_on_hashtags", "false"); + ResetService(); + + EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK"); + EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}1"), Not(Contains("key")))); + + EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK"); + EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}2"), Not(Contains("key")))); + + EXPECT_THAT(Run({"mget", "{key}1", "{key}2"}), RespArray(ElementsAre("val1", "val2"))); + EXPECT_NE(1, GetDebugInfo().shards_count); + EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("{key}1", "{key}2")); +} + +TEST_F(StringFamilyTest, SetWithHashtagsWithEmulatedCluster) { + SetTestFlag("cluster_mode", "emulated"); + SetTestFlag("lock_on_hashtags", "false"); + ResetService(); + + EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK"); + EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}1"), Not(Contains("key")))); + + EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK"); + EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("{key}2"), Not(Contains("key")))); + + EXPECT_THAT(Run({"mget", "{key}1", "{key}2"}), RespArray(ElementsAre("val1", "val2"))); + EXPECT_EQ(1, GetDebugInfo().shards_count); + EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("{key}1", "{key}2")); +} + +TEST_F(StringFamilyTest, SetWithHashtagsWithHashtagLock) { + SetTestFlag("cluster_mode", "emulated"); + SetTestFlag("lock_on_hashtags", "true"); + ResetService(); + + EXPECT_EQ(Run({"set", "{key}1", "val1"}), "OK"); + EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("key"), Not(Contains("{key}1")))); + + EXPECT_EQ(Run({"set", "{key}2", "val2"}), "OK"); + EXPECT_THAT(GetLastUsedKeys(), AllOf(Contains("key"), Not(Contains("{key}2")))); + + EXPECT_THAT(Run({"mget", "{key}1", "{key}2"}), RespArray(ElementsAre("val1", "val2"))); + EXPECT_EQ(1, GetDebugInfo().shards_count); + EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("key")); +} + +TEST_F(StringFamilyTest, MultiSetWithHashtagsDontLockHashtags) { + SetTestFlag("cluster_mode", ""); + SetTestFlag("lock_on_hashtags", "false"); + ResetService(); + + EXPECT_EQ(Run({"multi"}), "OK"); + EXPECT_EQ(Run({"set", "{key}1", "val1"}), "QUEUED"); + EXPECT_EQ(Run({"set", "{key}2", "val2"}), "QUEUED"); + EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK"))); + EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("{key}1", "{key}2")); +} + +TEST_F(StringFamilyTest, MultiSetWithHashtagsLockHashtags) { + SetTestFlag("cluster_mode", "emulated"); + SetTestFlag("lock_on_hashtags", "true"); + ResetService(); + + EXPECT_EQ(Run({"multi"}), "OK"); + EXPECT_EQ(Run({"set", "{key}1", "val1"}), "QUEUED"); + EXPECT_EQ(Run({"set", "{key}2", "val2"}), "QUEUED"); + EXPECT_THAT(Run({"exec"}), RespArray(ElementsAre("OK", "OK"))); + EXPECT_THAT(GetLastUsedKeys(), UnorderedElementsAre("key")); +} + } // namespace dfly diff --git a/src/server/table.cc b/src/server/table.cc index 8363e0e6c..a595ab19c 100644 --- a/src/server/table.cc +++ b/src/server/table.cc @@ -62,15 +62,4 @@ void DbTable::Clear() { stats = DbTableStats{}; } -void DbTable::Release(IntentLock::Mode mode, std::string_view key, unsigned count) { - DVLOG(1) << "Release " << IntentLock::ModeName(mode) << " " << count << " for " << key; - - auto it = trans_locks.find(key); - CHECK(it != trans_locks.end()) << key; - it->second.Release(mode, count); - if (it->second.IsFree()) { - trans_locks.erase(it); - } -} - } // namespace dfly diff --git a/src/server/table.h b/src/server/table.h index 444b8df00..f89077cdb 100644 --- a/src/server/table.h +++ b/src/server/table.h @@ -88,7 +88,6 @@ struct DbTable : boost::intrusive_ref_counterShutdown(); + service_ = nullptr; + + delete shard_set; + shard_set = nullptr; + + pp_->Stop(); + } + if (absl::GetFlag(FLAGS_force_epoll)) { pp_.reset(fb2::Pool::Epoll(num_threads_)); } else { pp_.reset(fb2::Pool::IOUring(16, num_threads_)); } pp_->Run(); - service_.reset(new Service{pp_.get()}); + service_ = std::make_unique(pp_.get()); Service::InitOpts opts; opts.disable_time_update = true; @@ -456,9 +484,31 @@ vector BaseFamilyTest::StrArray(const RespExpr& expr) { return res; } +absl::flat_hash_set BaseFamilyTest::GetLastUsedKeys() { + Mutex mu; + absl::flat_hash_set result; + + auto add_keys = [&](ProactorBase* proactor) { + EngineShard* shard = EngineShard::tlocal(); + if (shard == nullptr) { + return; + } + + lock_guard lk(mu); + for (string_view key : shard->db_slice().TEST_GetLastLockedKeys()) { + result.insert(string(key)); + } + }; + shard_set->pool()->AwaitFiberOnAll(add_keys); + + return result; +} + void BaseFamilyTest::SetTestFlag(string_view flag_name, string_view new_value) { auto* flag = absl::FindCommandLineFlag(flag_name); CHECK_NE(flag, nullptr); + VLOG(1) << "Changing flag " << flag_name << " from " << flag->CurrentValue() << " to " + << new_value; string error; CHECK(flag->ParseFrom(new_value, &error)) << "Error: " << error; } diff --git a/src/server/test_utils.h b/src/server/test_utils.h index 4d8fdf43c..737bae150 100644 --- a/src/server/test_utils.h +++ b/src/server/test_utils.h @@ -75,6 +75,8 @@ class BaseFamilyTest : public ::testing::Test { int64_t CheckedInt(ArgSlice list); std::string CheckedString(ArgSlice list); + void ResetService(); + bool IsLocked(DbIndex db_index, std::string_view key) const; ConnectionContext::DebugInfo GetDebugInfo(const std::string& id) const; @@ -102,9 +104,11 @@ class BaseFamilyTest : public ::testing::Test { const facade::Connection::PubMessage& GetPublishedMessage(std::string_view conn_id, size_t index) const; + static absl::flat_hash_set GetLastUsedKeys(); + static unsigned NumLocked(); - void SetTestFlag(std::string_view flag_name, std::string_view new_value); + static void SetTestFlag(std::string_view flag_name, std::string_view new_value); std::unique_ptr pp_; std::unique_ptr service_; diff --git a/src/server/transaction.cc b/src/server/transaction.cc index abd43f40a..0ec8495cb 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -170,7 +170,7 @@ void Transaction::InitMultiData(KeyIndex key_index) { tmp_uniques.clear(); auto lock_key = [this, mode, &tmp_uniques](string_view key) { - if (auto [_, inserted] = tmp_uniques.insert(key); !inserted) + if (auto [_, inserted] = tmp_uniques.insert(KeyLockArgs::GetLockKey(key)); !inserted) return; multi_->lock_counts[key][mode]++;