diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 14412c906..57e751451 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -31,7 +31,7 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc command_registry.cc cluster/unique_slot_checker.cc journal/tx_executor.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc - server_state.cc table.cc top_keys.cc transaction.cc + server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc ) diff --git a/src/server/bitops_family.cc b/src/server/bitops_family.cc index e8e01c6ea..11bad8896 100644 --- a/src/server/bitops_family.cc +++ b/src/server/bitops_family.cc @@ -24,6 +24,7 @@ namespace dfly { using namespace facade; +using namespace std; namespace { @@ -57,7 +58,6 @@ bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry); std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end); std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits); std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end); -OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args, ArgSlice keys); std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& values); // ------------------------------------------------------------------------- // @@ -444,12 +444,9 @@ OpResult CombineResultOp(ShardStringResults result, std::string_vie } // For bitop not - we cannot accumulate -OpResult RunBitOpNot(const OpArgs& op_args, ArgSlice keys) { - DCHECK(keys.size() == 1); - +OpResult RunBitOpNot(const OpArgs& op_args, string_view key) { EngineShard* es = op_args.shard; // if we found the value, just return, if not found then skip, otherwise report an error - auto key = keys.front(); auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, key, OBJ_STRING); if (find_res) { return GetString(find_res.value()->second); @@ -460,18 +457,18 @@ OpResult RunBitOpNot(const OpArgs& op_args, ArgSlice keys) { // Read only operation where we are running the bit operation on all the // values that belong to same shard. -OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); +OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args, + ShardArgs::Iterator start, ShardArgs::Iterator end) { + DCHECK(start != end); if (op == NOT_OP_NAME) { - return RunBitOpNot(op_args, keys); + return RunBitOpNot(op_args, *start); } EngineShard* es = op_args.shard; BitsStrVec values; - values.reserve(keys.size()); // collect all the value for this shard - for (auto& key : keys) { - auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, key, OBJ_STRING); + for (; start != end; ++start) { + auto find_res = es->db_slice().FindAndFetchReadOnly(op_args.db_cntx, *start, OBJ_STRING); if (find_res) { values.emplace_back(GetString(find_res.value()->second)); } else { @@ -1143,18 +1140,18 @@ void BitOp(CmdArgList args, ConnectionContext* cntx) { ShardId dest_shard = Shard(dest_key, result_set.size()); auto shard_bitop = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); - DCHECK(!largs.empty()); - + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + DCHECK(!largs.Empty()); + ShardArgs::Iterator start = largs.begin(), end = largs.end(); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - largs.remove_prefix(1); - if (largs.empty()) { // no more keys to check + CHECK_EQ(*start, dest_key); + ++start; + if (start == end) { // no more keys to check return OpStatus::OK; } } OpArgs op_args = t->GetOpArgs(shard); - result_set[shard->shard_id()] = RunBitOpOnShard(op, op_args, largs); + result_set[shard->shard_id()] = RunBitOpOnShard(op, op_args, start, end); return OpStatus::OK; }; diff --git a/src/server/blocking_controller.cc b/src/server/blocking_controller.cc index 7f1584adf..7eca07b43 100644 --- a/src/server/blocking_controller.cc +++ b/src/server/blocking_controller.cc @@ -118,7 +118,7 @@ bool BlockingController::DbWatchTable::AddAwakeEvent(string_view key) { } // Removes tx from its watch queues if tx appears there. -void BlockingController::FinalizeWatched(ArgSlice args, Transaction* tx) { +void BlockingController::FinalizeWatched(const ShardArgs& args, Transaction* tx) { DCHECK(tx); VLOG(1) << "FinalizeBlocking [" << owner_->shard_id() << "]" << tx->DebugId(); @@ -197,7 +197,8 @@ void BlockingController::NotifyPending() { awakened_indices_.clear(); } -void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transaction* trans) { +void BlockingController::AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc, + Transaction* trans) { auto [dbit, added] = watched_dbs_.emplace(trans->GetDbIndex(), nullptr); if (added) { dbit->second.reset(new DbWatchTable); @@ -205,7 +206,7 @@ void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transact DbWatchTable& wt = *dbit->second; - for (auto key : keys) { + for (auto key : watch_keys) { auto [res, inserted] = wt.queue_map.emplace(key, nullptr); if (inserted) { res->second.reset(new WatchQueue); diff --git a/src/server/blocking_controller.h b/src/server/blocking_controller.h index 081aff9f4..251811f4a 100644 --- a/src/server/blocking_controller.h +++ b/src/server/blocking_controller.h @@ -10,6 +10,7 @@ #include "base/string_view_sso.h" #include "server/common.h" +#include "server/tx_base.h" namespace dfly { @@ -28,7 +29,7 @@ class BlockingController { return awakened_transactions_; } - void FinalizeWatched(ArgSlice args, Transaction* tx); + void FinalizeWatched(const ShardArgs& args, Transaction* tx); // go over potential wakened keys, verify them and activate watch queues. void NotifyPending(); @@ -37,7 +38,7 @@ class BlockingController { // TODO: consider moving all watched functions to // EngineShard with separate per db map. //! AddWatched adds a transaction to the blocking queue. - void AddWatched(ArgSlice watch_keys, KeyReadyChecker krc, Transaction* me); + void AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc, Transaction* me); // Called from operations that create keys like lpush, rename etc. void AwakeWatched(DbIndex db_index, std::string_view db_key); diff --git a/src/server/common.cc b/src/server/common.cc index cc96d6f64..ade4e05d1 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -255,30 +255,6 @@ bool ParseDouble(string_view src, double* value) { return true; } -void RecordJournal(const OpArgs& op_args, string_view cmd, ArgSlice args, uint32_t shard_cnt, - bool multi_commands) { - VLOG(2) << "Logging command " << cmd << " from txn " << op_args.tx->txid(); - op_args.tx->LogJournalOnShard(op_args.shard, make_pair(cmd, args), shard_cnt, multi_commands, - false); -} - -void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt) { - op_args.tx->FinishLogJournalOnShard(op_args.shard, shard_cnt); -} - -void RecordExpiry(DbIndex dbid, string_view key) { - auto journal = EngineShard::tlocal()->journal(); - CHECK(journal); - journal->RecordEntry(0, journal::Op::EXPIRED, dbid, 1, cluster::KeySlot(key), - make_pair("DEL", ArgSlice{key}), false); -} - -void TriggerJournalWriteToSink() { - auto journal = EngineShard::tlocal()->journal(); - CHECK(journal); - journal->RecordEntry(0, journal::Op::NOOP, 0, 0, nullopt, {}, true); -} - #define ADD(x) (x) += o.x IoMgrStats& IoMgrStats::operator+=(const IoMgrStats& rhs) { @@ -462,24 +438,4 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) { return os << GlobalStateName(state); } -std::ostream& operator<<(std::ostream& os, ArgSlice list) { - os << "["; - if (!list.empty()) { - std::for_each(list.begin(), list.end() - 1, [&os](const auto& val) { os << val << ", "; }); - os << (*(list.end() - 1)); - } - return os << "]"; -} - -LockTag::LockTag(std::string_view key) { - if (LockTagOptions::instance().enabled) - str_ = LockTagOptions::instance().Tag(key); - else - str_ = key; -} - -LockFp LockTag::Fingerprint() const { - return XXH64(str_.data(), str_.size(), 0x1C69B3F74AC4AE35UL); -} - } // namespace dfly diff --git a/src/server/common.h b/src/server/common.h index 0e7ef4c65..56fdbdf01 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -1,4 +1,4 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. +// Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // @@ -27,8 +27,6 @@ enum class ListDir : uint8_t { LEFT, RIGHT }; constexpr int64_t kMaxExpireDeadlineSec = (1u << 28) - 1; // 8.5 years constexpr int64_t kMaxExpireDeadlineMs = kMaxExpireDeadlineSec * 1000; -using DbIndex = uint16_t; -using ShardId = uint16_t; using LSN = uint64_t; using TxId = uint64_t; using TxClock = uint64_t; @@ -39,17 +37,11 @@ using facade::CmdArgVec; using facade::MutableSlice; using facade::OpResult; -using ArgSlice = absl::Span; using StringVec = std::vector; // keys are RDB_TYPE_xxx constants. using RdbTypeFreqMap = absl::flat_hash_map; -constexpr DbIndex kInvalidDbId = DbIndex(-1); -constexpr ShardId kInvalidSid = ShardId(-1); -constexpr DbIndex kMaxDbId = 1024; // Reasonable starting point. -using LockFp = uint64_t; // a key fingerprint used by the LockTable. - class CommandId; class Transaction; class EngineShard; @@ -67,98 +59,6 @@ struct LockTagOptions { static const LockTagOptions& instance(); }; -struct KeyLockArgs { - DbIndex db_index = 0; - absl::Span fps; -}; - -// Describes key indices. -struct KeyIndex { - unsigned start; - unsigned end; // does not include this index (open limit). - unsigned step; // 1 for commands like mget. 2 for commands like mset. - - // if index is non-zero then adds another key index (usually 0). - // relevant for for commands like ZUNIONSTORE/ZINTERSTORE for destination key. - std::optional bonus{}; - bool has_reverse_mapping = false; - - KeyIndex(unsigned s = 0, unsigned e = 0, unsigned step = 0) : start(s), end(e), step(step) { - } - - static KeyIndex Range(unsigned start, unsigned end, unsigned step = 1) { - return KeyIndex{start, end, step}; - } - - bool HasSingleKey() const { - return !bonus && (start + step >= end); - } - - unsigned num_args() const { - return end - start + bool(bonus); - } -}; - -struct DbContext { - DbIndex db_index = 0; - uint64_t time_now_ms = 0; -}; - -struct OpArgs { - EngineShard* shard; - const Transaction* tx; - DbContext db_cntx; - - OpArgs() : shard(nullptr), tx(nullptr) { - } - - OpArgs(EngineShard* s, const Transaction* tx, const DbContext& cntx) - : shard(s), tx(tx), db_cntx(cntx) { - } -}; - -// A strong type for a lock tag. Helps to disambiguate between keys and the parts of the -// keys that are used for locking. -class LockTag { - std::string_view str_; - - public: - using is_stackonly = void; // marks that this object does not use heap. - - LockTag() = default; - explicit LockTag(std::string_view key); - - explicit operator std::string_view() const { - return str_; - } - - LockFp Fingerprint() const; - - // To make it hashable. - template friend H AbslHashValue(H h, const LockTag& tag) { - return H::combine(std::move(h), tag.str_); - } - - bool operator==(const LockTag& o) const { - return str_ == o.str_; - } -}; - -// Record non auto journal command with own txid and dbid. -void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args, - uint32_t shard_cnt = 1, bool multi_commands = false); - -// Record non auto journal command finish. Call only when command translates to multi commands. -void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt); - -// Record expiry in journal with independent transaction. Must be called from shard thread holding -// key. -void RecordExpiry(DbIndex dbid, std::string_view key); - -// Trigger journal write to sink, no journal record will be added to journal. -// Must be called from shard thread of journal to sink. -void TriggerJournalWriteToSink(); - struct IoMgrStats { uint64_t read_total = 0; uint64_t read_delay_usec = 0; @@ -205,8 +105,6 @@ enum class GlobalState : uint8_t { std::ostream& operator<<(std::ostream& os, const GlobalState& state); -std::ostream& operator<<(std::ostream& os, ArgSlice list); - enum class TimeUnit : uint8_t { SEC, MSEC }; inline void ToUpper(const MutableSlice* val) { @@ -414,10 +312,6 @@ inline uint32_t MemberTimeSeconds(uint64_t now_ms) { return (now_ms / 1000) - kMemberExpiryBase; } -// Checks whether the touched key is valid for a blocking transaction watching it -using KeyReadyChecker = - std::function; - struct MemoryBytesFlag { uint64_t value = 0; }; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index ba94a117c..4c007b4e1 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -11,6 +11,7 @@ #include "facade/conn_context.h" #include "facade/reply_capture.h" #include "server/common.h" +#include "server/tx_base.h" #include "server/version.h" namespace dfly { diff --git a/src/server/container_utils.cc b/src/server/container_utils.cc index 896c336c6..3b0402c7b 100644 --- a/src/server/container_utils.cc +++ b/src/server/container_utils.cc @@ -24,7 +24,7 @@ extern "C" { ABSL_FLAG(bool, singlehop_blocking, true, "Use single hop optimization for blocking commands"); namespace dfly::container_utils { - +using namespace std; namespace { struct ShardFFResult { @@ -32,16 +32,38 @@ struct ShardFFResult { ShardId sid = kInvalidSid; }; +// Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise. +// If multiple keys are found, returns the first index in the ArgSlice. +OpResult> FindFirstReadOnly(const DbSlice& db_slice, + const DbContext& cntx, + const ShardArgs& args, + int req_obj_type) { + DCHECK(!args.Empty()); + + unsigned i = 0; + for (string_view key : args) { + OpResult res = db_slice.FindReadOnly(cntx, key, req_obj_type); + if (res) + return make_pair(res.value(), i); + if (res.status() != OpStatus::KEY_NOTFOUND) + return res.status(); + ++i; + } + + VLOG(2) << "FindFirst not found"; + return OpStatus::KEY_NOTFOUND; +} + // Find first non-empty key of a single shard transaction, pass it to `func` and return the key. // If no such key exists or a wrong type is found, the apropriate status is returned. // Optimized version of `FindFirstNonEmpty` below. -OpResult FindFirstNonEmptySingleShard(Transaction* trans, int req_obj_type, - BlockingResultCb func) { +OpResult FindFirstNonEmptySingleShard(Transaction* trans, int req_obj_type, + BlockingResultCb func) { DCHECK_EQ(trans->GetUniqueShardCnt(), 1u); - std::string key; + string key; auto cb = [&](Transaction* t, EngineShard* shard) -> Transaction::RunnableResult { auto args = t->GetShardArgs(shard->shard_id()); - auto ff_res = shard->db_slice().FindFirstReadOnly(t->GetDbContext(), args, req_obj_type); + auto ff_res = FindFirstReadOnly(shard->db_slice(), t->GetDbContext(), args, req_obj_type); if (ff_res == OpStatus::WRONG_TYPE) return OpStatus::WRONG_TYPE; @@ -77,7 +99,7 @@ OpResult FindFirstNonEmpty(Transaction* trans, int req_obj_type) auto cb = [&](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - auto ff_res = shard->db_slice().FindFirstReadOnly(t->GetDbContext(), args, req_obj_type); + auto ff_res = FindFirstReadOnly(shard->db_slice(), t->GetDbContext(), args, req_obj_type); if (ff_res) { find_res[shard->shard_id()] = FFResult{ff_res->first->first.AsRef(), ff_res->second, shard->shard_id()}; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index ef22d5c89..6823777d5 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -411,7 +411,7 @@ OpResult DbSlice::FindMutableInternal(const Context& cntx } } -DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_view key) { +DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_view key) const { auto res = FindInternal(cntx, key, std::nullopt, UpdateStatsMode::kReadStats, LoadExternalMode::kDontLoad); return {ConstIterator(res->it, StringOrView::FromView(key)), @@ -419,7 +419,7 @@ DbSlice::ItAndExpConst DbSlice::FindReadOnly(const Context& cntx, std::string_vi } OpResult DbSlice::FindReadOnly(const Context& cntx, string_view key, - unsigned req_obj_type) { + unsigned req_obj_type) const { auto res = FindInternal(cntx, key, req_obj_type, UpdateStatsMode::kReadStats, LoadExternalMode::kDontLoad); if (res.ok()) { @@ -442,7 +442,7 @@ OpResult DbSlice::FindAndFetchReadOnly(const Context& cn OpResult DbSlice::FindInternal(const Context& cntx, std::string_view key, std::optional req_obj_type, UpdateStatsMode stats_mode, - LoadExternalMode load_mode) { + LoadExternalMode load_mode) const { if (!IsDbValid(cntx.db_index)) { return OpStatus::KEY_NOTFOUND; } @@ -536,24 +536,6 @@ OpResult DbSlice::FindInternal(const Context& cntx, std: return res; } -OpResult> DbSlice::FindFirstReadOnly(const Context& cntx, - ArgSlice args, - int req_obj_type) { - DCHECK(!args.empty()); - - for (unsigned i = 0; i < args.size(); ++i) { - string_view s = args[i]; - OpResult res = FindReadOnly(cntx, s, req_obj_type); - if (res) - return make_pair(res.value(), i); - if (res.status() != OpStatus::KEY_NOTFOUND) - return res.status(); - } - - VLOG(2) << "FindFirst " << args.front() << " not found"; - return OpStatus::KEY_NOTFOUND; -} - OpResult DbSlice::AddOrFind(const Context& cntx, string_view key) { return AddOrFindInternal(cntx, key, LoadExternalMode::kDontLoad); } @@ -1082,12 +1064,12 @@ void DbSlice::PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size SendInvalidationTrackingMessage(key); } -DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) { +DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) const { auto res = ExpireIfNeeded(cntx, it.GetInnerIt()); return {.it = Iterator::FromPrime(res.it), .exp_it = ExpIterator::FromPrime(res.exp_it)}; } -DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) { +DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) const { if (!it->second.HasExpire()) { LOG(ERROR) << "Invalid call to ExpireIfNeeded"; return {it, ExpireIterator{}}; @@ -1124,8 +1106,9 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato doc_del_cb_(key, cntx, it->second); } - PerformDeletion(Iterator(it, StringOrView::FromView(key)), - ExpIterator(expire_it, StringOrView::FromView(key)), db.get()); + const_cast(this)->PerformDeletion(Iterator(it, StringOrView::FromView(key)), + ExpIterator(expire_it, StringOrView::FromView(key)), + db.get()); ++events_.expired_keys; return {PrimeIterator{}, ExpireIterator{}}; @@ -1490,21 +1473,6 @@ void DbSlice::ResetEvents() { events_ = {}; } -void DbSlice::TrackKeys(const facade::Connection::WeakRef& conn, const ArgSlice& keys) { - if (conn.IsExpired()) { - DVLOG(2) << "Connection expired, exiting TrackKey function."; - return; - } - - DVLOG(2) << "Start tracking keys for client ID: " << conn.GetClientId() - << " with thread ID: " << conn.Thread(); - for (auto key : keys) { - DVLOG(2) << "Inserting client ID " << conn.GetClientId() - << " into the tracking client set of key " << key; - client_tracking_map_[key].insert(conn); - } -} - void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { if (client_tracking_map_.empty()) return; diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 00a9b2454..24b88da2e 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -281,17 +281,12 @@ class DbSlice { ConstIterator it; ExpConstIterator exp_it; }; - ItAndExpConst FindReadOnly(const Context& cntx, std::string_view key); + ItAndExpConst FindReadOnly(const Context& cntx, std::string_view key) const; OpResult FindReadOnly(const Context& cntx, std::string_view key, - unsigned req_obj_type); + unsigned req_obj_type) const; OpResult FindAndFetchReadOnly(const Context& cntx, std::string_view key, unsigned req_obj_type); - // Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise. - // If multiple keys are found, returns the first index in the ArgSlice. - OpResult> FindFirstReadOnly(const Context& cntx, ArgSlice args, - int req_obj_type); - struct AddOrFindResult { Iterator it; ExpIterator exp_it; @@ -404,7 +399,7 @@ class DbSlice { Iterator it; ExpIterator exp_it; }; - ItAndExp ExpireIfNeeded(const Context& cntx, Iterator it); + ItAndExp ExpireIfNeeded(const Context& cntx, Iterator it) const; // Iterate over all expire table entries and delete expired. void ExpireAllIfNeeded(); @@ -473,7 +468,9 @@ class DbSlice { } // Track keys for the client represented by the the weak reference to its connection. - void TrackKeys(const facade::Connection::WeakRef&, const ArgSlice&); + void TrackKey(const facade::Connection::WeakRef& conn_ref, std::string_view key) { + client_tracking_map_[key].insert(conn_ref); + } // Delete a key referred by its iterator. void PerformDeletion(Iterator del_it, DbTable* table); @@ -517,10 +514,11 @@ class DbSlice { PrimeIterator it; ExpireIterator exp_it; }; - PrimeItAndExp ExpireIfNeeded(const Context& cntx, PrimeIterator it); + PrimeItAndExp ExpireIfNeeded(const Context& cntx, PrimeIterator it) const; OpResult FindInternal(const Context& cntx, std::string_view key, std::optional req_obj_type, - UpdateStatsMode stats_mode, LoadExternalMode load_mode); + UpdateStatsMode stats_mode, + LoadExternalMode load_mode) const; OpResult AddOrFindInternal(const Context& cntx, std::string_view key, LoadExternalMode load_mode); OpResult FindMutableInternal(const Context& cntx, std::string_view key, diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index a0e12f8b2..4b3c6b7de 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -281,11 +281,11 @@ class Renamer { void Renamer::Find(Transaction* t) { auto cb = [this](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - CHECK_EQ(1u, args.size()); + DCHECK_EQ(1u, args.Size()); FindResult* res = (shard->shard_id() == src_sid_) ? &src_res_ : &dest_res_; - res->key = args.front(); + res->key = args.Front(); auto& db_slice = EngineShard::tlocal()->db_slice(); auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), res->key); @@ -615,6 +615,40 @@ OpResult OpFieldTtl(Transaction* t, EngineShard* shard, string_view key, s return res <= 0 ? res : int32_t(res - MemberTimeSeconds(db_cntx.time_now_ms)); } +OpResult OpDel(const OpArgs& op_args, const ShardArgs& keys) { + DVLOG(1) << "Del: " << keys.Front(); + auto& db_slice = op_args.shard->db_slice(); + + uint32_t res = 0; + + for (string_view key : keys) { + auto fres = db_slice.FindMutable(op_args.db_cntx, key); + if (!IsValid(fres.it)) + continue; + fres.post_updater.Run(); + res += int(db_slice.Del(op_args.db_cntx.db_index, fres.it)); + } + + return res; +} + +OpResult OpStick(const OpArgs& op_args, const ShardArgs& keys) { + DVLOG(1) << "Stick: " << keys.Front(); + + auto& db_slice = op_args.shard->db_slice(); + + uint32_t res = 0; + for (string_view key : keys) { + auto find_res = db_slice.FindMutable(op_args.db_cntx, key); + if (IsValid(find_res.it) && !find_res.it->first.IsSticky()) { + find_res.it->first.SetSticky(true); + ++res; + } + } + + return res; +} + } // namespace void GenericFamily::Init(util::ProactorPool* pp) { @@ -631,7 +665,7 @@ void GenericFamily::Del(CmdArgList args, ConnectionContext* cntx) { bool is_mc = cntx->protocol() == Protocol::MEMCACHE; auto cb = [&result](const Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = OpDel(t->GetOpArgs(shard), args); result.fetch_add(res.value_or(0), memory_order_relaxed); @@ -683,7 +717,7 @@ void GenericFamily::Exists(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t result{0}; auto cb = [&result](Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = OpExists(t->GetOpArgs(shard), args); result.fetch_add(res.value_or(0), memory_order_relaxed); @@ -889,7 +923,7 @@ void GenericFamily::Stick(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t result{0}; auto cb = [&result](const Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = OpStick(t->GetOpArgs(shard), args); result.fetch_add(res.value_or(0), memory_order_relaxed); @@ -1373,30 +1407,13 @@ OpResult GenericFamily::OpTtl(Transaction* t, EngineShard* shard, stri return ttl_ms; } -OpResult GenericFamily::OpDel(const OpArgs& op_args, ArgSlice keys) { - DVLOG(1) << "Del: " << keys[0]; - auto& db_slice = op_args.shard->db_slice(); - - uint32_t res = 0; - - for (uint32_t i = 0; i < keys.size(); ++i) { - auto fres = db_slice.FindMutable(op_args.db_cntx, keys[i]); - if (!IsValid(fres.it)) - continue; - fres.post_updater.Run(); - res += int(db_slice.Del(op_args.db_cntx.db_index, fres.it)); - } - - return res; -} - -OpResult GenericFamily::OpExists(const OpArgs& op_args, ArgSlice keys) { - DVLOG(1) << "Exists: " << keys[0]; +OpResult GenericFamily::OpExists(const OpArgs& op_args, const ShardArgs& keys) { + DVLOG(1) << "Exists: " << keys.Front(); auto& db_slice = op_args.shard->db_slice(); uint32_t res = 0; - for (uint32_t i = 0; i < keys.size(); ++i) { - auto find_res = db_slice.FindReadOnly(op_args.db_cntx, keys[i]); + for (string_view key : keys) { + auto find_res = db_slice.FindReadOnly(op_args.db_cntx, key); res += IsValid(find_res.it); } return res; @@ -1462,23 +1479,6 @@ OpResult GenericFamily::OpRen(const OpArgs& op_args, string_view from_key, return OpStatus::OK; } -OpResult GenericFamily::OpStick(const OpArgs& op_args, ArgSlice keys) { - DVLOG(1) << "Stick: " << keys[0]; - - auto& db_slice = op_args.shard->db_slice(); - - uint32_t res = 0; - for (uint32_t i = 0; i < keys.size(); ++i) { - auto find_res = db_slice.FindMutable(op_args.db_cntx, keys[i]); - if (IsValid(find_res.it) && !find_res.it->first.IsSticky()) { - find_res.it->first.SetSticky(true); - ++res; - } - } - - return res; -} - // OpMove touches multiple databases (op_args.db_idx, target_db), so it assumes it runs // as a global transaction. // TODO: Allow running OpMove without a global transaction. diff --git a/src/server/generic_family.h b/src/server/generic_family.h index f7e888c67..015ed7fcb 100644 --- a/src/server/generic_family.h +++ b/src/server/generic_family.h @@ -40,7 +40,7 @@ class GenericFamily { static void Register(CommandRegistry* registry); // Accessed by Service::Exec and Service::Watch as an utility. - static OpResult OpExists(const OpArgs& op_args, ArgSlice keys); + static OpResult OpExists(const OpArgs& op_args, const ShardArgs& keys); private: static void Del(CmdArgList args, ConnectionContext* cntx); @@ -76,10 +76,8 @@ class GenericFamily { static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit); static OpResult OpTtl(Transaction* t, EngineShard* shard, std::string_view key); - static OpResult OpDel(const OpArgs& op_args, ArgSlice keys); static OpResult OpRen(const OpArgs& op_args, std::string_view from, std::string_view to, bool skip_exists); - static OpResult OpStick(const OpArgs& op_args, ArgSlice keys); static OpStatus OpMove(const OpArgs& op_args, std::string_view key, DbIndex target_db); }; diff --git a/src/server/hll_family.cc b/src/server/hll_family.cc index ac3c4de64..0bb88e4e6 100644 --- a/src/server/hll_family.cc +++ b/src/server/hll_family.cc @@ -169,11 +169,11 @@ OpResult CountHllsSingle(const OpArgs& op_args, string_view key) { } } -OpResult> ReadValues(const OpArgs& op_args, ArgSlice keys) { +OpResult> ReadValues(const OpArgs& op_args, const ShardArgs& keys) { try { vector values; - for (size_t i = 0; i < keys.size(); ++i) { - auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_STRING); + for (string_view key : keys) { + auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (it.ok()) { string hll; it.value()->second.GetString(&hll); @@ -210,7 +210,7 @@ OpResult PFCountMulti(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { ShardId sid = shard->shard_id(); - ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); + ShardArgs shard_args = t->GetShardArgs(shard->shard_id()); auto result = ReadValues(t->GetOpArgs(shard), shard_args); if (result.ok()) { hlls[sid] = std::move(result.value()); @@ -252,7 +252,7 @@ OpResult PFMergeInternal(CmdArgList args, ConnectionContext* cntx) { atomic_bool success = true; auto cb = [&](Transaction* t, EngineShard* shard) { ShardId sid = shard->shard_id(); - ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); + ShardArgs shard_args = t->GetShardArgs(shard->shard_id()); auto result = ReadValues(t->GetOpArgs(shard), shard_args); if (result.ok()) { hlls[sid] = std::move(result.value()); diff --git a/src/server/json_family.cc b/src/server/json_family.cc index 0736a268d..dff09c181 100644 --- a/src/server/json_family.cc +++ b/src/server/json_family.cc @@ -1130,19 +1130,21 @@ OpResult> OpArrIndex(const OpArgs& op_args, string_view key, Jso // Returns string vector that represents the query result of each supplied key. vector OpJsonMGet(JsonPathV2 expression, const Transaction* t, EngineShard* shard) { - auto args = t->GetShardArgs(shard->shard_id()); - DCHECK(!args.empty()); - vector response(args.size()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); + DCHECK(!args.Empty()); + vector response(args.Size()); auto& db_slice = shard->db_slice(); - for (size_t i = 0; i < args.size(); ++i) { - auto it_res = db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_JSON); + unsigned index = 0; + for (string_view key : args) { + auto it_res = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_JSON); + auto& dest = response[index++]; if (!it_res.ok()) continue; - auto& dest = response[i].emplace(); + dest.emplace(); JsonType* json_val = it_res.value()->second.GetJson(); - DCHECK(json_val) << "should have a valid JSON object for key " << args[i]; + DCHECK(json_val) << "should have a valid JSON object for key " << key; vector query_result; auto cb = [&query_result](const string_view& path, const JsonType& val) { @@ -1364,8 +1366,8 @@ void JsonFamily::MSet(CmdArgList args, ConnectionContext* cntx) { } auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); - LOG(INFO) << shard->shard_id() << " " << args; + ShardArgs args = t->GetShardArgs(shard->shard_id()); + (void)args; // TBD return OpStatus::OK; }; @@ -1469,12 +1471,7 @@ void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) { continue; vector& res = mget_resp[sid]; - ArgSlice slice = transaction->GetShardArgs(sid); - - DCHECK(!slice.empty()); - DCHECK_EQ(slice.size(), res.size()); - - for (size_t j = 0; j < slice.size(); ++j) { + for (size_t j = 0; j < res.size(); ++j) { if (!res[j]) continue; diff --git a/src/server/list_family.cc b/src/server/list_family.cc index f02f04cff..79adbf336 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -416,9 +416,9 @@ OpResult MoveTwoShards(Transaction* trans, string_view src, string_view // auto cb = [&](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - DCHECK_EQ(1u, args.size()); - bool is_dest = args.front() == dest; - find_res[is_dest] = Peek(t->GetOpArgs(shard), args.front(), src_dir, !is_dest); + DCHECK_EQ(1u, args.Size()); + bool is_dest = args.Front() == dest; + find_res[is_dest] = Peek(t->GetOpArgs(shard), args.Front(), src_dir, !is_dest); return OpStatus::OK; }; @@ -432,7 +432,7 @@ OpResult MoveTwoShards(Transaction* trans, string_view src, string_view // Everything is ok, lets proceed with the mutations. auto cb = [&](Transaction* t, EngineShard* shard) { auto args = t->GetShardArgs(shard->shard_id()); - auto key = args.front(); + auto key = args.Front(); bool is_dest = (key == dest); OpArgs op_args = t->GetOpArgs(shard); @@ -873,7 +873,7 @@ OpResult BPopPusher::RunSingle(ConnectionContext* cntx, time_point tp) { return op_res; } - auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; }; + auto wcb = [&](Transaction* t, EngineShard* shard) { return ShardArgs{&this->pop_key_, 1}; }; const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*, std::string_view key) -> bool { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 4ad682f1c..74a1d5daa 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1127,9 +1127,25 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); } -OpResult OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, const ArgSlice& keys) { - auto& db_slice = op_args.shard->db_slice(); - db_slice.TrackKeys(cntx->conn()->Borrow(), keys); +OpResult OpTrackKeys(const OpArgs& op_args, const facade::Connection::WeakRef& conn_ref, + const ShardArgs& args) { + if (conn_ref.IsExpired()) { + DVLOG(2) << "Connection expired, exiting TrackKey function."; + return OpStatus::OK; + } + + DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() + << " with thread ID: " << conn_ref.Thread(); + + DbSlice& db_slice = op_args.shard->db_slice(); + + // TODO: There is a bug here that we track all arguments instead of tracking only keys. + for (auto key : args) { + DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() + << " into the tracking client set of key " << key; + db_slice.TrackKey(conn_ref, key); + } + return OpStatus::OK; } @@ -1236,9 +1252,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // start tracking all the updates to the keys in this read command if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() && cid->IsTransactional()) { - auto cb = [&](Transaction* t, EngineShard* shard) { - auto keys = t->GetShardArgs(shard->shard_id()); - return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, keys); + facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow(); + auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) { + return OpTrackKeys(t->GetOpArgs(shard), conn_ref, t->GetShardArgs(shard->shard_id())); }; dfly_cntx->transaction->Refurbish(); dfly_cntx->transaction->ScheduleSingleHopT(cb); @@ -1610,7 +1626,7 @@ void Service::Watch(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t keys_existed = 0; auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); for (auto k : largs) { shard->db_slice().RegisterWatchedKey(cntx->db_index(), k, &exec_info); } @@ -2018,7 +2034,7 @@ bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& regis atomic_uint32_t watch_exist_count{0}; auto cb = [&watch_exist_count](Transaction* t, EngineShard* shard) { - ArgSlice args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); auto res = GenericFamily::OpExists(t->GetOpArgs(shard), args); watch_exist_count.fetch_add(res.value_or(0), memory_order_relaxed); diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 1a3fdb520..5e046a00f 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -587,10 +587,10 @@ class Mover { }; OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { - ArgSlice largs = t->GetShardArgs(es->shard_id()); + ShardArgs largs = t->GetShardArgs(es->shard_id()); // In case both src and dest are in the same shard, largs size will be 2. - DCHECK_LE(largs.size(), 2u); + DCHECK_LE(largs.Size(), 2u); for (auto k : largs) { unsigned index = (k == src_) ? 0 : 1; @@ -609,8 +609,8 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { } OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) { - ArgSlice largs = t->GetShardArgs(es->shard_id()); - DCHECK_LE(largs.size(), 2u); + ShardArgs largs = t->GetShardArgs(es->shard_id()); + DCHECK_LE(largs.Size(), 2u); OpArgs op_args = t->GetOpArgs(es); for (auto k : largs) { @@ -655,12 +655,13 @@ OpResult Mover::Commit(Transaction* t) { } // Read-only OpUnion op on sets. -OpResult OpUnion(const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); +OpResult OpUnion(const OpArgs& op_args, ShardArgs::Iterator start, + ShardArgs::Iterator end) { + DCHECK(start != end); absl::flat_hash_set uniques; - for (string_view key : keys) { - auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET); + for (; start != end; ++start) { + auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET); if (find_res) { const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { @@ -683,11 +684,12 @@ OpResult OpUnion(const OpArgs& op_args, ArgSlice keys) { } // Read-only OpDiff op on sets. -OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); - DVLOG(1) << "OpDiff from " << keys.front(); +OpResult OpDiff(const OpArgs& op_args, ShardArgs::Iterator start, + ShardArgs::Iterator end) { + DCHECK(start != end); + DVLOG(1) << "OpDiff from " << *start; EngineShard* es = op_args.shard; - auto find_res = es->db_slice().FindReadOnly(op_args.db_cntx, keys.front(), OBJ_SET); + auto find_res = es->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET); if (!find_res) { return find_res.status(); @@ -707,8 +709,8 @@ OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { DCHECK(!uniques.empty()); // otherwise the key would not exist. - for (size_t i = 1; i < keys.size(); ++i) { - auto diff_res = es->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_SET); + for (++start; start != end; ++start) { + auto diff_res = es->db_slice().FindReadOnly(op_args.db_cntx, *start, OBJ_SET); if (!diff_res) { if (diff_res.status() == OpStatus::WRONG_TYPE) { return OpStatus::WRONG_TYPE; @@ -737,15 +739,16 @@ OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { // Read-only OpInter op on sets. OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_first) { - ArgSlice keys = t->GetShardArgs(es->shard_id()); + ShardArgs args = t->GetShardArgs(es->shard_id()); + auto it = args.begin(); if (remove_first) { - keys.remove_prefix(1); + ++it; } - DCHECK(!keys.empty()); + DCHECK(it != args.end()); StringVec result; - if (keys.size() == 1) { - auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), keys.front(), OBJ_SET); + if (args.Size() == 1 + unsigned(remove_first)) { + auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), *it, OBJ_SET); if (!find_res) return find_res.status(); @@ -763,12 +766,13 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f return result; } - vector sets(keys.size()); + vector sets(args.Size() - int(remove_first)); OpStatus status = OpStatus::OK; - - for (size_t i = 0; i < keys.size(); ++i) { - auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), keys[i], OBJ_SET); + unsigned index = 0; + for (; it != args.end(); ++it) { + auto& dest = sets[index++]; + auto find_res = es->db_slice().FindReadOnly(t->GetDbContext(), *it, OBJ_SET); if (!find_res) { if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND || find_res.status() != OpStatus::KEY_NOTFOUND) { @@ -778,7 +782,7 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f } const PrimeValue& pv = find_res.value()->second; void* ptr = pv.RObjPtr(); - sets[i] = make_pair(ptr, pv.Encoding()); + dest = make_pair(ptr, pv.Encoding()); } if (status != OpStatus::OK) @@ -1089,12 +1093,12 @@ void SDiff(CmdArgList args, ConnectionContext* cntx) { ShardId src_shard = Shard(src_key, result_set.size()); auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); if (shard->shard_id() == src_shard) { - CHECK_EQ(src_key, largs.front()); - result_set[shard->shard_id()] = OpDiff(t->GetOpArgs(shard), largs); + CHECK_EQ(src_key, largs.Front()); + result_set[shard->shard_id()] = OpDiff(t->GetOpArgs(shard), largs.begin(), largs.end()); } else { - result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs); + result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs.begin(), largs.end()); } return OpStatus::OK; @@ -1126,22 +1130,23 @@ void SDiffStore(CmdArgList args, ConnectionContext* cntx) { // read-only op auto diff_cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); - DCHECK(!largs.empty()); - + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + OpArgs op_args = t->GetOpArgs(shard); + DCHECK(!largs.Empty()); + ShardArgs::Iterator start = largs.begin(); + ShardArgs::Iterator end = largs.end(); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - largs.remove_prefix(1); - if (largs.empty()) + CHECK_EQ(*start, dest_key); + ++start; + if (start == end) return OpStatus::OK; } - OpArgs op_args = t->GetOpArgs(shard); if (shard->shard_id() == src_shard) { - CHECK_EQ(src_key, largs.front()); - result_set[shard->shard_id()] = OpDiff(op_args, largs); // Diff + CHECK_EQ(src_key, *start); + result_set[shard->shard_id()] = OpDiff(op_args, start, end); // Diff } else { - result_set[shard->shard_id()] = OpUnion(op_args, largs); // Union + result_set[shard->shard_id()] = OpUnion(op_args, start, end); // Union } return OpStatus::OK; @@ -1276,10 +1281,10 @@ void SInterStore(CmdArgList args, ConnectionContext* cntx) { atomic_uint32_t inter_shard_cnt{0}; auto inter_cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - if (largs.size() == 1) + CHECK_EQ(largs.Front(), dest_key); + if (largs.Size() == 1) return OpStatus::OK; } inter_shard_cnt.fetch_add(1, memory_order_relaxed); @@ -1337,8 +1342,8 @@ void SUnion(CmdArgList args, ConnectionContext* cntx) { ResultStringVec result_set(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); - result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs.begin(), largs.end()); return OpStatus::OK; }; @@ -1363,14 +1368,15 @@ void SUnionStore(CmdArgList args, ConnectionContext* cntx) { ShardId dest_shard = Shard(dest_key, result_set.size()); auto union_cb = [&](Transaction* t, EngineShard* shard) { - ArgSlice largs = t->GetShardArgs(shard->shard_id()); + ShardArgs largs = t->GetShardArgs(shard->shard_id()); + ShardArgs::Iterator start = largs.begin(), end = largs.end(); if (shard->shard_id() == dest_shard) { - CHECK_EQ(largs.front(), dest_key); - largs.remove_prefix(1); - if (largs.empty()) + CHECK_EQ(*start, dest_key); + ++start; + if (start == end) return OpStatus::OK; } - result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), largs); + result_set[shard->shard_id()] = OpUnion(t->GetOpArgs(shard), start, end); return OpStatus::OK; }; diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc index ac58427b2..3274c439c 100644 --- a/src/server/stream_family.cc +++ b/src/server/stream_family.cc @@ -798,8 +798,8 @@ stream* GetReadOnlyStream(const CompactObj& cobj) { // Returns a map of stream to the ID of the last entry in the stream. Any // streams not found are omitted from the result. OpResult>> OpLastIDs(const OpArgs& op_args, - const ArgSlice& args) { - DCHECK(!args.empty()); + const ShardArgs& args) { + DCHECK(!args.Empty()); auto& db_slice = op_args.shard->db_slice(); @@ -828,8 +828,8 @@ OpResult>> OpLastIDs(const OpArgs& op_args, // Returns the range response for each stream on this shard in order of // GetShardArgs. -vector OpRead(const OpArgs& op_args, const ArgSlice& args, const ReadOpts& opts) { - DCHECK(!args.empty()); +vector OpRead(const OpArgs& op_args, const ShardArgs& shard_args, const ReadOpts& opts) { + DCHECK(!shard_args.Empty()); RangeOpts range_opts; range_opts.count = opts.count; @@ -838,11 +838,11 @@ vector OpRead(const OpArgs& op_args, const ArgSlice& args, const Read .seq = UINT64_MAX, }}; - vector response(args.size()); - for (size_t i = 0; i < args.size(); ++i) { - string_view key = args[i]; - + vector response(shard_args.Size()); + unsigned index = 0; + for (string_view key : shard_args) { auto sitem = opts.stream_ids.at(key); + auto& dest = response[index++]; if (!sitem.group && opts.read_group) { continue; } @@ -858,7 +858,7 @@ vector OpRead(const OpArgs& op_args, const ArgSlice& args, const Read else range_res = OpRange(op_args, key, range_opts); if (range_res) { - response[i] = std::move(range_res.value()); + dest = std::move(range_res.value()); } } @@ -1352,15 +1352,17 @@ struct GroupConsumerPairOpts { string_view consumer; }; -vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpArgs& op_args, +vector OpGetGroupConsumerPairs(const ShardArgs& shard_args, + const OpArgs& op_args, const GroupConsumerPairOpts& opts) { - vector sid_items(slice_args.size()); - + vector sid_items(shard_args.Size()); + unsigned index = 0; // get group and consumer - for (size_t i = 0; i < slice_args.size(); i++) { - string_view key = slice_args[i]; + for (string_view key : shard_args) { streamCG* group = nullptr; streamConsumer* consumer = nullptr; + auto& dest = sid_items[index++]; + auto group_res = FindGroup(op_args, key, opts.group); if (!group_res) { continue; @@ -1376,7 +1378,7 @@ vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA consumer = streamCreateConsumer(group, op_args.shard->tmp_str1, NULL, 0, SCC_NO_NOTIFY | SCC_NO_DIRTIFY); } - sid_items[i] = {group, consumer}; + dest = {group, consumer}; } return sid_items; } @@ -2988,12 +2990,7 @@ void XReadImpl(CmdArgList args, std::optional opts, ConnectionContext* vector& results = xread_resp[sid]; - ArgSlice slice = cntx->transaction->GetShardArgs(sid); - - DCHECK(!slice.empty()); - DCHECK_EQ(slice.size(), results.size()); - - for (size_t i = 0; i < slice.size(); ++i) { + for (size_t i = 0; i < results.size(); ++i) { if (results[i].size() == 0) { continue; } @@ -3039,7 +3036,7 @@ void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) { vector> res_pairs(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { auto sid = shard->shard_id(); - auto s_args = t->GetShardArgs(sid); + ShardArgs s_args = t->GetShardArgs(sid); GroupConsumerPairOpts gc_opts = {opts->group_name, opts->consumer_name}; res_pairs[sid] = OpGetGroupConsumerPairs(s_args, t->GetOpArgs(shard), gc_opts); @@ -3057,11 +3054,12 @@ void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) { if (s_item.size() == 0) { continue; } - for (size_t j = 0; j < s_args.size(); j++) { - string_view key = s_args[j]; + unsigned index = 0; + for (string_view key : s_args) { StreamIDsItem& item = opts->stream_ids.at(key); - item.consumer = s_item[j].consumer; - item.group = s_item[j].group; + item.consumer = s_item[index].consumer; + item.group = s_item[index].group; + ++index; } } } diff --git a/src/server/stream_family_test.cc b/src/server/stream_family_test.cc index 55ce88d10..2b2f2d024 100644 --- a/src/server/stream_family_test.cc +++ b/src/server/stream_family_test.cc @@ -109,8 +109,9 @@ TEST_F(StreamFamilyTest, Range) { } TEST_F(StreamFamilyTest, GroupCreate) { - Run({"xadd", "key", "1-*", "f1", "v1"}); - auto resp = Run({"xgroup", "create", "key", "grname", "1"}); + auto resp = Run({"xadd", "key", "1-*", "f1", "v1"}); + EXPECT_EQ(resp, "1-0"); + resp = Run({"xgroup", "create", "key", "grname", "1"}); EXPECT_EQ(resp, "OK"); resp = Run({"xgroup", "create", "test", "test", "0"}); EXPECT_THAT(resp, ErrArg("requires the key to exist")); diff --git a/src/server/string_family.cc b/src/server/string_family.cc index dd7cb4975..c8ed6edcc 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -278,19 +278,23 @@ int64_t AbsExpiryToTtl(int64_t abs_expiry_time, bool as_milli) { } // Returns true if keys were set, false otherwise. -void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) { - DCHECK(!args.empty() && args.size() % 2 == 0); +void OpMSet(const OpArgs& op_args, const ShardArgs& args, atomic_bool* success) { + DCHECK(!args.Empty() && args.Size() % 2 == 0); SetCmd::SetParams params; SetCmd sg(op_args, false); - size_t i = 0; - for (; i < args.size(); i += 2) { - DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1]; - if (sg.Set(params, args[i], args[i + 1]) != OpStatus::OK) { // OOM for example. + size_t index = 0; + for (auto it = args.begin(); it != args.end(); ++it) { + string_view key = *it; + ++it; + string_view value = *it; + DVLOG(1) << "MSet " << key << ":" << value; + if (sg.Set(params, key, value) != OpStatus::OK) { // OOM for example. success->store(false); break; } + index += 2; } if (auto journal = op_args.shard->journal(); journal) { @@ -298,14 +302,14 @@ void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) { // we replicate only what was changed. string_view cmd; ArgSlice cmd_args; - if (i == 0) { + if (index == 0) { // All shards must record the tx was executed for the replica to execute it, so we send a PING // in case nothing was changed cmd = "PING"; } else { // journal [0, i) cmd = "MSET"; - cmd_args = ArgSlice(&args[0], i); + cmd_args = ArgSlice(args.begin(), index); } RecordJournal(op_args, cmd, cmd_args, op_args.tx->GetUniqueShardCnt()); } @@ -419,27 +423,29 @@ OpResult> OpThrottle(const OpArgs& op_args, const string_view SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction* t, EngineShard* shard) { - auto keys = t->GetShardArgs(shard->shard_id()); - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); auto& db_slice = shard->db_slice(); - SinkReplyBuilder::MGetResponse response(keys.size()); - absl::InlinedVector iters(keys.size()); + SinkReplyBuilder::MGetResponse response(keys.Size()); + absl::InlinedVector iters(keys.Size()); size_t total_size = 0; - for (size_t i = 0; i < keys.size(); ++i) { - auto it_res = db_slice.FindAndFetchReadOnly(t->GetDbContext(), keys[i], OBJ_STRING); + unsigned index = 0; + for (string_view key : keys) { + auto it_res = db_slice.FindAndFetchReadOnly(t->GetDbContext(), key, OBJ_STRING); + auto& dest = iters[index++]; if (!it_res) continue; - iters[i] = *it_res; + dest = *it_res; total_size += (*it_res)->second.Size(); } response.storage_list = SinkReplyBuilder::AllocMGetStorage(total_size); char* next = response.storage_list->data; - for (size_t i = 0; i < keys.size(); ++i) { + for (size_t i = 0; i < iters.size(); ++i) { auto it = iters[i]; if (it.is_done()) continue; @@ -1139,12 +1145,7 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) { res.storage_list = src.storage_list; src.storage_list = nullptr; - ArgSlice slice = transaction->GetShardArgs(sid); - - DCHECK(!slice.empty()); - DCHECK_EQ(slice.size(), src.resp_arr.size()); - - for (size_t j = 0; j < slice.size(); ++j) { + for (size_t j = 0; j < src.resp_arr.size(); ++j) { if (!src.resp_arr[j]) continue; @@ -1173,7 +1174,7 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) { atomic_bool success = true; auto cb = [&](Transaction* t, EngineShard* shard) { - auto args = t->GetShardArgs(shard->shard_id()); + ShardArgs args = t->GetShardArgs(shard->shard_id()); OpMSet(t->GetOpArgs(shard), args, &success); return OpStatus::OK; }; @@ -1193,8 +1194,9 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* es) { auto args = t->GetShardArgs(es->shard_id()); - for (size_t i = 0; i < args.size(); i += 2) { - auto it = es->db_slice().FindReadOnly(t->GetDbContext(), args[i]).it; + for (auto arg_it = args.begin(); arg_it != args.end(); ++arg_it) { + auto it = es->db_slice().FindReadOnly(t->GetDbContext(), *arg_it).it; + ++arg_it; if (IsValid(it)) { exists.store(true, memory_order_relaxed); break; diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 7fcbd2d60..3ae781eec 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -182,8 +182,6 @@ void Transaction::InitGlobal() { } void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vector* out) { - auto args = full_args_; - auto& shard_index = *out; auto add = [this, rev_mapping = key_index.has_reverse_mapping, &shard_index](uint32_t sid, @@ -196,14 +194,14 @@ void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vectorsize())); + DCHECK_EQ(unique_shard_id_, Shard(akey, shard_set->size())); else { - unique_slot_checker_.Add(kv_args_.front()); - unique_shard_id_ = Shard(kv_args_.front(), shard_set->size()); + unique_slot_checker_.Add(akey); + unique_shard_id_ = Shard(akey, shard_set->size()); } // Multi transactions that execute commands on their own (not stubs) can't shrink the backing @@ -1178,7 +1175,7 @@ bool Transaction::CancelShardCb(EngineShard* shard) { } // runs in engine-shard thread. -ArgSlice Transaction::GetShardArgs(ShardId sid) const { +ShardArgs Transaction::GetShardArgs(ShardId sid) const { DCHECK(!multi_ || multi_->role != SQUASHER); // We can read unique_shard_cnt_ only because ShardArgsInShard is called after IsArmedInShard @@ -1188,7 +1185,7 @@ ArgSlice Transaction::GetShardArgs(ShardId sid) const { } const auto& sd = shard_data_[sid]; - return ArgSlice{kv_args_.data() + sd.arg_start, sd.arg_count}; + return ShardArgs{kv_args_.data() + sd.arg_start, sd.arg_count}; } // from local index back to original arg index skipping the command. @@ -1253,7 +1250,7 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p return result; } -OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc) { +OpStatus Transaction::WatchInShard(const ShardArgs& keys, EngineShard* shard, KeyReadyChecker krc) { auto& sd = shard_data_[SidToId(shard->shard_id())]; CHECK_EQ(0, sd.local_mask & SUSPENDED_Q); @@ -1261,12 +1258,12 @@ OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyCh sd.local_mask &= ~OUT_OF_ORDER; shard->EnsureBlockingController()->AddWatched(keys, std::move(krc), this); - DVLOG(2) << "WatchInShard " << DebugId() << ", first_key:" << keys.front(); + DVLOG(2) << "WatchInShard " << DebugId() << ", first_key:" << keys.Front(); return OpStatus::OK; } -void Transaction::ExpireShardCb(ArgSlice wkeys, EngineShard* shard) { +void Transaction::ExpireShardCb(const ShardArgs& wkeys, EngineShard* shard) { // Blocking transactions don't release keys when suspending, release them now. auto lock_args = GetLockArgs(shard->shard_id()); shard->db_slice().Release(LockMode(), lock_args); @@ -1369,9 +1366,9 @@ bool Transaction::NotifySuspended(TxId committed_txid, ShardId sid, string_view CHECK_EQ(sd.local_mask & AWAKED_Q, 0); // Find index of awakened key - auto args = GetShardArgs(sid); - auto it = find_if(args.begin(), args.end(), [key](auto arg) { return facade::ToSV(arg) == key; }); - CHECK(it != args.end()); + ShardArgs args = GetShardArgs(sid); + auto it = find_if(args.cbegin(), args.cend(), [key](string_view arg) { return arg == key; }); + CHECK(it != args.cend()); // Change state to awaked and store index of awakened key sd.local_mask &= ~SUSPENDED_Q; @@ -1427,7 +1424,7 @@ void Transaction::LogAutoJournalOnShard(EngineShard* shard, RunnableResult resul if (unique_shard_cnt_ == 1 || kv_args_.empty()) { entry_payload = make_pair(cmd, full_args_); } else { - entry_payload = make_pair(cmd, GetShardArgs(shard->shard_id())); + entry_payload = make_pair(cmd, GetShardArgs(shard->shard_id()).AsSlice()); } LogJournalOnShard(shard, std::move(entry_payload), unique_shard_cnt_, false, true); } diff --git a/src/server/transaction.h b/src/server/transaction.h index 0462b3646..07ffc6260 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -22,6 +22,7 @@ #include "server/common.h" #include "server/journal/types.h" #include "server/table.h" +#include "server/tx_base.h" #include "util/fibers/synchronization.h" namespace dfly { @@ -129,8 +130,9 @@ class Transaction { // Runnable that is run on shards during hop executions (often named callback). // Callacks should return `OpStatus` which is implicitly converitble to `RunnableResult`! using RunnableType = absl::FunctionRef; + // Provides keys to block on for specific shard. - using WaitKeysProvider = std::function; + using WaitKeysProvider = std::function; // Modes in which a multi transaction can run. enum MultiMode { @@ -176,7 +178,7 @@ class Transaction { OpStatus InitByArgs(DbIndex index, CmdArgList args); // Get command arguments for specific shard. Called from shard thread. - ArgSlice GetShardArgs(ShardId sid) const; + ShardArgs GetShardArgs(ShardId sid) const; // Map arg_index from GetShardArgs slice to index in original command slice from InitByArgs. size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const; @@ -511,12 +513,12 @@ class Transaction { void RunCallback(EngineShard* shard); // Adds itself to watched queue in the shard. Must run in that shard thread. - OpStatus WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc); + OpStatus WatchInShard(const ShardArgs& keys, EngineShard* shard, KeyReadyChecker krc); // Expire blocking transaction, unlock keys and unregister it from the blocking controller void ExpireBlocking(WaitKeysProvider wcb); - void ExpireShardCb(ArgSlice wkeys, EngineShard* shard); + void ExpireShardCb(const ShardArgs& wkeys, EngineShard* shard); // Returns true if we need to follow up with PollExecution on this shard. bool CancelShardCb(EngineShard* shard); @@ -577,7 +579,6 @@ class Transaction { }); } - private: // Used for waiting for all hop callbacks to run. util::fb2::EmbeddedBlockingCounter run_barrier_{0}; diff --git a/src/server/tx_base.cc b/src/server/tx_base.cc new file mode 100644 index 000000000..5f1993c27 --- /dev/null +++ b/src/server/tx_base.cc @@ -0,0 +1,61 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/tx_base.h" + +#include "base/logging.h" +#include "server/cluster/cluster_defs.h" +#include "server/engine_shard_set.h" +#include "server/journal/journal.h" +#include "server/transaction.h" + +namespace dfly { + +using namespace std; + +void RecordJournal(const OpArgs& op_args, string_view cmd, ArgSlice args, uint32_t shard_cnt, + bool multi_commands) { + VLOG(2) << "Logging command " << cmd << " from txn " << op_args.tx->txid(); + op_args.tx->LogJournalOnShard(op_args.shard, make_pair(cmd, args), shard_cnt, multi_commands, + false); +} + +void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt) { + op_args.tx->FinishLogJournalOnShard(op_args.shard, shard_cnt); +} + +void RecordExpiry(DbIndex dbid, string_view key) { + auto journal = EngineShard::tlocal()->journal(); + CHECK(journal); + journal->RecordEntry(0, journal::Op::EXPIRED, dbid, 1, cluster::KeySlot(key), + make_pair("DEL", ArgSlice{key}), false); +} + +void TriggerJournalWriteToSink() { + auto journal = EngineShard::tlocal()->journal(); + CHECK(journal); + journal->RecordEntry(0, journal::Op::NOOP, 0, 0, nullopt, {}, true); +} + +std::ostream& operator<<(std::ostream& os, ArgSlice list) { + os << "["; + if (!list.empty()) { + std::for_each(list.begin(), list.end() - 1, [&os](const auto& val) { os << val << ", "; }); + os << (*(list.end() - 1)); + } + return os << "]"; +} + +LockTag::LockTag(std::string_view key) { + if (LockTagOptions::instance().enabled) + str_ = LockTagOptions::instance().Tag(key); + else + str_ = key; +} + +LockFp LockTag::Fingerprint() const { + return XXH64(str_.data(), str_.size(), 0x1C69B3F74AC4AE35UL); +} + +} // namespace dfly diff --git a/src/server/tx_base.h b/src/server/tx_base.h new file mode 100644 index 000000000..63f40fe20 --- /dev/null +++ b/src/server/tx_base.h @@ -0,0 +1,172 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +#include + +#include "src/facade/facade_types.h" + +namespace dfly { + +class EngineShard; +class Transaction; + +using DbIndex = uint16_t; +using ShardId = uint16_t; +using LockFp = uint64_t; // a key fingerprint used by the LockTable. + +using ArgSlice = absl::Span; + +constexpr DbIndex kInvalidDbId = DbIndex(-1); +constexpr ShardId kInvalidSid = ShardId(-1); +constexpr DbIndex kMaxDbId = 1024; // Reasonable starting point. + +struct KeyLockArgs { + DbIndex db_index = 0; + absl::Span fps; +}; + +// Describes key indices. +struct KeyIndex { + unsigned start; + unsigned end; // does not include this index (open limit). + unsigned step; // 1 for commands like mget. 2 for commands like mset. + + // if index is non-zero then adds another key index (usually 0). + // relevant for for commands like ZUNIONSTORE/ZINTERSTORE for destination key. + std::optional bonus{}; + bool has_reverse_mapping = false; + + KeyIndex(unsigned s = 0, unsigned e = 0, unsigned step = 0) : start(s), end(e), step(step) { + } + + static KeyIndex Range(unsigned start, unsigned end, unsigned step = 1) { + return KeyIndex{start, end, step}; + } + + bool HasSingleKey() const { + return !bonus && (start + step >= end); + } + + unsigned num_args() const { + return end - start + bool(bonus); + } +}; + +struct DbContext { + DbIndex db_index = 0; + uint64_t time_now_ms = 0; +}; + +struct OpArgs { + EngineShard* shard; + const Transaction* tx; + DbContext db_cntx; + + OpArgs() : shard(nullptr), tx(nullptr) { + } + + OpArgs(EngineShard* s, const Transaction* tx, const DbContext& cntx) + : shard(s), tx(tx), db_cntx(cntx) { + } +}; + +// A strong type for a lock tag. Helps to disambiguate between keys and the parts of the +// keys that are used for locking. +class LockTag { + std::string_view str_; + + public: + using is_stackonly = void; // marks that this object does not use heap. + + LockTag() = default; + explicit LockTag(std::string_view key); + + explicit operator std::string_view() const { + return str_; + } + + LockFp Fingerprint() const; + + // To make it hashable. + template friend H AbslHashValue(H h, const LockTag& tag) { + return H::combine(std::move(h), tag.str_); + } + + bool operator==(const LockTag& o) const { + return str_ == o.str_; + } +}; + +// Checks whether the touched key is valid for a blocking transaction watching it. +using KeyReadyChecker = + std::function; + +// References arguments in another array. +using IndexSlice = std::pair; // (begin, end) + +class ShardArgs : protected ArgSlice { + public: + using ArgSlice::ArgSlice; + using ArgSlice::at; + using ArgSlice::operator=; + using Iterator = ArgSlice::iterator; + + ShardArgs(const ArgSlice& o) : ArgSlice(o) { + } + + size_t Size() const { + return ArgSlice::size(); + } + + auto cbegin() const { + return ArgSlice::cbegin(); + } + + auto cend() const { + return ArgSlice::cend(); + } + + auto begin() const { + return cbegin(); + } + + auto end() const { + return cend(); + } + + bool Empty() const { + return ArgSlice::empty(); + } + + std::string_view Front() const { + return *cbegin(); + } + + ArgSlice AsSlice() const { + return ArgSlice(*this); + } +}; + +// Record non auto journal command with own txid and dbid. +void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args, + uint32_t shard_cnt = 1, bool multi_commands = false); + +// Record non auto journal command finish. Call only when command translates to multi commands. +void RecordJournalFinish(const OpArgs& op_args, uint32_t shard_cnt); + +// Record expiry in journal with independent transaction. Must be called from shard thread holding +// key. +void RecordExpiry(DbIndex dbid, std::string_view key); + +// Trigger journal write to sink, no journal record will be added to journal. +// Must be called from shard thread of journal to sink. +void TriggerJournalWriteToSink(); + +std::ostream& operator<<(std::ostream& os, ArgSlice list); + +} // namespace dfly diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 77147e94a..22dd8e045 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -822,39 +822,43 @@ double GetKeyWeight(Transaction* t, ShardId shard_id, const vector& weig OpResult OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type, const vector& weights, bool store) { - ArgSlice keys = t->GetShardArgs(shard->shard_id()); - DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys; - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); unsigned cmdargs_keys_offset = 1; // after {numkeys} for ZUNION unsigned removed_keys = 0; + ShardArgs::Iterator start = keys.begin(), end = keys.end(); + if (store) { // first global index is 2 after {destkey, numkeys}. ++cmdargs_keys_offset; - if (keys.front() == dest) { - keys.remove_prefix(1); + if (*start == dest) { + ++start; ++removed_keys; } // In case ONLY the destination key is hosted in this shard no work on this shard should be // done in this step - if (keys.empty()) { + if (start == end) { return OpStatus::OK; } } auto& db_slice = shard->db_slice(); - KeyIterWeightVec key_weight_vec(keys.size()); - for (unsigned j = 0; j < keys.size(); ++j) { - auto it_res = db_slice.FindReadOnly(t->GetDbContext(), keys[j], OBJ_ZSET); - if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1. + KeyIterWeightVec key_weight_vec(keys.Size() - removed_keys); + unsigned index = 0; + for (; start != end; ++start) { + auto it_res = db_slice.FindReadOnly(t->GetDbContext(), *start, OBJ_ZSET); + if (it_res == OpStatus::WRONG_TYPE) // TODO: support SET type with default score 1. return it_res.status(); - if (!it_res) + if (!it_res) { + ++index; continue; - - key_weight_vec[j] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys, - cmdargs_keys_offset)}; + } + key_weight_vec[index] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, + index + removed_keys, cmdargs_keys_offset)}; + ++index; } return UnionShardKeysWithScore(key_weight_vec, agg_type); @@ -871,46 +875,48 @@ ScoredMap ZSetFromSet(const PrimeValue& pv, double weight) { OpResult OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type, const vector& weights, bool store) { - ArgSlice keys = t->GetShardArgs(shard->shard_id()); - DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys; - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); unsigned removed_keys = 0; unsigned cmdargs_keys_offset = 1; + ShardArgs::Iterator start = keys.begin(), end = keys.end(); if (store) { // first global index is 2 after {destkey, numkeys}. ++cmdargs_keys_offset; - if (keys.front() == dest) { - keys.remove_prefix(1); + if (*start == dest) { + ++start; ++removed_keys; - } - // In case ONLY the destination key is hosted in this shard no work on this shard should be - // done in this step - if (keys.empty()) { - return OpStatus::SKIPPED; + // In case ONLY the destination key is hosted in this shard no work on this shard should be + // done in this step + if (start == end) { + return OpStatus::SKIPPED; + } } } auto& db_slice = shard->db_slice(); - vector> it_arr(keys.size()); - if (it_arr.empty()) // could be when only the dest key is hosted in this shard - return OpStatus::SKIPPED; // return noop + vector> it_arr(keys.Size() - removed_keys); - for (unsigned j = 0; j < keys.size(); ++j) { - auto it_res = db_slice.FindMutable(t->GetDbContext(), keys[j]); - if (!IsValid(it_res.it)) + unsigned index = 0; + for (; start != end; ++start) { + auto it_res = db_slice.FindMutable(t->GetDbContext(), *start); + if (!IsValid(it_res.it)) { + ++index; continue; // we exit in the next loop + } // sets are supported for ZINTER* commands: auto obj_type = it_res.it->second.ObjType(); if (obj_type != OBJ_ZSET && obj_type != OBJ_SET) return OpStatus::WRONG_TYPE; - it_arr[j] = {std::move(it_res), GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys, - cmdargs_keys_offset)}; + it_arr[index] = {std::move(it_res), GetKeyWeight(t, shard->shard_id(), weights, + index + removed_keys, cmdargs_keys_offset)}; + ++index; } ScoredMap result; @@ -1343,16 +1349,15 @@ void BZPopMinMax(CmdArgList args, ConnectionContext* cntx, bool is_max) { } vector OpFetch(EngineShard* shard, Transaction* t) { - ArgSlice keys = t->GetShardArgs(shard->shard_id()); - DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << keys; - DCHECK(!keys.empty()); + ShardArgs keys = t->GetShardArgs(shard->shard_id()); + DCHECK(!keys.Empty()); vector results; - results.reserve(keys.size()); + results.reserve(keys.Size()); auto& db_slice = shard->db_slice(); - for (size_t i = 0; i < keys.size(); ++i) { - auto it = db_slice.FindReadOnly(t->GetDbContext(), keys[i], OBJ_ZSET); + for (string_view key : keys) { + auto it = db_slice.FindReadOnly(t->GetDbContext(), key, OBJ_ZSET); if (!it) { results.push_back({}); continue;