From 89b1d7d52a7c09db288a4026c796d901d51eb0b5 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Wed, 24 Apr 2024 13:36:34 +0300 Subject: [PATCH] chore: Introduce ShardArgs as a distinct type (#2952) Done in preparation to make ShardArgs a smart iterable type, but currently it's just a wrapper aroung ArgSlice. Also refactored common.{h,cc} into tx_base.{h,cc} In addition, fixed a bug in key tracking, where we wrongly created weak_ref in a shard thread instead of doing this in the coordinator thread. Finally, identified another bug (not fixed yet) where we track all the arguments instead of tracking keys only. Besides this, no functional changes around the moved code. Signed-off-by: Roman Gershman --- src/server/CMakeLists.txt | 2 +- src/server/bitops_family.cc | 33 +++--- src/server/blocking_controller.cc | 7 +- src/server/blocking_controller.h | 5 +- src/server/common.cc | 44 -------- src/server/common.h | 108 +------------------ src/server/conn_context.h | 1 + src/server/container_utils.cc | 34 ++++-- src/server/db_slice.cc | 48 ++------- src/server/db_slice.h | 20 ++-- src/server/generic_family.cc | 86 +++++++-------- src/server/generic_family.h | 4 +- src/server/hll_family.cc | 10 +- src/server/json_family.cc | 27 +++-- src/server/list_family.cc | 10 +- src/server/main_service.cc | 32 ++++-- src/server/set_family.cc | 102 +++++++++--------- src/server/stream_family.cc | 50 +++++---- src/server/stream_family_test.cc | 5 +- src/server/string_family.cc | 52 ++++----- src/server/transaction.cc | 39 ++++--- src/server/transaction.h | 11 +- src/server/tx_base.cc | 61 +++++++++++ src/server/tx_base.h | 172 ++++++++++++++++++++++++++++++ src/server/zset_family.cc | 81 +++++++------- 25 files changed, 568 insertions(+), 476 deletions(-) create mode 100644 src/server/tx_base.cc create mode 100644 src/server/tx_base.h diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 14412c906..57e751451 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -31,7 +31,7 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc command_registry.cc cluster/unique_slot_checker.cc journal/tx_executor.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc - server_state.cc table.cc top_keys.cc transaction.cc + server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc ) diff --git a/src/server/bitops_family.cc b/src/server/bitops_family.cc index e8e01c6ea..11bad8896 100644 --- a/src/server/bitops_family.cc +++ b/src/server/bitops_family.cc @@ -24,6 +24,7 @@ namespace dfly { using namespace facade; +using namespace std; namespace { @@ -57,7 +58,6 @@ bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry); std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end); std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits); std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end); -OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args, ArgSlice keys); std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& values); // ------------------------------------------------------------------------- // @@ -444,12 +444,9 @@ OpResult CombineResultOp(ShardStringResults result, std::string_vie } // For bitop not - we cannot accumulate -OpResult RunBitOpNot(const OpArgs& op_args, ArgSlice keys) { - DCHECK(keys.size() == 1); - +OpResult RunBitOpNot(const OpArgs& op_args, string_view key) { EngineShard* es = op_args.shard; // if we found the value, just return, if not found then skip, otherwise report an error - auto key = keys.front(); auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, key, OBJ_STRING); if (find_res) { return GetString(find_res.value()->second); @@ -460,18 +457,18 @@ OpResult RunBitOpNot(const OpArgs& op_args, ArgSlice keys) { // Read only operation where we are running the bit operation on all the // values that belong to same shard. -OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); +OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args, + ShardArgs::Iterator start, ShardArgs::Iterator end) { + DCHECK(start != end); if (op == NOT_OP_NAME) { - return RunBitOpNot(op_args, keys); + return RunBitOpNot(op_args, *start); } EngineShard* es = op_args.shard; BitsStrVec values; - values.reserve(keys.size()); // collect all the value for this shard - for (auto& key : keys) { - auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, key, OBJ_STRING); + for (; start != end; ++start) { + auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, *start, OBJ_STRING); if (find_res) { values.emplace_back(GetString(find_res.value()->second)); } else { @@ -1143,18 +1140,18 @@ void BitOp(CmdArgList args, ConnectionContext* cntx) { ShardId dest_shard = Shard(dest_key, result_set.size()); auto shard_bitop = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); - DCHECK(!largs.empty()); - + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + DCHECK(!largs.Empty()); + ShardArgs::Iterator start = largs.begin(), end = largs.end(); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - largs.remove_prefix(1); - if (largs.empty()) { // no more keys to check + CHECK_EQ(*start, dest_key); + ++start; + if (start == end) { // no more keys to check return OpStatus::OK; } } OpArgs op_args = t->GetOpArgs(shard); - result_set[shard->shard_id()] = RunBitOpOnShard(op, op_args, largs); + result_set[shard->shard_id()] = RunBitOpOnShard(op, op_args, start, end); return OpStatus::OK; }; diff --git a/src/server/blocking_controller.cc b/src/server/blocking_controller.cc index 7f1584adf..7eca07b43 100644 --- a/src/server/blocking_controller.cc +++ b/src/server/blocking_controller.cc @@ -118,7 +118,7 @@ bool BlockingController::DbWatchTable::AddAwakeEvent(string_view key) { } // Removes tx from its watch queues if tx appears there. -void BlockingController::FinalizeWatched(ArgSlice args, Transaction* tx) { +void BlockingController::FinalizeWatched(const ShardArgs& args, Transaction* tx) { DCHECK(tx); VLOG(1) << "FinalizeBlocking [" << owner_->shard_id() << "]" << tx->DebugId(); @@ -197,7 +197,8 @@ void BlockingController::NotifyPending() { awakened_indices_.clear(); } -void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transaction* trans) { +void BlockingController::AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc, + Transaction* trans) { auto [dbit, added] = watched_dbs_.emplace(trans->GetDbIndex(), nullptr); if (added) { dbit->second.reset(new DbWatchTable); @@ -205,7 +206,7 @@ void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transact DbWatchTable& wt = *dbit->second; - for (auto key : keys) { + for (auto key : watch_keys) { auto [res, inserted] = wt.queue_map.emplace(key, nullptr); if (inserted) { res->second.reset(new WatchQueue); diff --git a/src/server/blocking_controller.h b/src/server/blocking_controller.h index 081aff9f4..251811f4a 100644 --- a/src/server/blocking_controller.h +++ b/src/server/blocking_controller.h @@ -10,6 +10,7 @@ #include "base/string_view_sso.h" #include "server/common.h" +#include "server/tx_base.h" namespace dfly { @@ -28,7 +29,7 @@ class BlockingController { return awakened_transactions_; } - void FinalizeWatched(ArgSlice args, Transaction* tx); + void FinalizeWatched(const ShardArgs& args, Transaction* tx); // go over potential wakened keys, verify them and activate watch queues. void NotifyPending(); @@ -37,7 +38,7 @@ class BlockingController { // TODO: consider moving all watched functions to // EngineShard with separate per db map. //! AddWatched adds a transaction to the blocking queue. - void AddWatched(ArgSlice watch_keys, KeyReadyChecker krc, Transaction* me); + void AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc, Transaction* me); // Called from operations that create keys like lpush, rename etc. void AwakeWatched(DbIndex db_index, std::string_view db_key); diff --git a/src/server/common.cc b/src/server/common.cc index cc96d6f64..ade4e05d1 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -255,30 +255,6 @@ bool ParseDouble(string_view src, double* value) { return true; } -void RecordJournal(const OpArgs& op_args, string_view cmd, ArgSlice args, uint32_t shard_cnt, - bool multi_commands) { - VLOG(2) << "Logging command " << cmd << " from txn " << op_args.tx->txid(); - op_args.tx->LogJournalOnShard(op_args.shard, make_pair(cmd, args), shard_cnt, multi_commands, - false); -} - -void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt) { - op_args.tx->FinishLogJournalOnShard(op_args.shard, shard_cnt); -} - -void RecordExpiry(DbIndex dbid, string_view key) { - auto journal = EngineShard::tlocal()->journal(); - CHECK(journal); - journal->RecordEntry(0, journal::Op::EXPIRED, dbid, 1, cluster::KeySlot(key), - make_pair("DEL", ArgSlice{key}), false); -} - -void TriggerJournalWriteToSink() { - auto journal = EngineShard::tlocal()->journal(); - CHECK(journal); - journal->RecordEntry(0, journal::Op::NOOP, 0, 0, nullopt, {}, true); -} - #define ADD(x) (x) += o.x IoMgrStats& IoMgrStats::operator+=(const IoMgrStats& rhs) { @@ -462,24 +438,4 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) { return os << GlobalStateName(state); } -std::ostream& operator<<(std::ostream& os, ArgSlice list) { - os << "["; - if (!list.empty()) { - std::for_each(list.begin(), list.end() - 1, [&os](const auto& val) { os << val << ", "; }); - os << (*(list.end() - 1)); - } - return os << "]"; -} - -LockTag::LockTag(std::string_view key) { - if (LockTagOptions::instance().enabled) - str_ = LockTagOptions::instance().Tag(key); - else - str_ = key; -} - -LockFp LockTag::Fingerprint() const { - return XXH64(str_.data(), str_.size(), 0x1C69B3F74AC4AE35UL); -} - } // namespace dfly diff --git a/src/server/common.h b/src/server/common.h index 0e7ef4c65..56fdbdf01 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -1,4 +1,4 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. +// Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // @@ -27,8 +27,6 @@ enum class ListDir : uint8_t { LEFT, RIGHT }; constexpr int64_t kMaxExpireDeadlineSec = (1u << 28) - 1; // 8.5 years constexpr int64_t kMaxExpireDeadlineMs = kMaxExpireDeadlineSec * 1000; -using DbIndex = uint16_t; -using ShardId = uint16_t; using LSN = uint64_t; using TxId = uint64_t; using TxClock = uint64_t; @@ -39,17 +37,11 @@ using facade::CmdArgVec; using facade::MutableSlice; using facade::OpResult; -using ArgSlice = absl::Span; using StringVec = std::vector; // keys are RDB_TYPE_xxx constants. using RdbTypeFreqMap = absl::flat_hash_map; -constexpr DbIndex kInvalidDbId = DbIndex(-1); -constexpr ShardId kInvalidSid = ShardId(-1); -constexpr DbIndex kMaxDbId = 1024; // Reasonable starting point. -using LockFp = uint64_t; // a key fingerprint used by the LockTable. - class CommandId; class Transaction; class EngineShard; @@ -67,98 +59,6 @@ struct LockTagOptions { static const LockTagOptions& instance(); }; -struct KeyLockArgs { - DbIndex db_index = 0; - absl::Span fps; -}; - -// Describes key indices. -struct KeyIndex { - unsigned start; - unsigned end; // does not include this index (open limit). - unsigned step; // 1 for commands like mget. 2 for commands like mset. - - // if index is non-zero then adds another key index (usually 0). - // relevant for for commands like ZUNIONSTORE/ZINTERSTORE for destination key. - std::optional bonus{}; - bool has_reverse_mapping = false; - - KeyIndex(unsigned s = 0, unsigned e = 0, unsigned step = 0) : start(s), end(e), step(step) { - } - - static KeyIndex Range(unsigned start, unsigned end, unsigned step = 1) { - return KeyIndex{start, end, step}; - } - - bool HasSingleKey() const { - return !bonus && (start + step >= end); - } - - unsigned num_args() const { - return end - start + bool(bonus); - } -}; - -struct DbContext { - DbIndex db_index = 0; - uint64_t time_now_ms = 0; -}; - -struct OpArgs { - EngineShard* shard; - const Transaction* tx; - DbContext db_cntx; - - OpArgs() : shard(nullptr), tx(nullptr) { - } - - OpArgs(EngineShard* s, const Transaction* tx, const DbContext& cntx) - : shard(s), tx(tx), db_cntx(cntx) { - } -}; - -// A strong type for a lock tag. Helps to disambiguate between keys and the parts of the -// keys that are used for locking. -class LockTag { - std::string_view str_; - - public: - using is_stackonly = void; // marks that this object does not use heap. - - LockTag() = default; - explicit LockTag(std::string_view key); - - explicit operator std::string_view() const { - return str_; - } - - LockFp Fingerprint() const; - - // To make it hashable. - template friend H AbslHashValue(H h, const LockTag& tag) { - return H::combine(std::move(h), tag.str_); - } - - bool operator==(const LockTag& o) const { - return str_ == o.str_; - } -}; - -// Record non auto journal command with own txid and dbid. -void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args, - uint32_t shard_cnt = 1, bool multi_commands = false); - -// Record non auto journal command finish. Call only when command translates to multi commands. -void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt); - -// Record expiry in journal with independent transaction. Must be called from shard thread holding -// key. -void RecordExpiry(DbIndex dbid, std::string_view key); - -// Trigger journal write to sink, no journal record will be added to journal. -// Must be called from shard thread of journal to sink. -void TriggerJournalWriteToSink(); - struct IoMgrStats { uint64_t read_total = 0; uint64_t read_delay_usec = 0; @@ -205,8 +105,6 @@ enum class GlobalState : uint8_t { std::ostream& operator<<(std::ostream& os, const GlobalState& state); -std::ostream& operator<<(std::ostream& os, ArgSlice list); - enum class TimeUnit : uint8_t { SEC, MSEC }; inline void ToUpper(const MutableSlice* val) { @@ -414,10 +312,6 @@ inline uint32_t MemberTimeSeconds(uint64_t now_ms) { return (now_ms / 1000) - kMemberExpiryBase; } -// Checks whether the touched key is valid for a blocking transaction watching it -using KeyReadyChecker = - std::function; - struct MemoryBytesFlag { uint64_t value = 0; }; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index ba94a117c..4c007b4e1 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -11,6 +11,7 @@ #include "facade/conn_context.h" #include "facade/reply_capture.h" #include "server/common.h" +#include "server/tx_base.h" #include "server/version.h" namespace dfly { diff --git a/src/server/container_utils.cc b/src/server/container_utils.cc index 896c336c6..3b0402c7b 100644 --- a/src/server/container_utils.cc +++ b/src/server/container_utils.cc @@ -24,7 +24,7 @@ extern "C" { ABSL_FLAG(bool, singlehop_blocking, true, "Use single hop optimization for blocking commands"); namespace dfly::container_utils { - +using namespace std; namespace { struct ShardFFResult { @@ -32,16 +32,38 @@ struct ShardFFResult { ShardId sid = kInvalidSid; }; +// Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise. +// If multiple keys are found, returns the first index in the ArgSlice. +OpResult> FindFirstReadOnly(const DbSlice& db_slice, + const DbContext& cntx, + const ShardArgs& args, + int req_obj_type) { + DCHECK(!args.Empty()); + + unsigned i = 0; + for (string_view key : args) { + OpResult res = db_slice.FindReadOnly(cntx, key, req_obj_type); + if (res) + return make_pair(res.value(), i); + if (res.status() != OpStatus::KEY_NOTFOUND) + return res.status(); + ++i; + } + + VLOG(2) << "FindFirst not found"; + return OpStatus::KEY_NOTFOUND; +} + // Find first non-empty key of a single shard transaction, pass it to `func` and return the key. // If no such key exists or a wrong type is found, the apropriate status is returned. // Optimized version of `FindFirstNonEmpty` below. -OpResult FindFirstNonEmptySingleShard(Transaction* trans, int req_obj_type, - BlockingResultCb func) { +OpResult FindFirstNonEmptySingleShard(Transaction* trans, int req_obj_type, + BlockingResultCb func) { DCHECK_EQ(trans->GetUniqueShardCnt(), 1u); - std::string key; + string key; auto cb = [&](Transaction* t, EngineShard* shard) -> Transaction::RunnableResult { auto args = t->GetShardArgs(shard->shard_id()); - auto ff_res = shard->db_slice().FindFirstReadOnly(t->GetDbContext(), args, req_obj_type); + auto ff_res = FindFirstReadOnly(shard->db_slice(), t->GetDbContext(), args, req_obj_type); if (ff_res == OpStatus::WRONG_TYPE) return OpStatus::WRONG_TYPE; @@ -77,7 +99,7 @@ OpResult FindFirstNonEmpty(Transaction* trans, int req_obj_type) auto cb = [&](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - auto ff_res = shard->db_slice().FindFirstReadOnly(t->GetDbContext(), args, req_obj_type); + auto ff_res = FindFirstReadOnly(shard->db_slice(), t->GetDbContext(), args, req_obj_type); if (ff_res) { find_res[shard->shard_id()] = FFResult{ff_res->first->first.AsRef(), ff_res->second, shard->shard_id()}; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index ef22d5c89..6823777d5 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -411,7 +411,7 @@ OpResult DbSlice::FindMutableInternal(const Context& cntx } } -DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_view key) { +DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_view key) const { auto res = FindInternal(cntx, key, std::nullopt, UpdateStatsMode::kReadStats, LoadExternalMode::kDontLoad); return {ConstIterator(res->it, StringOrView::FromView(key)), @@ -419,7 +419,7 @@ DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_vi } OpResult DbSlice::FindReadOnly(const Context& cntx, string_view key, - unsigned req_obj_type) { + unsigned req_obj_type) const { auto res = FindInternal(cntx, key, req_obj_type, UpdateStatsMode::kReadStats, LoadExternalMode::kDontLoad); if (res.ok()) { @@ -442,7 +442,7 @@ OpResult DbSlice::FindAndFetchReadOnly(const Context& cn OpResult DbSlice::FindInternal(const Context& cntx, std::string_view key, std::optional req_obj_type, UpdateStatsMode stats_mode, - LoadExternalMode load_mode) { + LoadExternalMode load_mode) const { if (!IsDbValid(cntx.db_index)) { return OpStatus::KEY_NOTFOUND; } @@ -536,24 +536,6 @@ OpResult DbSlice::FindInternal(const Context& cntx, std: return res; } -OpResult> DbSlice::FindFirstReadOnly(const Context& cntx, - ArgSlice args, - int req_obj_type) { - DCHECK(!args.empty()); - - for (unsigned i = 0; i < args.size(); ++i) { - string_view s = args[i]; - OpResult res = FindReadOnly(cntx, s, req_obj_type); - if (res) - return make_pair(res.value(), i); - if (res.status() != OpStatus::KEY_NOTFOUND) - return res.status(); - } - - VLOG(2) << "FindFirst " << args.front() << " not found"; - return OpStatus::KEY_NOTFOUND; -} - OpResult DbSlice::AddOrFind(const Context& cntx, string_view key) { return AddOrFindInternal(cntx, key, LoadExternalMode::kDontLoad); } @@ -1082,12 +1064,12 @@ void DbSlice::PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size SendInvalidationTrackingMessage(key); } -DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) { +DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) const { auto res = ExpireIfNeeded(cntx, it.GetInnerIt()); return {.it = Iterator::FromPrime(res.it), .exp_it = ExpIterator::FromPrime(res.exp_it)}; } -DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) { +DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) const { if (!it->second.HasExpire()) { LOG(ERROR) << "Invalid call to ExpireIfNeeded"; return {it, ExpireIterator{}}; @@ -1124,8 +1106,9 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato doc_del_cb_(key, cntx, it->second); } - PerformDeletion(Iterator(it, StringOrView::FromView(key)), - ExpIterator(expire_it, StringOrView::FromView(key)), db.get()); + const_cast(this)->PerformDeletion(Iterator(it, StringOrView::FromView(key)), + ExpIterator(expire_it, StringOrView::FromView(key)), + db.get()); ++events_.expired_keys; return {PrimeIterator{}, ExpireIterator{}}; @@ -1490,21 +1473,6 @@ void DbSlice::ResetEvents() { events_ = {}; } -void DbSlice::TrackKeys(const facade::Connection::WeakRef& conn, const ArgSlice& keys) { - if (conn.IsExpired()) { - DVLOG(2) << "Connection expired, exiting TrackKey function."; - return; - } - - DVLOG(2) << "Start tracking keys for client ID: " << conn.GetClientId() - << " with thread ID: " << conn.Thread(); - for (auto key : keys) { - DVLOG(2) << "Inserting client ID " << conn.GetClientId() - << " into the tracking client set of key " << key; - client_tracking_map_[key].insert(conn); - } -} - void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { if (client_tracking_map_.empty()) return; diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 00a9b2454..24b88da2e 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -281,17 +281,12 @@ class DbSlice { ConstIterator it; ExpConstIterator exp_it; }; - ItAndExpConst FindReadOnly(const Context& cntx, std::string_view key); + ItAndExpConst FindReadOnly(const Context& cntx, std::string_view key) const; OpResult FindReadOnly(const Context& cntx, std::string_view key, - unsigned req_obj_type); + unsigned req_obj_type) const; OpResult FindAndFetchReadOnly(const Context& cntx, std::string_view key, unsigned req_obj_type); - // Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise. - // If multiple keys are found, returns the first index in the ArgSlice. - OpResult> FindFirstReadOnly(const Context& cntx, ArgSlice args, - int req_obj_type); - struct AddOrFindResult { Iterator it; ExpIterator exp_it; @@ -404,7 +399,7 @@ class DbSlice { Iterator it; ExpIterator exp_it; }; - ItAndExp ExpireIfNeeded(const Context& cntx, Iterator it); + ItAndExp ExpireIfNeeded(const Context& cntx, Iterator it) const; // Iterate over all expire table entries and delete expired. void ExpireAllIfNeeded(); @@ -473,7 +468,9 @@ class DbSlice { } // Track keys for the client represented by the the weak reference to its connection. - void TrackKeys(const facade::Connection::WeakRef&, const ArgSlice&); + void TrackKey(const facade::Connection::WeakRef& conn_ref, std::string_view key) { + client_tracking_map_[key].insert(conn_ref); + } // Delete a key referred by its iterator. void PerformDeletion(Iterator del_it, DbTable* table); @@ -517,10 +514,11 @@ class DbSlice { PrimeIterator it; ExpireIterator exp_it; }; - PrimeItAndExp ExpireIfNeeded(const Context& cntx, PrimeIterator it); + PrimeItAndExp ExpireIfNeeded(const Context& cntx, PrimeIterator it) const; OpResult FindInternal(const Context& cntx, std::string_view key, std::optional req_obj_type, - UpdateStatsMode stats_mode, LoadExternalMode load_mode); + UpdateStatsMode stats_mode, + LoadExternalMode load_mode) const; OpResult AddOrFindInternal(const Context& cntx, std::string_view key, LoadExternalMode load_mode); OpResult FindMutableInternal(const Context& cntx, std::string_view key, diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index a0e12f8b2..4b3c6b7de 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -281,11 +281,11 @@ class Renamer { void Renamer::Find(Transaction* t) { auto cb = [this](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - CHECK_EQ(1u, args.size()); + DCHECK_EQ(1u, args.Size()); FindResult* res = (shard->shard_id() == src_sid_) ? &src_res_ : &dest_res_; - res->key = args.front(); + res->key = args.Front(); auto& db_slice = EngineShard::tlocal()->db_slice(); auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), res->key); @@ -615,6 +615,40 @@ OpResult OpFieldTtl(Transaction* t, EngineShard* shard, string_view key, s return res <= 0 ? res : int32_t(res - MemberTimeSeconds(db_cntx.time_now_ms)); } +OpResult OpDel(const OpArgs& op_args, const ShardArgs& keys) { + DVLOG(1) << "Del: " << keys.Front(); + auto& db_slice = op_args.shard->db_slice(); + + uint32_t res = 0; + + for (string_view key : keys) { + auto fres = db_slice.FindMutable(op_args.db_cntx, key); + if (!IsValid(fres.it)) + continue; + fres.post_updater.Run(); + res += int(db_slice.Del(op_args.db_cntx.db_index, fres.it)); + } + + return res; +} + +OpResult OpStick(const OpArgs& op_args, const ShardArgs& keys) { + DVLOG(1) << "Stick: " << keys.Front(); + + auto& db_slice = op_args.shard->db_slice(); + + uint32_t res = 0; + for (string_view key : keys) { + auto find_res = db_slice.FindMutable(op_args.db_cntx, key); + if (IsValid(find_res.it) && !find_res.it->first.IsSticky()) { + find_res.it->first.SetSticky(true); + ++res; + } + } + + return res; +} + } // namespace void GenericFamily::Init(util::ProactorPool* pp) { @@ -631,7 +665,7 @@ void GenericFamily::Del(CmdArgList args, ConnectionContext* cntx) { bool is_mc = cntx->protocol() == Protocol::MEMCACHE; auto cb = [&result](const Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = OpDel(t->GetOpArgs(shard), args); result.fetch_add(res.value_or(0), memory_order_relaxed); @@ -683,7 +717,7 @@ void GenericFamily::Exists(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t result{0}; auto cb = [&result](Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = OpExists(t->GetOpArgs(shard), args); result.fetch_add(res.value_or(0), memory_order_relaxed); @@ -889,7 +923,7 @@ void GenericFamily::Stick(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t result{0}; auto cb = [&result](const Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = OpStick(t->GetOpArgs(shard), args); result.fetch_add(res.value_or(0), memory_order_relaxed); @@ -1373,30 +1407,13 @@ OpResult GenericFamily::OpTtl(Transaction* t, EngineShard* shard, stri return ttl_ms; } -OpResult GenericFamily::OpDel(const OpArgs& op_args, ArgSlice keys) { - DVLOG(1) << "Del: " << keys[0]; - auto& db_slice = op_args.shard->db_slice(); - - uint32_t res = 0; - - for (uint32_t i = 0; i < keys.size(); ++i) { - auto fres = db_slice.FindMutable(op_args.db_cntx, keys[i]); - if (!IsValid(fres.it)) - continue; - fres.post_updater.Run(); - res += int(db_slice.Del(op_args.db_cntx.db_index, fres.it)); - } - - return res; -} - -OpResult GenericFamily::OpExists(const OpArgs& op_args, ArgSlice keys) { - DVLOG(1) << "Exists: " << keys[0]; +OpResult GenericFamily::OpExists(const OpArgs& op_args, const ShardArgs& keys) { + DVLOG(1) << "Exists: " << keys.Front(); auto& db_slice = op_args.shard->db_slice(); uint32_t res = 0; - for (uint32_t i = 0; i < keys.size(); ++i) { - auto find_res = db_slice.FindReadOnly(op_args.db_cntx, keys[i]); + for (string_view key : keys) { + auto find_res = db_slice.FindReadOnly(op_args.db_cntx, key); res += IsValid(find_res.it); } return res; @@ -1462,23 +1479,6 @@ OpResult GenericFamily::OpRen(const OpArgs& op_args, string_view from_key, return OpStatus::OK; } -OpResult GenericFamily::OpStick(const OpArgs& op_args, ArgSlice keys) { - DVLOG(1) << "Stick: " << keys[0]; - - auto& db_slice = op_args.shard->db_slice(); - - uint32_t res = 0; - for (uint32_t i = 0; i < keys.size(); ++i) { - auto find_res = db_slice.FindMutable(op_args.db_cntx, keys[i]); - if (IsValid(find_res.it) && !find_res.it->first.IsSticky()) { - find_res.it->first.SetSticky(true); - ++res; - } - } - - return res; -} - // OpMove touches multiple databases (op_args.db_idx, target_db), so it assumes it runs // as a global transaction. // TODO: Allow running OpMove without a global transaction. diff --git a/src/server/generic_family.h b/src/server/generic_family.h index f7e888c67..015ed7fcb 100644 --- a/src/server/generic_family.h +++ b/src/server/generic_family.h @@ -40,7 +40,7 @@ class GenericFamily { static void Register(CommandRegistry* registry); // Accessed by Service::Exec and Service::Watch as an utility. - static OpResult OpExists(const OpArgs& op_args, ArgSlice keys); + static OpResult OpExists(const OpArgs& op_args, const ShardArgs& keys); private: static void Del(CmdArgList args, ConnectionContext* cntx); @@ -76,10 +76,8 @@ class GenericFamily { static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit); static OpResult OpTtl(Transaction* t, EngineShard* shard, std::string_view key); - static OpResult OpDel(const OpArgs& op_args, ArgSlice keys); static OpResult OpRen(const OpArgs& op_args, std::string_view from, std::string_view to, bool skip_exists); - static OpResult OpStick(const OpArgs& op_args, ArgSlice keys); static OpStatus OpMove(const OpArgs& op_args, std::string_view key, DbIndex target_db); }; diff --git a/src/server/hll_family.cc b/src/server/hll_family.cc index ac3c4de64..0bb88e4e6 100644 --- a/src/server/hll_family.cc +++ b/src/server/hll_family.cc @@ -169,11 +169,11 @@ OpResult CountHllsSingle(const OpArgs& op_args, string_view key) { } } -OpResult> ReadValues(const OpArgs& op_args, ArgSlice keys) { +OpResult> ReadValues(const OpArgs& op_args, const ShardArgs& keys) { try { vector values; - for (size_t i = 0; i < keys.size(); ++i) { - auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_STRING); + for (string_view key : keys) { + auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (it.ok()) { string hll; it.value()->second.GetString(&hll); @@ -210,7 +210,7 @@ OpResult PFCountMulti(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { ShardId sid = shard->shard_id(); - ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); + ShardArgs shard_args = t->GetShardArgs(shard->shard_id()); auto result = ReadValues(t->GetOpArgs(shard), shard_args); if (result.ok()) { hlls[sid] = std::move(result.value()); @@ -252,7 +252,7 @@ OpResult PFMergeInternal(CmdArgList args, ConnectionContext* cntx) { atomic_bool success = true; auto cb = [&](Transaction* t, EngineShard* shard) { ShardId sid = shard->shard_id(); - ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); + ShardArgs shard_args = t->GetShardArgs(shard->shard_id()); auto result = ReadValues(t->GetOpArgs(shard), shard_args); if (result.ok()) { hlls[sid] = std::move(result.value()); diff --git a/src/server/json_family.cc b/src/server/json_family.cc index 0736a268d..dff09c181 100644 --- a/src/server/json_family.cc +++ b/src/server/json_family.cc @@ -1130,19 +1130,21 @@ OpResult> OpArrIndex(const OpArgs& op_args, string_view key, Jso // Returns string vector that represents the query result of each supplied key. vector OpJsonMGet(JsonPathV2 expression, const Transaction* t, EngineShard* shard) { - auto args = t->GetShardArgs(shard->shard_id()); - DCHECK(!args.empty()); - vector response(args.size()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); + DCHECK(!args.Empty()); + vector response(args.Size()); auto& db_slice = shard->db_slice(); - for (size_t i = 0; i < args.size(); ++i) { - auto it_res = db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_JSON); + unsigned index = 0; + for (string_view key : args) { + auto it_res = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_JSON); + auto& dest = response[index++]; if (!it_res.ok()) continue; - auto& dest = response[i].emplace(); + dest.emplace(); JsonType* json_val = it_res.value()->second.GetJson(); - DCHECK(json_val) << "should have a valid JSON object for key " << args[i]; + DCHECK(json_val) << "should have a valid JSON object for key " << key; vector query_result; auto cb = [&query_result](const string_view& path, const JsonType& val) { @@ -1364,8 +1366,8 @@ void JsonFamily::MSet(CmdArgList args, ConnectionContext* cntx) { } auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); - LOG(INFO) << shard->shard_id() << " " << args; + ShardArgs args = t->GetShardArgs(shard->shard_id()); + (void)args; // TBD return OpStatus::OK; }; @@ -1469,12 +1471,7 @@ void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) { continue; vector& res = mget_resp[sid]; - ArgSlice slice = transaction->GetShardArgs(sid); - - DCHECK(!slice.empty()); - DCHECK_EQ(slice.size(), res.size()); - - for (size_t j = 0; j < slice.size(); ++j) { + for (size_t j = 0; j < res.size(); ++j) { if (!res[j]) continue; diff --git a/src/server/list_family.cc b/src/server/list_family.cc index f02f04cff..79adbf336 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -416,9 +416,9 @@ OpResult MoveTwoShards(Transaction* trans, string_view src, string_view // auto cb = [&](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - DCHECK_EQ(1u, args.size()); - bool is_dest = args.front() == dest; - find_res[is_dest] = Peek(t->GetOpArgs(shard), args.front(), src_dir, !is_dest); + DCHECK_EQ(1u, args.Size()); + bool is_dest = args.Front() == dest; + find_res[is_dest] = Peek(t->GetOpArgs(shard), args.Front(), src_dir, !is_dest); return OpStatus::OK; }; @@ -432,7 +432,7 @@ OpResult MoveTwoShards(Transaction* trans, string_view src, string_view // Everything is ok, lets proceed with the mutations. auto cb = [&](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - auto key = args.front(); + auto key = args.Front(); bool is_dest = (key == dest); OpArgs op_args = t->GetOpArgs(shard); @@ -873,7 +873,7 @@ OpResult BPopPusher::RunSingle(ConnectionContext* cntx, time_point tp) { return op_res; } - auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; }; + auto wcb = [&](Transaction* t, EngineShard* shard) { return ShardArgs{&this->pop_key_, 1}; }; const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*, std::string_view key) -> bool { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 4ad682f1c..74a1d5daa 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1127,9 +1127,25 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); } -OpResult OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, const ArgSlice& keys) { - auto& db_slice = op_args.shard->db_slice(); - db_slice.TrackKeys(cntx->conn()->Borrow(), keys); +OpResult OpTrackKeys(const OpArgs& op_args, const facade::Connection::WeakRef& conn_ref, + const ShardArgs& args) { + if (conn_ref.IsExpired()) { + DVLOG(2) << "Connection expired, exiting TrackKey function."; + return OpStatus::OK; + } + + DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() + << " with thread ID: " << conn_ref.Thread(); + + DbSlice& db_slice = op_args.shard->db_slice(); + + // TODO: There is a bug here that we track all arguments instead of tracking only keys. + for (auto key : args) { + DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() + << " into the tracking client set of key " << key; + db_slice.TrackKey(conn_ref, key); + } + return OpStatus::OK; } @@ -1236,9 +1252,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // start tracking all the updates to the keys in this read command if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() && cid->IsTransactional()) { - auto cb = [&](Transaction* t, EngineShard* shard) { - auto keys = t->GetShardArgs(shard->shard_id()); - return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, keys); + facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow(); + auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) { + return OpTrackKeys(t->GetOpArgs(shard), conn_ref, t->GetShardArgs(shard->shard_id())); }; dfly_cntx->transaction->Refurbish(); dfly_cntx->transaction->ScheduleSingleHopT(cb); @@ -1610,7 +1626,7 @@ void Service::Watch(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t keys_existed = 0; auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); for (auto k : largs) { shard->db_slice().RegisterWatchedKey(cntx->db_index(), k, &exec_info); } @@ -2018,7 +2034,7 @@ bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& regis atomic_uint32_t watch_exist_count{0}; auto cb = [&watch_exist_count](Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = GenericFamily::OpExists(t->GetOpArgs(shard), args); watch_exist_count.fetch_add(res.value_or(0), memory_order_relaxed); diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 1a3fdb520..5e046a00f 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -587,10 +587,10 @@ class Mover { }; OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { - ArgSlice largs = t->GetShardArgs(es->shard_id()); + ShardArgs largs = t->GetShardArgs(es->shard_id()); // In case both src and dest are in the same shard, largs size will be 2. - DCHECK_LE(largs.size(), 2u); + DCHECK_LE(largs.Size(), 2u); for (auto k : largs) { unsigned index = (k == src_) ? 0 : 1; @@ -609,8 +609,8 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { } OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) { - ArgSlice largs = t->GetShardArgs(es->shard_id()); - DCHECK_LE(largs.size(), 2u); + ShardArgs largs = t->GetShardArgs(es->shard_id()); + DCHECK_LE(largs.Size(), 2u); OpArgs op_args = t->GetOpArgs(es); for (auto k : largs) { @@ -655,12 +655,13 @@ OpResult Mover::Commit(Transaction* t) { } // Read-only OpUnion op on sets. -OpResult OpUnion(const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); +OpResult OpUnion(const OpArgs& op_args, ShardArgs::Iterator start, + ShardArgs::Iterator end) { + DCHECK(start != end); absl::flat_hash_set uniques; - for (string_view key : keys) { - auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET); + for (; start != end; ++start) { + auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET); if (find_res) { const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { @@ -683,11 +684,12 @@ OpResult OpUnion(const OpArgs& op_args, ArgSlice keys) { } // Read-only OpDiff op on sets. -OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); - DVLOG(1) << "OpDiff from " << keys.front(); +OpResult OpDiff(const OpArgs& op_args, ShardArgs::Iterator start, + ShardArgs::Iterator end) { + DCHECK(start != end); + DVLOG(1) << "OpDiff from " << *start; EngineShard* es = op_args.shard; - auto find_res = es->db_slice().FindReadOnly(op_args.db_cntx, keys.front(), OBJ_SET); + auto find_res = es->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET); if (!find_res) { return find_res.status(); @@ -707,8 +709,8 @@ OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { DCHECK(!uniques.empty()); // otherwise the key would not exist. - for (size_t i = 1; i < keys.size(); ++i) { - auto diff_res = es->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_SET); + for (++start; start != end; ++start) { + auto diff_res = es->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET); if (!diff_res) { if (diff_res.status() == OpStatus::WRONG_TYPE) { return OpStatus::WRONG_TYPE; @@ -737,15 +739,16 @@ OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { // Read-only OpInter op on sets. OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_first) { - ArgSlice keys = t->GetShardArgs(es->shard_id()); + ShardArgs args = t->GetShardArgs(es->shard_id()); + auto it = args.begin(); if (remove_first) { - keys.remove_prefix(1); + ++it; } - DCHECK(!keys.empty()); + DCHECK(it != args.end()); StringVec result; - if (keys.size() == 1) { - auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), keys.front(), OBJ_SET); + if (args.Size() == 1 + unsigned(remove_first)) { + auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), *it, OBJ_SET); if (!find_res) return find_res.status(); @@ -763,12 +766,13 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f return result; } - vector sets(keys.size()); + vector sets(args.Size() - int(remove_first)); OpStatus status = OpStatus::OK; - - for (size_t i = 0; i < keys.size(); ++i) { - auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), keys[i], OBJ_SET); + unsigned index = 0; + for (; it != args.end(); ++it) { + auto& dest = sets[index++]; + auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), *it, OBJ_SET); if (!find_res) { if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND || find_res.status() != OpStatus::KEY_NOTFOUND) { @@ -778,7 +782,7 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f } const PrimeValue& pv = find_res.value()->second; void* ptr = pv.RObjPtr(); - sets[i] = make_pair(ptr, pv.Encoding()); + dest = make_pair(ptr, pv.Encoding()); } if (status != OpStatus::OK) @@ -1089,12 +1093,12 @@ void SDiff(CmdArgList args, ConnectionContext* cntx) { ShardId src_shard = Shard(src_key, result_set.size()); auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); if (shard->shard_id() == src_shard) { - CHECK_EQ(src_key, largs.front()); - result_set[shard->shard_id()] = OpDiff(t->GetOpArgs(shard), largs); + CHECK_EQ(src_key, largs.Front()); + result_set[shard->shard_id()] = OpDiff(t->GetOpArgs(shard), largs.begin(), largs.end()); } else { - result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs); + result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs.begin(), largs.end()); } return OpStatus::OK; @@ -1126,22 +1130,23 @@ void SDiffStore(CmdArgList args, ConnectionContext* cntx) { // read-only op auto diff_cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); - DCHECK(!largs.empty()); - + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + OpArgs op_args = t->GetOpArgs(shard); + DCHECK(!largs.Empty()); + ShardArgs::Iterator start = largs.begin(); + ShardArgs::Iterator end = largs.end(); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - largs.remove_prefix(1); - if (largs.empty()) + CHECK_EQ(*start, dest_key); + ++start; + if (start == end) return OpStatus::OK; } - OpArgs op_args = t->GetOpArgs(shard); if (shard->shard_id() == src_shard) { - CHECK_EQ(src_key, largs.front()); - result_set[shard->shard_id()] = OpDiff(op_args, largs); // Diff + CHECK_EQ(src_key, *start); + result_set[shard->shard_id()] = OpDiff(op_args, start, end); // Diff } else { - result_set[shard->shard_id()] = OpUnion(op_args, largs); // Union + result_set[shard->shard_id()] = OpUnion(op_args, start, end); // Union } return OpStatus::OK; @@ -1276,10 +1281,10 @@ void SInterStore(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t inter_shard_cnt{0}; auto inter_cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - if (largs.size() == 1) + CHECK_EQ(largs.Front(), dest_key); + if (largs.Size() == 1) return OpStatus::OK; } inter_shard_cnt.fetch_add(1, memory_order_relaxed); @@ -1337,8 +1342,8 @@ void SUnion(CmdArgList args, ConnectionContext* cntx) { ResultStringVec result_set(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); - result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs.begin(), largs.end()); return OpStatus::OK; }; @@ -1363,14 +1368,15 @@ void SUnionStore(CmdArgList args, ConnectionContext* cntx) { ShardId dest_shard = Shard(dest_key, result_set.size()); auto union_cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + ShardArgs::Iterator start = largs.begin(), end = largs.end(); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - largs.remove_prefix(1); - if (largs.empty()) + CHECK_EQ(*start, dest_key); + ++start; + if (start == end) return OpStatus::OK; } - result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs); + result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), start, end); return OpStatus::OK; }; diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc index ac58427b2..3274c439c 100644 --- a/src/server/stream_family.cc +++ b/src/server/stream_family.cc @@ -798,8 +798,8 @@ stream* GetReadOnlyStream(const CompactObj& cobj) { // Returns a map of stream to the ID of the last entry in the stream. Any // streams not found are omitted from the result. OpResult>> OpLastIDs(const OpArgs& op_args, - const ArgSlice& args) { - DCHECK(!args.empty()); + const ShardArgs& args) { + DCHECK(!args.Empty()); auto& db_slice = op_args.shard->db_slice(); @@ -828,8 +828,8 @@ OpResult>> OpLastIDs(const OpArgs& op_args, // Returns the range response for each stream on this shard in order of // GetShardArgs. -vector OpRead(const OpArgs& op_args, const ArgSlice& args, const ReadOpts& opts) { - DCHECK(!args.empty()); +vector OpRead(const OpArgs& op_args, const ShardArgs& shard_args, const ReadOpts& opts) { + DCHECK(!shard_args.Empty()); RangeOpts range_opts; range_opts.count = opts.count; @@ -838,11 +838,11 @@ vector OpRead(const OpArgs& op_args, const ArgSlice& args, const Read .seq = UINT64_MAX, }}; - vector response(args.size()); - for (size_t i = 0; i < args.size(); ++i) { - string_view key = args[i]; - + vector response(shard_args.Size()); + unsigned index = 0; + for (string_view key : shard_args) { auto sitem = opts.stream_ids.at(key); + auto& dest = response[index++]; if (!sitem.group && opts.read_group) { continue; } @@ -858,7 +858,7 @@ vector OpRead(const OpArgs& op_args, const ArgSlice& args, const Read else range_res = OpRange(op_args, key, range_opts); if (range_res) { - response[i] = std::move(range_res.value()); + dest = std::move(range_res.value()); } } @@ -1352,15 +1352,17 @@ struct GroupConsumerPairOpts { string_view consumer; }; -vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpArgs& op_args, +vector OpGetGroupConsumerPairs(const ShardArgs& shard_args, + const OpArgs& op_args, const GroupConsumerPairOpts& opts) { - vector sid_items(slice_args.size()); - + vector sid_items(shard_args.Size()); + unsigned index = 0; // get group and consumer - for (size_t i = 0; i < slice_args.size(); i++) { - string_view key = slice_args[i]; + for (string_view key : shard_args) { streamCG* group = nullptr; streamConsumer* consumer = nullptr; + auto& dest = sid_items[index++]; + auto group_res = FindGroup(op_args, key, opts.group); if (!group_res) { continue; @@ -1376,7 +1378,7 @@ vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA consumer = streamCreateConsumer(group, op_args.shard->tmp_str1, NULL, 0, SCC_NO_NOTIFY | SCC_NO_DIRTIFY); } - sid_items[i] = {group, consumer}; + dest = {group, consumer}; } return sid_items; } @@ -2988,12 +2990,7 @@ void XReadImpl(CmdArgList args, std::optional opts, ConnectionContext* vector& results = xread_resp[sid]; - ArgSlice slice = cntx->transaction->GetShardArgs(sid); - - DCHECK(!slice.empty()); - DCHECK_EQ(slice.size(), results.size()); - - for (size_t i = 0; i < slice.size(); ++i) { + for (size_t i = 0; i < results.size(); ++i) { if (results[i].size() == 0) { continue; } @@ -3039,7 +3036,7 @@ void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) { vector> res_pairs(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { auto sid = shard->shard_id(); - auto s_args = t->GetShardArgs(sid); + ShardArgs s_args = t->GetShardArgs(sid); GroupConsumerPairOpts gc_opts = {opts->group_name, opts->consumer_name}; res_pairs[sid] = OpGetGroupConsumerPairs(s_args, t->GetOpArgs(shard), gc_opts); @@ -3057,11 +3054,12 @@ void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) { if (s_item.size() == 0) { continue; } - for (size_t j = 0; j < s_args.size(); j++) { - string_view key = s_args[j]; + unsigned index = 0; + for (string_view key : s_args) { StreamIDsItem& item = opts->stream_ids.at(key); - item.consumer = s_item[j].consumer; - item.group = s_item[j].group; + item.consumer = s_item[index].consumer; + item.group = s_item[index].group; + ++index; } } } diff --git a/src/server/stream_family_test.cc b/src/server/stream_family_test.cc index 55ce88d10..2b2f2d024 100644 --- a/src/server/stream_family_test.cc +++ b/src/server/stream_family_test.cc @@ -109,8 +109,9 @@ TEST_F(StreamFamilyTest, Range) { } TEST_F(StreamFamilyTest, GroupCreate) { - Run({"xadd", "key", "1-*", "f1", "v1"}); - auto resp = Run({"xgroup", "create", "key", "grname", "1"}); + auto resp = Run({"xadd", "key", "1-*", "f1", "v1"}); + EXPECT_EQ(resp, "1-0"); + resp = Run({"xgroup", "create", "key", "grname", "1"}); EXPECT_EQ(resp, "OK"); resp = Run({"xgroup", "create", "test", "test", "0"}); EXPECT_THAT(resp, ErrArg("requires the key to exist")); diff --git a/src/server/string_family.cc b/src/server/string_family.cc index dd7cb4975..c8ed6edcc 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -278,19 +278,23 @@ int64_t AbsExpiryToTtl(int64_t abs_expiry_time, bool as_milli) { } // Returns true if keys were set, false otherwise. -void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) { - DCHECK(!args.empty() && args.size() % 2 == 0); +void OpMSet(const OpArgs& op_args, const ShardArgs& args, atomic_bool* success) { + DCHECK(!args.Empty() && args.Size() % 2 == 0); SetCmd::SetParams params; SetCmd sg(op_args, false); - size_t i = 0; - for (; i < args.size(); i += 2) { - DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1]; - if (sg.Set(params, args[i], args[i + 1]) != OpStatus::OK) { // OOM for example. + size_t index = 0; + for (auto it = args.begin(); it != args.end(); ++it) { + string_view key = *it; + ++it; + string_view value = *it; + DVLOG(1) << "MSet " << key << ":" << value; + if (sg.Set(params, key, value) != OpStatus::OK) { // OOM for example. success->store(false); break; } + index += 2; } if (auto journal = op_args.shard->journal(); journal) { @@ -298,14 +302,14 @@ void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) { // we replicate only what was changed. string_view cmd; ArgSlice cmd_args; - if (i == 0) { + if (index == 0) { // All shards must record the tx was executed for the replica to execute it, so we send a PING // in case nothing was changed cmd = "PING"; } else { // journal [0, i) cmd = "MSET"; - cmd_args = ArgSlice(&args[0], i); + cmd_args = ArgSlice(args.begin(), index); } RecordJournal(op_args, cmd, cmd_args, op_args.tx->GetUniqueShardCnt()); } @@ -419,27 +423,29 @@ OpResult> OpThrottle(const OpArgs& op_args, const string_view SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction* t, EngineShard* shard) { - auto keys = t->GetShardArgs(shard->shard_id()); - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); auto& db_slice = shard->db_slice(); - SinkReplyBuilder::MGetResponse response(keys.size()); - absl::InlinedVector iters(keys.size()); + SinkReplyBuilder::MGetResponse response(keys.Size()); + absl::InlinedVector iters(keys.Size()); size_t total_size = 0; - for (size_t i = 0; i < keys.size(); ++i) { - auto it_res = db_slice.FindAndFetchReadOnly(t->GetDbContext(), keys[i], OBJ_STRING); + unsigned index = 0; + for (string_view key : keys) { + auto it_res = db_slice.FindAndFetchReadOnly(t->GetDbContext(), key, OBJ_STRING); + auto& dest = iters[index++]; if (!it_res) continue; - iters[i] = *it_res; + dest = *it_res; total_size += (*it_res)->second.Size(); } response.storage_list = SinkReplyBuilder::AllocMGetStorage(total_size); char* next = response.storage_list->data; - for (size_t i = 0; i < keys.size(); ++i) { + for (size_t i = 0; i < iters.size(); ++i) { auto it = iters[i]; if (it.is_done()) continue; @@ -1139,12 +1145,7 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) { res.storage_list = src.storage_list; src.storage_list = nullptr; - ArgSlice slice = transaction->GetShardArgs(sid); - - DCHECK(!slice.empty()); - DCHECK_EQ(slice.size(), src.resp_arr.size()); - - for (size_t j = 0; j < slice.size(); ++j) { + for (size_t j = 0; j < src.resp_arr.size(); ++j) { if (!src.resp_arr[j]) continue; @@ -1173,7 +1174,7 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) { atomic_bool success = true; auto cb = [&](Transaction* t, EngineShard* shard) { - auto args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); OpMSet(t->GetOpArgs(shard), args, &success); return OpStatus::OK; }; @@ -1193,8 +1194,9 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* es) { auto args = t->GetShardArgs(es->shard_id()); - for (size_t i = 0; i < args.size(); i += 2) { - auto it = es->db_slice().FindReadOnly(t->GetDbContext(), args[i]).it; + for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) { + auto it = es->db_slice().FindReadOnly(t->GetDbContext(), *arg_it).it; + ++arg_it; if (IsValid(it)) { exists.store(true, memory_order_relaxed); break; diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 7fcbd2d60..3ae781eec 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -182,8 +182,6 @@ void Transaction::InitGlobal() { } void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vector* out) { - auto args = full_args_; - auto& shard_index = *out; auto add = [this, rev_mapping = key_index.has_reverse_mapping, &shard_index](uint32_t sid, @@ -196,14 +194,14 @@ void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vectorsize())); + DCHECK_EQ(unique_shard_id_, Shard(akey, shard_set->size())); else { - unique_slot_checker_.Add(kv_args_.front()); - unique_shard_id_ = Shard(kv_args_.front(), shard_set->size()); + unique_slot_checker_.Add(akey); + unique_shard_id_ = Shard(akey, shard_set->size()); } // Multi transactions that execute commands on their own (not stubs) can't shrink the backing @@ -1178,7 +1175,7 @@ bool Transaction::CancelShardCb(EngineShard* shard) { } // runs in engine-shard thread. -ArgSlice Transaction::GetShardArgs(ShardId sid) const { +ShardArgs Transaction::GetShardArgs(ShardId sid) const { DCHECK(!multi_ || multi_->role != SQUASHER); // We can read unique_shard_cnt_ only because ShardArgsInShard is called after IsArmedInShard @@ -1188,7 +1185,7 @@ ArgSlice Transaction::GetShardArgs(ShardId sid) const { } const auto& sd = shard_data_[sid]; - return ArgSlice{kv_args_.data() + sd.arg_start, sd.arg_count}; + return ShardArgs{kv_args_.data() + sd.arg_start, sd.arg_count}; } // from local index back to original arg index skipping the command. @@ -1253,7 +1250,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p return result; } -OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc) { +OpStatus Transaction::WatchInShard(const ShardArgs& keys, EngineShard* shard, KeyReadyChecker krc) { auto& sd = shard_data_[SidToId(shard->shard_id())]; CHECK_EQ(0, sd.local_mask & SUSPENDED_Q); @@ -1261,12 +1258,12 @@ OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyCh sd.local_mask &= ~OUT_OF_ORDER; shard->EnsureBlockingController()->AddWatched(keys, std::move(krc), this); - DVLOG(2) << "WatchInShard " << DebugId() << ", first_key:" << keys.front(); + DVLOG(2) << "WatchInShard " << DebugId() << ", first_key:" << keys.Front(); return OpStatus::OK; } -void Transaction::ExpireShardCb(ArgSlice wkeys, EngineShard* shard) { +void Transaction::ExpireShardCb(const ShardArgs& wkeys, EngineShard* shard) { // Blocking transactions don't release keys when suspending, release them now. auto lock_args = GetLockArgs(shard->shard_id()); shard->db_slice().Release(LockMode(), lock_args); @@ -1369,9 +1366,9 @@ bool Transaction::NotifySuspended(TxId committed_txid, ShardId sid, string_view CHECK_EQ(sd.local_mask & AWAKED_Q, 0); // Find index of awakened key - auto args = GetShardArgs(sid); - auto it = find_if(args.begin(), args.end(), [key](auto arg) { return facade::ToSV(arg) == key; }); - CHECK(it != args.end()); + ShardArgs args = GetShardArgs(sid); + auto it = find_if(args.cbegin(), args.cend(), [key](string_view arg) { return arg == key; }); + CHECK(it != args.cend()); // Change state to awaked and store index of awakened key sd.local_mask &= ~SUSPENDED_Q; @@ -1427,7 +1424,7 @@ void Transaction::LogAutoJournalOnShard(EngineShard* shard, RunnableResult resul if (unique_shard_cnt_ == 1 || kv_args_.empty()) { entry_payload = make_pair(cmd, full_args_); } else { - entry_payload = make_pair(cmd, GetShardArgs(shard->shard_id())); + entry_payload = make_pair(cmd, GetShardArgs(shard->shard_id()).AsSlice()); } LogJournalOnShard(shard, std::move(entry_payload), unique_shard_cnt_, false, true); } diff --git a/src/server/transaction.h b/src/server/transaction.h index 0462b3646..07ffc6260 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -22,6 +22,7 @@ #include "server/common.h" #include "server/journal/types.h" #include "server/table.h" +#include "server/tx_base.h" #include "util/fibers/synchronization.h" namespace dfly { @@ -129,8 +130,9 @@ class Transaction { // Runnable that is run on shards during hop executions (often named callback). // Callacks should return `OpStatus` which is implicitly converitble to `RunnableResult`! using RunnableType = absl::FunctionRef; + // Provides keys to block on for specific shard. - using WaitKeysProvider = std::function; + using WaitKeysProvider = std::function; // Modes in which a multi transaction can run. enum MultiMode { @@ -176,7 +178,7 @@ class Transaction { OpStatus InitByArgs(DbIndex index, CmdArgList args); // Get command arguments for specific shard. Called from shard thread. - ArgSlice GetShardArgs(ShardId sid) const; + ShardArgs GetShardArgs(ShardId sid) const; // Map arg_index from GetShardArgs slice to index in original command slice from InitByArgs. size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const; @@ -511,12 +513,12 @@ class Transaction { void RunCallback(EngineShard* shard); // Adds itself to watched queue in the shard. Must run in that shard thread. - OpStatus WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc); + OpStatus WatchInShard(const ShardArgs& keys, EngineShard* shard, KeyReadyChecker krc); // Expire blocking transaction, unlock keys and unregister it from the blocking controller void ExpireBlocking(WaitKeysProvider wcb); - void ExpireShardCb(ArgSlice wkeys, EngineShard* shard); + void ExpireShardCb(const ShardArgs& wkeys, EngineShard* shard); // Returns true if we need to follow up with PollExecution on this shard. bool CancelShardCb(EngineShard* shard); @@ -577,7 +579,6 @@ class Transaction { }); } - private: // Used for waiting for all hop callbacks to run. util::fb2::EmbeddedBlockingCounter run_barrier_{0}; diff --git a/src/server/tx_base.cc b/src/server/tx_base.cc new file mode 100644 index 000000000..5f1993c27 --- /dev/null +++ b/src/server/tx_base.cc @@ -0,0 +1,61 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/tx_base.h" + +#include "base/logging.h" +#include "server/cluster/cluster_defs.h" +#include "server/engine_shard_set.h" +#include "server/journal/journal.h" +#include "server/transaction.h" + +namespace dfly { + +using namespace std; + +void RecordJournal(const OpArgs& op_args, string_view cmd, ArgSlice args, uint32_t shard_cnt, + bool multi_commands) { + VLOG(2) << "Logging command " << cmd << " from txn " << op_args.tx->txid(); + op_args.tx->LogJournalOnShard(op_args.shard, make_pair(cmd, args), shard_cnt, multi_commands, + false); +} + +void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt) { + op_args.tx->FinishLogJournalOnShard(op_args.shard, shard_cnt); +} + +void RecordExpiry(DbIndex dbid, string_view key) { + auto journal = EngineShard::tlocal()->journal(); + CHECK(journal); + journal->RecordEntry(0, journal::Op::EXPIRED, dbid, 1, cluster::KeySlot(key), + make_pair("DEL", ArgSlice{key}), false); +} + +void TriggerJournalWriteToSink() { + auto journal = EngineShard::tlocal()->journal(); + CHECK(journal); + journal->RecordEntry(0, journal::Op::NOOP, 0, 0, nullopt, {}, true); +} + +std::ostream& operator<<(std::ostream& os, ArgSlice list) { + os << "["; + if (!list.empty()) { + std::for_each(list.begin(), list.end() - 1, [&os](const auto& val) { os << val << ", "; }); + os << (*(list.end() - 1)); + } + return os << "]"; +} + +LockTag::LockTag(std::string_view key) { + if (LockTagOptions::instance().enabled) + str_ = LockTagOptions::instance().Tag(key); + else + str_ = key; +} + +LockFp LockTag::Fingerprint() const { + return XXH64(str_.data(), str_.size(), 0x1C69B3F74AC4AE35UL); +} + +} // namespace dfly diff --git a/src/server/tx_base.h b/src/server/tx_base.h new file mode 100644 index 000000000..63f40fe20 --- /dev/null +++ b/src/server/tx_base.h @@ -0,0 +1,172 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +#include + +#include "src/facade/facade_types.h" + +namespace dfly { + +class EngineShard; +class Transaction; + +using DbIndex = uint16_t; +using ShardId = uint16_t; +using LockFp = uint64_t; // a key fingerprint used by the LockTable. + +using ArgSlice = absl::Span; + +constexpr DbIndex kInvalidDbId = DbIndex(-1); +constexpr ShardId kInvalidSid = ShardId(-1); +constexpr DbIndex kMaxDbId = 1024; // Reasonable starting point. + +struct KeyLockArgs { + DbIndex db_index = 0; + absl::Span fps; +}; + +// Describes key indices. +struct KeyIndex { + unsigned start; + unsigned end; // does not include this index (open limit). + unsigned step; // 1 for commands like mget. 2 for commands like mset. + + // if index is non-zero then adds another key index (usually 0). + // relevant for for commands like ZUNIONSTORE/ZINTERSTORE for destination key. + std::optional bonus{}; + bool has_reverse_mapping = false; + + KeyIndex(unsigned s = 0, unsigned e = 0, unsigned step = 0) : start(s), end(e), step(step) { + } + + static KeyIndex Range(unsigned start, unsigned end, unsigned step = 1) { + return KeyIndex{start, end, step}; + } + + bool HasSingleKey() const { + return !bonus && (start + step >= end); + } + + unsigned num_args() const { + return end - start + bool(bonus); + } +}; + +struct DbContext { + DbIndex db_index = 0; + uint64_t time_now_ms = 0; +}; + +struct OpArgs { + EngineShard* shard; + const Transaction* tx; + DbContext db_cntx; + + OpArgs() : shard(nullptr), tx(nullptr) { + } + + OpArgs(EngineShard* s, const Transaction* tx, const DbContext& cntx) + : shard(s), tx(tx), db_cntx(cntx) { + } +}; + +// A strong type for a lock tag. Helps to disambiguate between keys and the parts of the +// keys that are used for locking. +class LockTag { + std::string_view str_; + + public: + using is_stackonly = void; // marks that this object does not use heap. + + LockTag() = default; + explicit LockTag(std::string_view key); + + explicit operator std::string_view() const { + return str_; + } + + LockFp Fingerprint() const; + + // To make it hashable. + template friend H AbslHashValue(H h, const LockTag& tag) { + return H::combine(std::move(h), tag.str_); + } + + bool operator==(const LockTag& o) const { + return str_ == o.str_; + } +}; + +// Checks whether the touched key is valid for a blocking transaction watching it. +using KeyReadyChecker = + std::function; + +// References arguments in another array. +using IndexSlice = std::pair; // (begin, end) + +class ShardArgs : protected ArgSlice { + public: + using ArgSlice::ArgSlice; + using ArgSlice::at; + using ArgSlice::operator=; + using Iterator = ArgSlice::iterator; + + ShardArgs(const ArgSlice& o) : ArgSlice(o) { + } + + size_t Size() const { + return ArgSlice::size(); + } + + auto cbegin() const { + return ArgSlice::cbegin(); + } + + auto cend() const { + return ArgSlice::cend(); + } + + auto begin() const { + return cbegin(); + } + + auto end() const { + return cend(); + } + + bool Empty() const { + return ArgSlice::empty(); + } + + std::string_view Front() const { + return *cbegin(); + } + + ArgSlice AsSlice() const { + return ArgSlice(*this); + } +}; + +// Record non auto journal command with own txid and dbid. +void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args, + uint32_t shard_cnt = 1, bool multi_commands = false); + +// Record non auto journal command finish. Call only when command translates to multi commands. +void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt); + +// Record expiry in journal with independent transaction. Must be called from shard thread holding +// key. +void RecordExpiry(DbIndex dbid, std::string_view key); + +// Trigger journal write to sink, no journal record will be added to journal. +// Must be called from shard thread of journal to sink. +void TriggerJournalWriteToSink(); + +std::ostream& operator<<(std::ostream& os, ArgSlice list); + +} // namespace dfly diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 77147e94a..22dd8e045 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -822,39 +822,43 @@ double GetKeyWeight(Transaction* t, ShardId shard_id, const vector& weig OpResult OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type, const vector& weights, bool store) { - ArgSlice keys = t->GetShardArgs(shard->shard_id()); - DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys; - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); unsigned cmdargs_keys_offset = 1; // after {numkeys} for ZUNION unsigned removed_keys = 0; + ShardArgs::Iterator start = keys.begin(), end = keys.end(); + if (store) { // first global index is 2 after {destkey, numkeys}. ++cmdargs_keys_offset; - if (keys.front() == dest) { - keys.remove_prefix(1); + if (*start == dest) { + ++start; ++removed_keys; } // In case ONLY the destination key is hosted in this shard no work on this shard should be // done in this step - if (keys.empty()) { + if (start == end) { return OpStatus::OK; } } auto& db_slice = shard->db_slice(); - KeyIterWeightVec key_weight_vec(keys.size()); - for (unsigned j = 0; j < keys.size(); ++j) { - auto it_res = db_slice.FindReadOnly(t->GetDbContext(), keys[j], OBJ_ZSET); - if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1. + KeyIterWeightVec key_weight_vec(keys.Size() - removed_keys); + unsigned index = 0; + for (; start != end; ++start) { + auto it_res = db_slice.FindReadOnly(t->GetDbContext(), *start, OBJ_ZSET); + if (it_res == OpStatus::WRONG_TYPE) // TODO: support SET type with default score 1. return it_res.status(); - if (!it_res) + if (!it_res) { + ++index; continue; - - key_weight_vec[j] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys, - cmdargs_keys_offset)}; + } + key_weight_vec[index] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, + index + removed_keys, cmdargs_keys_offset)}; + ++index; } return UnionShardKeysWithScore(key_weight_vec, agg_type); @@ -871,46 +875,48 @@ ScoredMap ZSetFromSet(const PrimeValue& pv, double weight) { OpResult OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type, const vector& weights, bool store) { - ArgSlice keys = t->GetShardArgs(shard->shard_id()); - DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys; - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); unsigned removed_keys = 0; unsigned cmdargs_keys_offset = 1; + ShardArgs::Iterator start = keys.begin(), end = keys.end(); if (store) { // first global index is 2 after {destkey, numkeys}. ++cmdargs_keys_offset; - if (keys.front() == dest) { - keys.remove_prefix(1); + if (*start == dest) { + ++start; ++removed_keys; - } - // In case ONLY the destination key is hosted in this shard no work on this shard should be - // done in this step - if (keys.empty()) { - return OpStatus::SKIPPED; + // In case ONLY the destination key is hosted in this shard no work on this shard should be + // done in this step + if (start == end) { + return OpStatus::SKIPPED; + } } } auto& db_slice = shard->db_slice(); - vector> it_arr(keys.size()); - if (it_arr.empty()) // could be when only the dest key is hosted in this shard - return OpStatus::SKIPPED; // return noop + vector> it_arr(keys.Size() - removed_keys); - for (unsigned j = 0; j < keys.size(); ++j) { - auto it_res = db_slice.FindMutable(t->GetDbContext(), keys[j]); - if (!IsValid(it_res.it)) + unsigned index = 0; + for (; start != end; ++start) { + auto it_res = db_slice.FindMutable(t->GetDbContext(), *start); + if (!IsValid(it_res.it)) { + ++index; continue; // we exit in the next loop + } // sets are supported for ZINTER* commands: auto obj_type = it_res.it->second.ObjType(); if (obj_type != OBJ_ZSET && obj_type != OBJ_SET) return OpStatus::WRONG_TYPE; - it_arr[j] = {std::move(it_res), GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys, - cmdargs_keys_offset)}; + it_arr[index] = {std::move(it_res), GetKeyWeight(t, shard->shard_id(), weights, + index + removed_keys, cmdargs_keys_offset)}; + ++index; } ScoredMap result; @@ -1343,16 +1349,15 @@ void BZPopMinMax(CmdArgList args, ConnectionContext* cntx, bool is_max) { } vector OpFetch(EngineShard* shard, Transaction* t) { - ArgSlice keys = t->GetShardArgs(shard->shard_id()); - DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys; - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); vector results; - results.reserve(keys.size()); + results.reserve(keys.Size()); auto& db_slice = shard->db_slice(); - for (size_t i = 0; i < keys.size(); ++i) { - auto it = db_slice.FindReadOnly(t->GetDbContext(), keys[i], OBJ_ZSET); + for (string_view key : keys) { + auto it = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_ZSET); if (!it) { results.push_back({}); continue;