From 1ce3f983c9bb8ebc41c26e99f6abd45f9eb153a8 Mon Sep 17 00:00:00 2001 From: Shahar Mike Date: Mon, 11 Dec 2023 10:07:53 +0200 Subject: [PATCH] WIP: Auto `PostUpdate()` (#2268) * WIP: Auto `PostUpdate()` * More `Find()` uses * Final touches * Fixes * Fix bug and allow reassigning * Rename to AutoUpdater * Fix and add DCHECK * Also check deletion count * Use ccache instead of sccache * Try to upgrade Helio * off64_t * off64_t * Revert changes to CI --- src/core/dash.h | 14 ++-- src/server/bitops_family.cc | 10 ++- src/server/db_slice.cc | 78 +++++++++++++++++++ src/server/db_slice.h | 51 ++++++++++++ src/server/hll_family.cc | 6 +- src/server/hset_family.cc | 38 ++++----- src/server/json_family.cc | 19 +++-- src/server/list_family.cc | 69 ++++++++--------- src/server/search/doc_index.cc | 2 +- src/server/set_family.cc | 62 ++++++++------- src/server/stream_family.cc | 138 ++++++++++++++++++--------------- src/server/string_family.cc | 27 ++++--- src/server/table.h | 7 +- src/server/zset_family.cc | 108 ++++++++++++++------------ 14 files changed, 391 insertions(+), 238 deletions(-) diff --git a/src/core/dash.h b/src/core/dash.h index d99b3a400..bd2b05e09 100644 --- a/src/core/dash.h +++ b/src/core/dash.h @@ -325,6 +325,9 @@ class DashTable<_Key, _Value, Policy>::Iterator { public: using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; + using IteratorPairType = + std::conditional_t, + detail::IteratorPair>; // Copy constructor from iterator to const_iterator. template ::Iterator { return *this; } - detail::IteratorPair operator->() { + IteratorPairType operator->() const { auto* seg = owner_->segment_[seg_id_]; - return detail::IteratorPair{seg->Key(bucket_id_, slot_id_), - seg->Value(bucket_id_, slot_id_)}; - } - - const detail::IteratorPair operator->() const { - auto* seg = owner_->segment_[seg_id_]; - return detail::IteratorPair{seg->Key(bucket_id_, slot_id_), - seg->Value(bucket_id_, slot_id_)}; + return {seg->Key(bucket_id_, slot_id_), seg->Value(bucket_id_, slot_id_)}; } // Make it self-contained. Does not need container::end(). diff --git a/src/server/bitops_family.cc b/src/server/bitops_family.cc index d9159b7d7..fc0b43791 100644 --- a/src/server/bitops_family.cc +++ b/src/server/bitops_family.cc @@ -314,7 +314,7 @@ class ElementAccess { }; std::optional ElementAccess::Exists(EngineShard* shard) { - auto res = shard->db_slice().Find(context_, key_, OBJ_STRING); + auto res = shard->db_slice().FindReadOnly(context_, key_, OBJ_STRING); if (res.status() == OpStatus::WRONG_TYPE) { return {}; } @@ -458,7 +458,8 @@ OpResult RunBitOpNot(const OpArgs& op_args, ArgSlice keys) { EngineShard* es = op_args.shard; // if we found the value, just return, if not found then skip, otherwise report an error auto key = keys.front(); - OpResult find_res = es->db_slice().Find(op_args.db_cntx, key, OBJ_STRING); + OpResult find_res = + es->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (find_res) { return GetString(find_res.value()->second, es); } else { @@ -479,7 +480,8 @@ OpResult RunBitOpOnShard(std::string_view op, const OpArgs& op_args // collect all the value for this shard for (auto& key : keys) { - OpResult find_res = es->db_slice().Find(op_args.db_cntx, key, OBJ_STRING); + OpResult find_res = + es->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (find_res) { values.emplace_back(GetString(find_res.value()->second, es)); } else { @@ -1261,7 +1263,7 @@ OpResult ReadValueBitsetAt(const OpArgs& op_args, std::string_view key, ui OpResult ReadValue(const DbContext& context, std::string_view key, EngineShard* shard) { - OpResult it_res = shard->db_slice().Find(context, key, OBJ_STRING); + OpResult it_res = shard->db_slice().FindReadOnly(context, key, OBJ_STRING); if (!it_res.ok()) { return it_res.status(); } diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 556191e6e..7241b16d5 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -338,6 +338,83 @@ auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) return it; } +DbSlice::AutoUpdater::AutoUpdater() { +} + +DbSlice::AutoUpdater::AutoUpdater(AutoUpdater&& o) { + *this = std::move(o); +} + +DbSlice::AutoUpdater& DbSlice::AutoUpdater::operator=(AutoUpdater&& o) { + Run(); + fields_ = o.fields_; + o.Cancel(); + return *this; +} + +DbSlice::AutoUpdater::~AutoUpdater() { + Run(); +} + +void DbSlice::AutoUpdater::Run() { + if (fields_.action == DestructorAction::kDoNothing) { + return; + } + + // Check that AutoUpdater does not run after a key was removed. + // If this CHECK() failed for you, it probably means that you deleted a key while having an auto + // updater in scope. You'll probably want to call Run() (or Cancel() - but be careful). + DCHECK(IsValid(fields_.db_slice->db_arr_[fields_.db_ind]->prime.Find(fields_.key))) + << "Key was removed before PostUpdate() - this is a bug!"; + + // Make sure that the DB has not changed in size since this object was created. + // Adding or removing elements from the DB may invalidate iterators. + CHECK_EQ(fields_.db_size, fields_.db_slice->DbSize(fields_.db_ind)) + << "Attempting to run post-update after DB was modified"; + + CHECK_EQ(fields_.deletion_count, fields_.db_slice->deletion_count_) + << "Attempting to run post-update after a deletion was issued"; + + DCHECK(fields_.action == DestructorAction::kRun); + CHECK_NE(fields_.db_slice, nullptr); + + fields_.db_slice->PostUpdate(fields_.db_ind, fields_.it, fields_.key, fields_.key_existed); + Cancel(); // Reset to not run again +} + +void DbSlice::AutoUpdater::Cancel() { + this->fields_ = {}; +} + +DbSlice::AutoUpdater::AutoUpdater(const Fields& fields) : fields_(fields) { + DCHECK(fields_.action == DestructorAction::kRun); + fields_.db_slice->PreUpdate(fields_.db_ind, fields_.it); + fields_.db_size = fields_.db_slice->DbSize(fields_.db_ind); + fields_.deletion_count = fields_.db_slice->deletion_count_; +} + +OpResult DbSlice::FindMutable(const Context& cntx, string_view key, + unsigned req_obj_type) { + // TODO(#2252): Call an internal find version that does not handle post updates + auto it = FindExt(cntx, key).first; + + if (!IsValid(it)) + return OpStatus::KEY_NOTFOUND; + + if (it->second.ObjType() != req_obj_type) { + return OpStatus::WRONG_TYPE; + } + + return { + {it, AutoUpdater({AutoUpdater::DestructorAction::kRun, this, cntx.db_index, it, key, true})}}; +} + +auto DbSlice::FindReadOnly(const Context& cntx, string_view key, unsigned req_obj_type) const + -> OpResult { + auto res = Find(cntx, key, req_obj_type); + return res.ok() ? OpResult(res.value()) : res.status(); +} + pair DbSlice::FindExt(const Context& cntx, string_view key) const { pair res; @@ -562,6 +639,7 @@ bool DbSlice::Del(DbIndex db_ind, PrimeIterator it) { } PerformDeletion(it, shard_owner(), db.get()); + deletion_count_++; return true; } diff --git a/src/server/db_slice.h b/src/server/db_slice.h index b2ce97add..d413fd512 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -63,6 +63,44 @@ class DbSlice { void operator=(const DbSlice&) = delete; public: + class AutoUpdater { + public: + AutoUpdater(); + AutoUpdater(AutoUpdater&& o); + AutoUpdater& operator=(AutoUpdater&& o); + ~AutoUpdater(); + + void Run(); + void Cancel(); + + private: + enum class DestructorAction { + kDoNothing, + kRun, + }; + + // Wrap members in a struct to auto generate operator= + struct Fields { + DestructorAction action = DestructorAction::kDoNothing; + + DbSlice* db_slice = nullptr; + DbIndex db_ind = 0; + PrimeIterator it; + std::string_view key; + bool key_existed = false; + + size_t db_size = 0; + size_t deletion_count = 0; + // TODO(#2252): Add heap size here, and only update memory in d'tor + }; + + AutoUpdater(const Fields& fields); + + friend class DbSlice; + + Fields fields_ = {}; + }; + struct Stats { // DbStats db; std::vector db_stats; @@ -148,9 +186,20 @@ class DbSlice { return ExpirePeriod{time_ms - expire_base_[0]}; } + // TODO(#2252): Remove this in favor of FindMutable() / FindReadOnly() OpResult Find(const Context& cntx, std::string_view key, unsigned req_obj_type) const; + struct ItAndUpdater { + PrimeIterator it; + AutoUpdater post_updater; + }; + OpResult FindMutable(const Context& cntx, std::string_view key, + unsigned req_obj_type); + + OpResult FindReadOnly(const Context& cntx, std::string_view key, + unsigned req_obj_type) const; + // Returns (value, expire) dict entries if key exists, null if it does not exist or has expired. std::pair FindExt(const Context& cntx, std::string_view key) const; @@ -260,6 +309,7 @@ class DbSlice { size_t DbSize(DbIndex db_ind) const; // Callback functions called upon writing to the existing key. + // TODO(#2252): Remove these (or make them private) void PreUpdate(DbIndex db_ind, PrimeIterator it); void PostUpdate(DbIndex db_ind, PrimeIterator it, std::string_view key, bool existing_entry = true); @@ -376,6 +426,7 @@ class DbSlice { ssize_t memory_budget_ = SSIZE_MAX; size_t bytes_per_object_ = 0; size_t soft_budget_limit_ = 0; + size_t deletion_count_ = 0; mutable SliceEvents events_; // we may change this even for const operations. diff --git a/src/server/hll_family.cc b/src/server/hll_family.cc index 419bad1e1..368ac1612 100644 --- a/src/server/hll_family.cc +++ b/src/server/hll_family.cc @@ -117,7 +117,7 @@ void PFAdd(CmdArgList args, ConnectionContext* cntx) { OpResult CountHllsSingle(const OpArgs& op_args, string_view key) { auto& db_slice = op_args.shard->db_slice(); - OpResult it = db_slice.Find(op_args.db_cntx, key, OBJ_STRING); + OpResult it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (it.ok()) { string hll; string_view hll_view = it.value()->second.GetSlice(&hll); @@ -150,8 +150,8 @@ OpResult> ReadValues(const OpArgs& op_args, ArgSlice keys) { try { vector values; for (size_t i = 0; i < keys.size(); ++i) { - OpResult it = - op_args.shard->db_slice().Find(op_args.db_cntx, keys[i], OBJ_STRING); + OpResult it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_STRING); if (it.ok()) { string hll; it.value()->second.GetString(&hll); diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index d208e69f6..828bd924f 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -278,17 +278,18 @@ OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t * of returning no or very few elements. (taken from redis code at db.c line 904 */ constexpr size_t INTERATION_FACTOR = 10; - OpResult find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_HASH); + OpResult find_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (!find_res) { DVLOG(1) << "ScanOp: find failed: " << find_res << ", baling out"; return find_res.status(); } - PrimeIterator it = find_res.value(); + PrimeConstIterator it = find_res.value(); StringVec res; uint32_t count = scan_op.limit * HASH_TABLE_ENTRIES_FACTOR; - PrimeValue& pv = it->second; + const PrimeValue& pv = it->second; if (pv.Encoding() == kEncodingListPack) { uint8_t* lp = (uint8_t*)pv.RObjPtr(); @@ -342,15 +343,14 @@ OpResult OpDel(const OpArgs& op_args, string_view key, CmdArgList valu DCHECK(!values.empty()); auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_HASH); if (!it_res) return it_res.status(); - op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, (*it_res)->second); - db_slice.PreUpdate(op_args.db_cntx.db_index, *it_res); + PrimeValue& pv = it_res->it->second; + op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, pv); - PrimeValue& pv = (*it_res)->second; unsigned deleted = 0; bool key_remove = false; DbTableStats* stats = db_slice.MutableStats(op_args.db_cntx.db_index); @@ -387,7 +387,7 @@ OpResult OpDel(const OpArgs& op_args, string_view key, CmdArgList valu } } - db_slice.PostUpdate(op_args.db_cntx.db_index, *it_res, key); + it_res->post_updater.Run(); if (!key_remove) op_args.shard->search_indices()->AddDoc(key, op_args.db_cntx, pv); @@ -396,7 +396,7 @@ OpResult OpDel(const OpArgs& op_args, string_view key, CmdArgList valu if (enc == kEncodingListPack) { stats->listpack_blob_cnt--; } - db_slice.Del(op_args.db_cntx.db_index, *it_res); + db_slice.Del(op_args.db_cntx.db_index, it_res->it); } else if (enc == kEncodingListPack) { stats->listpack_bytes += lpBytes((uint8_t*)pv.RObjPtr()); } @@ -408,12 +408,12 @@ OpResult> OpHMGet(const OpArgs& op_args, std::string_view key, Cm DCHECK(!fields.empty()); auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (!it_res) return it_res.status(); - PrimeValue& pv = (*it_res)->second; + const PrimeValue& pv = (*it_res)->second; std::vector result(fields.size()); @@ -466,7 +466,7 @@ OpResult> OpHMGet(const OpArgs& op_args, std::string_view key, Cm OpResult OpLen(const OpArgs& op_args, string_view key) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (it_res) { return HMapLength(op_args.db_cntx, (*it_res)->second); @@ -479,7 +479,7 @@ OpResult OpLen(const OpArgs& op_args, string_view key) { OpResult OpExist(const OpArgs& op_args, string_view key, string_view field) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (!it_res) { if (it_res.status() == OpStatus::KEY_NOTFOUND) @@ -503,7 +503,7 @@ OpResult OpExist(const OpArgs& op_args, string_view key, string_view field) OpResult OpGet(const OpArgs& op_args, string_view key, string_view field) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (!it_res) return it_res.status(); @@ -531,7 +531,7 @@ OpResult OpGet(const OpArgs& op_args, string_view key, string_view field OpResult> OpGetAll(const OpArgs& op_args, string_view key, uint8_t mask) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (!it_res) { if (it_res.status() == OpStatus::KEY_NOTFOUND) return vector{}; @@ -582,7 +582,7 @@ OpResult> OpGetAll(const OpArgs& op_args, string_view key, uint8_ OpResult OpStrLen(const OpArgs& op_args, string_view key, string_view field) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); if (!it_res) { if (it_res.status() == OpStatus::KEY_NOTFOUND) @@ -1062,7 +1062,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { auto& db_slice = shard->db_slice(); DbContext db_context = t->GetDbContext(); - auto it_res = db_slice.Find(db_context, key, OBJ_HASH); + auto it_res = db_slice.FindReadOnly(db_context, key, OBJ_HASH); if (!it_res) return it_res.status(); @@ -1097,7 +1097,9 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { } if (string_map->Empty()) { - db_slice.Del(db_context.db_index, *it_res); + auto it_mutable = db_slice.FindMutable(db_context, key, OBJ_HASH); + it_mutable->post_updater.Run(); + db_slice.Del(db_context.db_index, it_mutable->it); return facade::OpStatus::KEY_NOTFOUND; } } else if (pv.Encoding() == kEncodingListPack) { diff --git a/src/server/json_family.cc b/src/server/json_family.cc index 996b7773c..15a65abce 100644 --- a/src/server/json_family.cc +++ b/src/server/json_family.cc @@ -146,21 +146,18 @@ error_code JsonReplace(JsonType& instance, string_view path, JsonReplaceCb callb OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_view path, JsonReplaceCb callback, JsonReplaceVerify verify_op = JsonReplaceVerifyNoOp) { - OpResult it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON); + auto it_res = op_args.shard->db_slice().FindMutable(op_args.db_cntx, key, OBJ_JSON); if (!it_res.ok()) { return it_res.status(); } - PrimeIterator entry_it = it_res.value(); - auto& db_slice = op_args.shard->db_slice(); - auto db_index = op_args.db_cntx.db_index; + PrimeConstIterator entry_it = it_res->it; JsonType* json_val = entry_it->second.GetJson(); DCHECK(json_val) << "should have a valid JSON object for key '" << key << "' the type for it is '" << entry_it->second.ObjType() << "'"; JsonType& json_entry = *json_val; op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, entry_it->second); - db_slice.PreUpdate(db_index, entry_it); // Run the update operation on this entry error_code ec = JsonReplace(json_entry, path, callback); @@ -172,7 +169,7 @@ OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_vi // Make sure that we don't have other internal issue with the operation OpStatus res = verify_op(json_entry); if (res == OpStatus::OK) { - db_slice.PostUpdate(db_index, entry_it, key); + it_res->post_updater.Run(); op_args.shard->search_indices()->AddDoc(key, op_args.db_cntx, entry_it->second); } @@ -180,7 +177,8 @@ OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_vi } OpResult GetJson(const OpArgs& op_args, string_view key) { - OpResult it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON); + OpResult it_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_JSON); if (!it_res.ok()) return it_res.status(); @@ -982,7 +980,8 @@ vector OpJsonMGet(JsonExpression expression, const Transaction* t, En auto& db_slice = shard->db_slice(); for (size_t i = 0; i < args.size(); ++i) { - OpResult it_res = db_slice.Find(t->GetDbContext(), args[i], OBJ_JSON); + OpResult it_res = + db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_JSON); if (!it_res.ok()) continue; @@ -1068,8 +1067,8 @@ OpResult OpSet(const OpArgs& op_args, string_view key, string_view path, // and its not JSON, it would return an error. if (path == "." || path == "$") { if (is_nx_condition || is_xx_condition) { - OpResult it_res = - op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON); + OpResult it_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_JSON); bool key_exists = (it_res.status() != OpStatus::KEY_NOTFOUND); if (is_nx_condition && key_exists) { return false; diff --git a/src/server/list_family.cc b/src/server/list_family.cc index 436108e7d..a47b81a71 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -184,7 +184,7 @@ std::string OpBPop(Transaction* t, EngineShard* shard, std::string_view key, Lis DVLOG(2) << "popping from " << key << " " << t->DebugId(); auto& db_slice = shard->db_slice(); - auto it_res = db_slice.Find(t->GetDbContext(), key, OBJ_LIST); + auto it_res = db_slice.FindMutable(t->GetDbContext(), key, OBJ_LIST); if (!it_res) { auto messages = debugMessages.All(); @@ -203,14 +203,13 @@ std::string OpBPop(Transaction* t, EngineShard* shard, std::string_view key, Lis CHECK(it_res) << t->DebugId() << " " << key; // must exist and must be ok. - PrimeIterator it = *it_res; + PrimeIterator it = it_res->it; quicklist* ql = GetQL(it->second); absl::StrAppend(debugMessages.Next(), "OpBPop: ", key, " by ", t->DebugId()); - db_slice.PreUpdate(t->GetDbIndex(), it); std::string value = ListPop(dir, ql); - db_slice.PostUpdate(t->GetDbIndex(), it, key); + it_res->post_updater.Run(); if (quicklistCount(ql) == 0) { DVLOG(1) << "deleting key " << key << " " << t->DebugId(); @@ -230,20 +229,18 @@ std::string OpBPop(Transaction* t, EngineShard* shard, std::string_view key, Lis OpResult OpMoveSingleShard(const OpArgs& op_args, string_view src, string_view dest, ListDir src_dir, ListDir dest_dir) { auto& db_slice = op_args.shard->db_slice(); - auto src_res = db_slice.Find(op_args.db_cntx, src, OBJ_LIST); + auto src_res = db_slice.FindMutable(op_args.db_cntx, src, OBJ_LIST); if (!src_res) return src_res.status(); - PrimeIterator src_it = *src_res; + PrimeIterator src_it = src_res->it; quicklist* src_ql = GetQL(src_it->second); if (src == dest) { // simple case. - db_slice.PreUpdate(op_args.db_cntx.db_index, src_it); string val = ListPop(src_dir, src_ql); int pos = (dest_dir == ListDir::LEFT) ? QUICKLIST_HEAD : QUICKLIST_TAIL; quicklistPush(src_ql, val.data(), val.size(), pos); - db_slice.PostUpdate(op_args.db_cntx.db_index, src_it, src); return val; } @@ -252,6 +249,7 @@ OpResult OpMoveSingleShard(const OpArgs& op_args, string_view src, strin PrimeIterator dest_it; bool new_key = false; try { + src_res->post_updater.Run(); tie(dest_it, new_key) = db_slice.AddOrFind(op_args.db_cntx, dest); } catch (bad_alloc&) { return OpStatus::OUT_OF_MEMORY; @@ -265,7 +263,9 @@ OpResult OpMoveSingleShard(const OpArgs& op_args, string_view src, strin dest_it->second.ImportRObj(obj); // Insertion of dest could invalidate src_it. Find it again. - src_it = db_slice.GetTables(op_args.db_cntx.db_index).first->Find(src); + src_res = db_slice.FindMutable(op_args.db_cntx, src, OBJ_LIST); + src_it = src_res->it; + DCHECK(IsValid(src_it)); } else { if (dest_it->second.ObjType() != OBJ_LIST) return OpStatus::WRONG_TYPE; @@ -274,13 +274,11 @@ OpResult OpMoveSingleShard(const OpArgs& op_args, string_view src, strin db_slice.PreUpdate(op_args.db_cntx.db_index, dest_it); } - db_slice.PreUpdate(op_args.db_cntx.db_index, src_it); - string val = ListPop(src_dir, src_ql); int pos = (dest_dir == ListDir::LEFT) ? QUICKLIST_HEAD : QUICKLIST_TAIL; quicklistPush(dest_ql, val.data(), val.size(), pos); - db_slice.PostUpdate(op_args.db_cntx.db_index, src_it, src); + src_res->post_updater.Run(); db_slice.PostUpdate(op_args.db_cntx.db_index, dest_it, dest, !new_key); if (quicklistCount(src_ql) == 0) { @@ -293,7 +291,7 @@ OpResult OpMoveSingleShard(const OpArgs& op_args, string_view src, strin // Read-only peek operation that determines whether the list exists and optionally // returns the first from left/right value without popping it from the list. OpResult Peek(const OpArgs& op_args, string_view key, ListDir dir, bool fetch) { - auto it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); + auto it_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST); if (!it_res) { return it_res.status(); } @@ -321,6 +319,7 @@ OpResult OpPush(const OpArgs& op_args, std::string_view key, ListDir d bool new_key = false; if (skip_notexist) { + // TODO(#2252): Move to FindMutable() once AddOrFindMutable() is implemented auto it_res = es->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); if (!it_res) return 0; // Redis returns 0 for nonexisting keys for the *PUSHX actions. @@ -384,13 +383,12 @@ OpResult OpPush(const OpArgs& op_args, std::string_view key, ListDir d OpResult OpPop(const OpArgs& op_args, string_view key, ListDir dir, uint32_t count, bool return_results, bool journal_rewrite) { auto& db_slice = op_args.shard->db_slice(); - OpResult it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST); if (!it_res) return it_res.status(); - PrimeIterator it = *it_res; + PrimeIterator it = it_res->it; quicklist* ql = GetQL(it->second); - db_slice.PreUpdate(op_args.db_cntx.db_index, it); StringVec res; if (quicklistCount(ql) < count) { @@ -408,7 +406,7 @@ OpResult OpPop(const OpArgs& op_args, string_view key, ListDir dir, u } } - db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); + it_res->post_updater.Run(); if (quicklistCount(ql) == 0) { absl::StrAppend(debugMessages.Next(), "OpPop Del: ", key, " by ", op_args.tx->DebugId()); @@ -486,7 +484,7 @@ OpResult MoveTwoShards(Transaction* trans, string_view src, string_view } OpResult OpLen(const OpArgs& op_args, std::string_view key) { - auto res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); + auto res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST); if (!res) return res.status(); @@ -496,7 +494,7 @@ OpResult OpLen(const OpArgs& op_args, std::string_view key) { } OpResult OpIndex(const OpArgs& op_args, std::string_view key, long index) { - auto res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); + auto res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST); if (!res) return res.status(); quicklist* ql = GetQL(res.value()->second); @@ -520,7 +518,8 @@ OpResult OpIndex(const OpArgs& op_args, std::string_view key, long index OpResult> OpPos(const OpArgs& op_args, std::string_view key, std::string_view element, int rank, int count, int max_len) { - OpResult it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); + OpResult it_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST); if (!it_res.ok()) return it_res.status(); @@ -564,11 +563,11 @@ OpResult> OpPos(const OpArgs& op_args, std::string_view key, OpResult OpInsert(const OpArgs& op_args, string_view key, string_view pivot, string_view elem, int insert_param) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST); if (!it_res) return it_res.status(); - quicklist* ql = GetQL(it_res.value()->second); + quicklist* ql = GetQL(it_res->it->second); quicklistEntry entry = container_utils::QLEntry(); quicklistIter* qiter = quicklistGetIterator(ql, AL_START_HEAD); bool found = false; @@ -582,14 +581,12 @@ OpResult OpInsert(const OpArgs& op_args, string_view key, string_view pivot int res = -1; if (found) { - db_slice.PreUpdate(op_args.db_cntx.db_index, *it_res); if (insert_param == LIST_TAIL) { quicklistInsertAfter(qiter, &entry, elem.data(), elem.size()); } else { DCHECK_EQ(LIST_HEAD, insert_param); quicklistInsertBefore(qiter, &entry, elem.data(), elem.size()); } - db_slice.PostUpdate(op_args.db_cntx.db_index, *it_res, key); res = quicklistCount(ql); } quicklistReleaseIterator(qiter); @@ -598,11 +595,11 @@ OpResult OpInsert(const OpArgs& op_args, string_view key, string_view pivot OpResult OpRem(const OpArgs& op_args, string_view key, string_view elem, long count) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST); if (!it_res) return it_res.status(); - PrimeIterator it = *it_res; + PrimeIterator it = it_res->it; quicklist* ql = GetQL(it->second); int iter_direction = AL_START_HEAD; @@ -618,7 +615,6 @@ OpResult OpRem(const OpArgs& op_args, string_view key, string_view ele unsigned removed = 0; const uint8_t* elem_ptr = reinterpret_cast(elem.data()); - db_slice.PreUpdate(op_args.db_cntx.db_index, it); while (quicklistNext(qiter, &entry)) { if (quicklistCompare(&entry, elem_ptr, elem.size())) { quicklistDelEntry(qiter, &entry); @@ -627,7 +623,8 @@ OpResult OpRem(const OpArgs& op_args, string_view key, string_view ele break; } } - db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); + + it_res->post_updater.Run(); quicklistReleaseIterator(qiter); @@ -640,16 +637,14 @@ OpResult OpRem(const OpArgs& op_args, string_view key, string_view ele OpStatus OpSet(const OpArgs& op_args, string_view key, string_view elem, long index) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST); if (!it_res) return it_res.status(); - PrimeIterator it = *it_res; + PrimeIterator it = it_res->it; quicklist* ql = GetQL(it->second); - db_slice.PreUpdate(op_args.db_cntx.db_index, it); int replaced = quicklistReplaceAtIndex(ql, index, elem.data(), elem.size()); - db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); if (!replaced) { return OpStatus::OUT_OF_RANGE; @@ -659,11 +654,11 @@ OpStatus OpSet(const OpArgs& op_args, string_view key, string_view elem, long in OpStatus OpTrim(const OpArgs& op_args, string_view key, long start, long end) { auto& db_slice = op_args.shard->db_slice(); - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST); if (!it_res) return it_res.status(); - PrimeIterator it = *it_res; + PrimeIterator it = it_res->it; quicklist* ql = GetQL(it->second); long llen = quicklistCount(ql); @@ -690,10 +685,10 @@ OpStatus OpTrim(const OpArgs& op_args, string_view key, long start, long end) { rtrim = llen - end - 1; } - db_slice.PreUpdate(op_args.db_cntx.db_index, it); quicklistDelRange(ql, 0, ltrim); quicklistDelRange(ql, -rtrim, rtrim); - db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); + + it_res->post_updater.Run(); if (quicklistCount(ql) == 0) { CHECK(db_slice.Del(op_args.db_cntx.db_index, it)); @@ -702,7 +697,7 @@ OpStatus OpTrim(const OpArgs& op_args, string_view key, long start, long end) { } OpResult OpRange(const OpArgs& op_args, std::string_view key, long start, long end) { - auto res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); + auto res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST); if (!res) return res.status(); diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index cb794d1a1..01693cf5b 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -213,7 +213,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa size_t expired_count = 0; for (size_t i = 0; i < search_results.ids.size(); i++) { auto key = key_index_.Get(search_results.ids[i]); - auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode()); + auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode()); if (!it || !IsValid(*it)) { // Item must have expired expired_count++; diff --git a/src/server/set_family.cc b/src/server/set_family.cc index f530b597a..b972fc7de 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -695,20 +695,18 @@ OpResult OpRem(const OpArgs& op_args, string_view key, const ArgSlice& bool journal_rewrite) { auto* es = op_args.shard; auto& db_slice = es->db_slice(); - OpResult find_res = db_slice.Find(op_args.db_cntx, key, OBJ_SET); + auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET); if (!find_res) { return find_res.status(); } - db_slice.PreUpdate(op_args.db_cntx.db_index, *find_res); - - CompactObj& co = find_res.value()->second; + CompactObj& co = find_res->it->second; auto [removed, isempty] = RemoveSet(op_args.db_cntx, vals, &co); - db_slice.PostUpdate(op_args.db_cntx.db_index, *find_res, key); + find_res->post_updater.Run(); if (isempty) { - CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res.value())); + CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res->it)); } if (journal_rewrite && op_args.shard->journal()) { vector mapped(vals.size() + 1); @@ -749,7 +747,7 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { for (auto k : largs) { unsigned index = (k == src_) ? 0 : 1; - OpResult res = es->db_slice().Find(t->GetDbContext(), k, OBJ_SET); + OpResult res = es->db_slice().FindReadOnly(t->GetDbContext(), k, OBJ_SET); if (res && index == 0) { // successful src find. DCHECK(!res->is_done()); const CompactObj& val = res.value()->second; @@ -815,10 +813,10 @@ OpResult OpUnion(const OpArgs& op_args, ArgSlice keys) { absl::flat_hash_set uniques; for (string_view key : keys) { - OpResult find_res = - op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_SET); + OpResult find_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET); if (find_res) { - PrimeValue& pv = find_res.value()->second; + const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { StringSet* ss = (StringSet*)pv.RObjPtr(); ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); @@ -843,14 +841,15 @@ OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { DCHECK(!keys.empty()); DVLOG(1) << "OpDiff from " << keys.front(); EngineShard* es = op_args.shard; - OpResult find_res = es->db_slice().Find(op_args.db_cntx, keys.front(), OBJ_SET); + OpResult find_res = + es->db_slice().FindReadOnly(op_args.db_cntx, keys.front(), OBJ_SET); if (!find_res) { return find_res.status(); } absl::flat_hash_set uniques; - PrimeValue& pv = find_res.value()->second; + const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { StringSet* ss = (StringSet*)pv.RObjPtr(); ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); @@ -864,7 +863,8 @@ OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { DCHECK(!uniques.empty()); // otherwise the key would not exist. for (size_t i = 1; i < keys.size(); ++i) { - OpResult diff_res = es->db_slice().Find(op_args.db_cntx, keys[i], OBJ_SET); + OpResult diff_res = + es->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_SET); if (!diff_res) { if (diff_res.status() == OpStatus::WRONG_TYPE) { return OpStatus::WRONG_TYPE; @@ -901,12 +901,12 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f StringVec result; if (keys.size() == 1) { - OpResult find_res = - es->db_slice().Find(t->GetDbContext(), keys.front(), OBJ_SET); + OpResult find_res = + es->db_slice().FindReadOnly(t->GetDbContext(), keys.front(), OBJ_SET); if (!find_res) return find_res.status(); - PrimeValue& pv = find_res.value()->second; + const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { StringSet* ss = (StringSet*)pv.RObjPtr(); ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); @@ -926,7 +926,8 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f OpStatus status = OpStatus::OK; for (size_t i = 0; i < keys.size(); ++i) { - OpResult find_res = es->db_slice().Find(t->GetDbContext(), keys[i], OBJ_SET); + OpResult find_res = + es->db_slice().FindReadOnly(t->GetDbContext(), keys[i], OBJ_SET); if (!find_res) { if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND || find_res.status() != OpStatus::KEY_NOTFOUND) { @@ -976,7 +977,7 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f // count - how many elements to pop. OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count) { auto& db_slice = op_args.shard->db_slice(); - OpResult find_res = db_slice.Find(op_args.db_cntx, key, OBJ_SET); + auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET); if (!find_res) return find_res.status(); @@ -984,7 +985,7 @@ OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count if (count == 0) return result; - PrimeIterator it = find_res.value(); + PrimeIterator it = find_res->it; size_t slen = it->second.Size(); /* CASE 1: @@ -1003,6 +1004,7 @@ OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count }); // Delete the set as it is now empty + find_res->post_updater.Run(); CHECK(db_slice.Del(op_args.db_cntx.db_index, it)); // Replicate as DEL. @@ -1011,7 +1013,6 @@ OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count } } else { SetType st{it->second.RObjPtr(), it->second.Encoding()}; - db_slice.PreUpdate(op_args.db_cntx.db_index, it); if (st.second == kEncodingIntSet) { intset* is = (intset*)st.first; int64_t val = 0; @@ -1035,20 +1036,19 @@ OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count std::copy(result.begin(), result.end(), mapped.begin() + 1); RecordJournal(op_args, "SREM"sv, mapped); } - - db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); } return result; } OpResult OpScan(const OpArgs& op_args, string_view key, uint64_t* cursor, const ScanOpts& scan_op) { - OpResult find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_SET); + OpResult find_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET); if (!find_res) return find_res.status(); - PrimeIterator it = find_res.value(); + PrimeConstIterator it = find_res.value(); StringVec res; if (it->second.Encoding() == kEncodingIntSet) { @@ -1094,7 +1094,8 @@ void SIsMember(CmdArgList args, ConnectionContext* cntx) { string_view val = ArgS(args, 1); auto cb = [&](Transaction* t, EngineShard* shard) { - OpResult find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET); + OpResult find_res = + shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET); if (find_res) { SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; @@ -1125,7 +1126,8 @@ void SMIsMember(CmdArgList args, ConnectionContext* cntx) { memberships.reserve(vals.size()); auto cb = [&](Transaction* t, EngineShard* shard) { - OpResult find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET); + OpResult find_res = + shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET); if (find_res) { SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; FindInSet(memberships, t->GetDbContext(), st, vals); @@ -1191,7 +1193,8 @@ void SCard(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - OpResult find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET); + OpResult find_res = + shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET); if (!find_res) { return find_res.status(); } @@ -1366,12 +1369,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) { const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { StringVec result; - OpResult find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET); + OpResult find_res = + shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET); if (!find_res) { return find_res.status(); } - PrimeValue& pv = find_res.value()->second; + const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { StringSet* ss = (StringSet*)pv.RObjPtr(); ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc index a8435677e..df48b79ab 100644 --- a/src/server/stream_family.cc +++ b/src/server/stream_family.cc @@ -611,6 +611,7 @@ OpResult OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL pair add_res; if (opts.no_mkstream) { + // TODO(#2252): Replace with FindMutable() once AddOrFindMutable() is implemented auto res_it = db_slice.Find(op_args.db_cntx, opts.key, OBJ_STREAM); if (!res_it) { return res_it.status(); @@ -668,7 +669,7 @@ OpResult OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL OpResult OpRange(const OpArgs& op_args, string_view key, const RangeOpts& opts) { auto& db_slice = op_args.shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); + OpResult res_it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); @@ -680,7 +681,7 @@ OpResult OpRange(const OpArgs& op_args, string_view key, const RangeO streamIterator si; int64_t numfields; streamID id; - CompactObj& cobj = (*res_it)->second; + const CompactObj& cobj = (*res_it)->second; stream* s = (stream*)cobj.RObjPtr(); streamID sstart = opts.start.val, send = opts.end.val; @@ -797,6 +798,14 @@ OpResult OpRangeFromConsumerPEL(const OpArgs& op_args, string_view ke return result; } +namespace { +// Our C-API doesn't use const, so we have to const cast. +// Only intended for read-only functions. +stream* GetReadOnlyStream(const CompactObj& cobj) { + return const_cast((const stream*)cobj.RObjPtr()); +} +} // namespace + // Returns a map of stream to the ID of the last entry in the stream. Any // streams not found are omitted from the result. OpResult>> OpLastIDs(const OpArgs& op_args, @@ -807,7 +816,7 @@ OpResult>> OpLastIDs(const OpArgs& op_args, vector> last_ids; for (string_view key : args) { - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); + OpResult res_it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STREAM); if (!res_it) { if (res_it.status() == OpStatus::KEY_NOTFOUND) { continue; @@ -815,8 +824,8 @@ OpResult>> OpLastIDs(const OpArgs& op_args, return res_it.status(); } - CompactObj& cobj = (*res_it)->second; - stream* s = (stream*)cobj.RObjPtr(); + const CompactObj& cobj = (*res_it)->second; + stream* s = GetReadOnlyStream(cobj); streamID last_id = s->last_id; if (s->length) { @@ -869,10 +878,10 @@ vector OpRead(const OpArgs& op_args, const ArgSlice& args, const Read OpResult OpLen(const OpArgs& op_args, string_view key) { auto& db_slice = op_args.shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); + OpResult res_it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); - CompactObj& cobj = (*res_it)->second; + const CompactObj& cobj = (*res_it)->second; stream* s = (stream*)cobj.RObjPtr(); return s->length; } @@ -880,12 +889,12 @@ OpResult OpLen(const OpArgs& op_args, string_view key) { OpResult> OpListGroups(const DbContext& db_cntx, string_view key, EngineShard* shard) { auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(db_cntx, key, OBJ_STREAM); + OpResult res_it = db_slice.FindReadOnly(db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); vector result; - CompactObj& cobj = (*res_it)->second; + const CompactObj& cobj = (*res_it)->second; stream* s = (stream*)cobj.RObjPtr(); if (s->cgroups) { @@ -1012,12 +1021,12 @@ void GetConsumers(stream* s, streamCG* cg, long long count, GroupInfo* ginfo) { OpResult OpStreams(const DbContext& db_cntx, string_view key, EngineShard* shard, int full, size_t count) { auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(db_cntx, key, OBJ_STREAM); + OpResult res_it = db_slice.FindReadOnly(db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); vector result; - CompactObj& cobj = (*res_it)->second; + const CompactObj& cobj = (*res_it)->second; stream* s = (stream*)cobj.RObjPtr(); StreamInfo sinfo; @@ -1075,13 +1084,13 @@ OpResult OpStreams(const DbContext& db_cntx, string_view key, Engine OpResult> OpConsumers(const DbContext& db_cntx, EngineShard* shard, string_view stream_name, string_view group_name) { auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(db_cntx, stream_name, OBJ_STREAM); + OpResult res_it = db_slice.FindReadOnly(db_cntx, stream_name, OBJ_STREAM); if (!res_it) return res_it.status(); vector result; - CompactObj& cobj = (*res_it)->second; - stream* s = (stream*)cobj.RObjPtr(); + const CompactObj& cobj = (*res_it)->second; + stream* s = GetReadOnlyStream(cobj); shard->tmp_str1 = sdscpylen(shard->tmp_str1, group_name.data(), group_name.length()); streamCG* cg = streamLookupCG(s, shard->tmp_str1); if (cg == NULL) { @@ -1120,6 +1129,7 @@ struct CreateOpts { OpStatus OpCreate(const OpArgs& op_args, string_view key, const CreateOpts& opts) { auto* shard = op_args.shard; auto& db_slice = shard->db_slice(); + // TODO(#2252): Replace with FindMutable() once new AddNew() is implemented OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); int64_t entries_read = SCG_INVALID_ENTRIES_READ; if (!res_it) { @@ -1158,19 +1168,24 @@ OpStatus OpCreate(const OpArgs& op_args, string_view key, const CreateOpts& opts return OpStatus::BUSY_GROUP; } -OpResult> FindGroup(const OpArgs& op_args, string_view key, - string_view gname) { +struct FindGroupResult { + stream* s = nullptr; + streamCG* cg = nullptr; + DbSlice::AutoUpdater post_updater; +}; +OpResult FindGroup(const OpArgs& op_args, string_view key, string_view gname) { auto* shard = op_args.shard; auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); - CompactObj& cobj = (*res_it)->second; - pair res; - res.first = (stream*)cobj.RObjPtr(); + CompactObj& cobj = res_it->it->second; + FindGroupResult res; + res.s = (stream*)cobj.RObjPtr(); shard->tmp_str1 = sdscpylen(shard->tmp_str1, gname.data(), gname.size()); - res.second = streamLookupCG(res.first, shard->tmp_str1); + res.cg = streamLookupCG(res.s, shard->tmp_str1); + res.post_updater = std::move(res_it->post_updater); return res; } @@ -1231,12 +1246,10 @@ void AppendClaimResultItem(ClaimInfo& result, stream* s, streamID id) { // XCLAIM key group consumer min-idle-time id OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimOpts& opts, absl::Span ids) { - OpResult> cgr_res = FindGroup(op_args, key, opts.group); + auto cgr_res = FindGroup(op_args, key, opts.group); if (!cgr_res) return cgr_res.status(); - stream* s = cgr_res->first; - streamCG* scg = cgr_res->second; - if (!scg) { + if (!cgr_res->cg) { return OpStatus::SKIPPED; } streamConsumer* consumer = nullptr; @@ -1246,8 +1259,8 @@ OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimO streamID last_id = opts.last_id; if (opts.flags & kClaimLastID) { - if (streamCompareID(&last_id, &scg->last_id) > 0) { - scg->last_id = last_id; + if (streamCompareID(&last_id, &cgr_res->cg->last_id) > 0) { + cgr_res->cg->last_id = last_id; } } @@ -1255,11 +1268,11 @@ OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimO std::array buf; StreamEncodeID(buf.begin(), &id); - streamNACK* nack = (streamNACK*)raxFind(scg->pel, buf.begin(), sizeof(buf)); - if (!streamEntryExists(s, &id)) { + streamNACK* nack = (streamNACK*)raxFind(cgr_res->cg->pel, buf.begin(), sizeof(buf)); + if (!streamEntryExists(cgr_res->s, &id)) { if (nack != raxNotFound) { /* Release the NACK */ - raxRemove(scg->pel, buf.begin(), sizeof(buf), nullptr); + raxRemove(cgr_res->cg->pel, buf.begin(), sizeof(buf), nullptr); raxRemove(nack->consumer->pel, buf.begin(), sizeof(buf), nullptr); streamFreeNACK(nack); } @@ -1271,7 +1284,7 @@ OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimO if ((opts.flags & kClaimForce) && nack == raxNotFound) { /* Create the NACK. */ nack = streamCreateNACK(nullptr); - raxInsert(scg->pel, buf.begin(), sizeof(buf), nack, nullptr); + raxInsert(cgr_res->cg->pel, buf.begin(), sizeof(buf), nack, nullptr); } // We found the nack, continue. @@ -1287,9 +1300,9 @@ OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimO // Try to get the consumer. If not found, create a new one. op_args.shard->tmp_str1 = sdscpylen(op_args.shard->tmp_str1, opts.consumer.data(), opts.consumer.size()); - if ((consumer = streamLookupConsumer(scg, op_args.shard->tmp_str1, SLC_NO_REFRESH)) == + if ((consumer = streamLookupConsumer(cgr_res->cg, op_args.shard->tmp_str1, SLC_NO_REFRESH)) == nullptr) { - consumer = streamCreateConsumer(scg, op_args.shard->tmp_str1, nullptr, 0, + consumer = streamCreateConsumer(cgr_res->cg, op_args.shard->tmp_str1, nullptr, 0, SCC_NO_NOTIFY | SCC_NO_DIRTIFY); } @@ -1319,7 +1332,7 @@ OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimO } /* Send the reply for this entry. */ - AppendClaimResultItem(result, s, id); + AppendClaimResultItem(result, cgr_res->s, id); } } return result; @@ -1327,16 +1340,13 @@ OpResult OpClaim(const OpArgs& op_args, string_view key, const ClaimO // XGROUP DESTROY key groupname OpStatus OpDestroyGroup(const OpArgs& op_args, string_view key, string_view gname) { - OpResult> cgr_res = FindGroup(op_args, key, gname); + auto cgr_res = FindGroup(op_args, key, gname); if (!cgr_res) return cgr_res.status(); - stream* s = cgr_res->first; - streamCG* scg = cgr_res->second; - - if (scg) { - raxRemove(s->cgroups, (uint8_t*)(gname.data()), gname.size(), NULL); - streamFreeCG(scg); + if (cgr_res->cg) { + raxRemove(cgr_res->s->cgroups, (uint8_t*)(gname.data()), gname.size(), NULL); + streamFreeCG(cgr_res->cg); return OpStatus::OK; } @@ -1366,7 +1376,7 @@ vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA if (!group_res) { continue; } - if (group = group_res->second; !group) { + if (group = group_res->cg; !group) { continue; } @@ -1385,10 +1395,10 @@ vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA // XGROUP CREATECONSUMER key groupname consumername OpResult OpCreateConsumer(const OpArgs& op_args, string_view key, string_view gname, string_view consumer_name) { - OpResult> cgroup_res = FindGroup(op_args, key, gname); + auto cgroup_res = FindGroup(op_args, key, gname); if (!cgroup_res) return cgroup_res.status(); - streamCG* cg = cgroup_res->second; + streamCG* cg = cgroup_res->cg; if (cg == nullptr) return OpStatus::SKIPPED; @@ -1405,11 +1415,11 @@ OpResult OpCreateConsumer(const OpArgs& op_args, string_view key, stri // XGROUP DELCONSUMER key groupname consumername OpResult OpDelConsumer(const OpArgs& op_args, string_view key, string_view gname, string_view consumer_name) { - OpResult> cgroup_res = FindGroup(op_args, key, gname); + auto cgroup_res = FindGroup(op_args, key, gname); if (!cgroup_res) return cgroup_res.status(); - streamCG* cg = cgroup_res->second; + streamCG* cg = cgroup_res->cg; if (cg == nullptr) return OpStatus::SKIPPED; @@ -1427,18 +1437,18 @@ OpResult OpDelConsumer(const OpArgs& op_args, string_view key, string_ } OpStatus OpSetId(const OpArgs& op_args, string_view key, string_view gname, string_view id) { - OpResult> cgr_res = FindGroup(op_args, key, gname); + auto cgr_res = FindGroup(op_args, key, gname); if (!cgr_res) return cgr_res.status(); - streamCG* cg = cgr_res->second; + streamCG* cg = cgr_res->cg; if (cg == nullptr) return OpStatus::SKIPPED; streamID sid; ParsedStreamId parsed_id; if (id == "$") { - sid = cgr_res->first->last_id; + sid = cgr_res->s->last_id; } else { if (ParseID(id, true, 0, &parsed_id)) { sid = parsed_id.val; @@ -1454,11 +1464,11 @@ OpStatus OpSetId(const OpArgs& op_args, string_view key, string_view gname, stri OpStatus OpSetId2(const OpArgs& op_args, string_view key, const streamID& sid) { auto* shard = op_args.shard; auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); - CompactObj& cobj = (*res_it)->second; + CompactObj& cobj = res_it->it->second; stream* stream_inst = (stream*)cobj.RObjPtr(); long long entries_added = -1; streamID max_xdel_id{0, 0}; @@ -1493,11 +1503,11 @@ OpStatus OpSetId2(const OpArgs& op_args, string_view key, const streamID& sid) { OpResult OpDel(const OpArgs& op_args, string_view key, absl::Span ids) { auto* shard = op_args.shard; auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STREAM); if (!res_it) return res_it.status(); - CompactObj& cobj = (*res_it)->second; + CompactObj& cobj = res_it->it->second; stream* stream_inst = (stream*)cobj.RObjPtr(); uint32_t deleted = 0; @@ -1536,12 +1546,11 @@ OpResult OpDel(const OpArgs& op_args, string_view key, absl::Span OpAck(const OpArgs& op_args, string_view key, string_view gname, absl::Span ids) { - OpResult> res = FindGroup(op_args, key, gname); + auto res = FindGroup(op_args, key, gname); if (!res) return res.status(); - auto [stream_inst, cg] = *res; - if (cg == nullptr || stream_inst == nullptr) { + if (res->cg == nullptr || res->s == nullptr) { return 0; } @@ -1554,9 +1563,9 @@ OpResult OpAck(const OpArgs& op_args, string_view key, string_view gna // Lookup the ID in the group PEL: it will have a reference to the // NACK structure that will have a reference to the consumer, so that // we are able to remove the entry from both PELs. - streamNACK* nack = (streamNACK*)raxFind(cg->pel, buf, sizeof(buf)); + streamNACK* nack = (streamNACK*)raxFind(res->cg->pel, buf, sizeof(buf)); if (nack != raxNotFound) { - raxRemove(cg->pel, buf, sizeof(buf), nullptr); + raxRemove(res->cg->pel, buf, sizeof(buf), nullptr); raxRemove(nack->consumer->pel, buf, sizeof(buf), nullptr); streamFreeNACK(nack); acknowledged++; @@ -1566,10 +1575,11 @@ OpResult OpAck(const OpArgs& op_args, string_view key, string_view gna } OpResult OpAutoClaim(const OpArgs& op_args, string_view key, const ClaimOpts& opts) { - OpResult> cgr_res = FindGroup(op_args, key, opts.group); + auto cgr_res = FindGroup(op_args, key, opts.group); if (!cgr_res) return cgr_res.status(); - auto [stream, group] = *cgr_res; + stream* stream = cgr_res->s; + streamCG* group = cgr_res->cg; if (stream == nullptr || group == nullptr) { return OpStatus::KEY_NOTFOUND; @@ -1770,12 +1780,12 @@ PendingExtendedResultList GetPendingExtendedResult(streamCG* cg, streamConsumer* } OpResult OpPending(const OpArgs& op_args, string_view key, const PendingOpts& opts) { - OpResult> cgroup_res = FindGroup(op_args, key, opts.group_name); + auto cgroup_res = FindGroup(op_args, key, opts.group_name); if (!cgroup_res) { return cgroup_res.status(); } - streamCG* cg = cgroup_res->second; + streamCG* cg = cgroup_res->cg; if (cg == nullptr) { return OpStatus::SKIPPED; } @@ -1923,7 +1933,7 @@ void XGroupHelp(CmdArgList args, ConnectionContext* cntx) { OpResult OpTrim(const OpArgs& op_args, const AddTrimOpts& opts) { auto* shard = op_args.shard; auto& db_slice = shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, opts.key, OBJ_STREAM); + auto res_it = db_slice.FindMutable(op_args.db_cntx, opts.key, OBJ_STREAM); if (!res_it) { if (res_it.status() == OpStatus::KEY_NOTFOUND) { return 0; @@ -1931,7 +1941,7 @@ OpResult OpTrim(const OpArgs& op_args, const AddTrimOpts& opts) { return res_it.status(); } - CompactObj& cobj = (*res_it)->second; + CompactObj& cobj = res_it->it->second; stream* s = (stream*)cobj.RObjPtr(); return StreamTrim(opts, s); diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 9dc222ed4..40f71f62f 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -80,7 +80,7 @@ OpResult OpSetRange(const OpArgs& op_args, string_view key, size_t sta size_t range_len = start + value.size(); if (range_len == 0) { - auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_STRING); + auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (it_res) { return it_res.value()->second.Size(); } else { @@ -114,7 +114,7 @@ OpResult OpSetRange(const OpArgs& op_args, string_view key, size_t sta OpResult OpGetRange(const OpArgs& op_args, string_view key, int32_t start, int32_t end) { auto& db_slice = op_args.shard->db_slice(); - OpResult it_res = db_slice.Find(op_args.db_cntx, key, OBJ_STRING); + OpResult it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING); if (!it_res.ok()) return it_res.status(); @@ -154,10 +154,7 @@ size_t ExtendExisting(const OpArgs& op_args, PrimeIterator it, string_view key, else new_val = absl::StrCat(slice, val); - auto& db_slice = shard->db_slice(); - db_slice.PreUpdate(op_args.db_cntx.db_index, it); it->second.SetString(new_val); - db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, true); return new_val.size(); } @@ -170,6 +167,7 @@ OpResult ExtendOrSet(const OpArgs& op_args, string_view key, string_vi auto [it, inserted] = db_slice.AddOrFind(op_args.db_cntx, key); if (inserted) { it->second.SetString(val); + // TODO(#2252): We currently only call PostUpdate() (no PreUpdate()), make sure this is fixed db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, false); return val.size(); @@ -178,17 +176,20 @@ OpResult ExtendOrSet(const OpArgs& op_args, string_view key, string_vi if (it->second.ObjType() != OBJ_STRING) return OpStatus::WRONG_TYPE; - return ExtendExisting(op_args, it, key, val, prepend); + db_slice.PreUpdate(op_args.db_cntx.db_index, it); + size_t res = ExtendExisting(op_args, it, key, val, prepend); + db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, true); + return res; } OpResult ExtendOrSkip(const OpArgs& op_args, string_view key, string_view val, bool prepend) { auto& db_slice = op_args.shard->db_slice(); - OpResult it_res = db_slice.Find(op_args.db_cntx, key, OBJ_STRING); + auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING); if (!it_res) { return false; } - return ExtendExisting(op_args, *it_res, key, val, prepend); + return ExtendExisting(op_args, it_res->it, key, val, prepend); } OpResult OpGet(const OpArgs& op_args, string_view key, bool del_hit = false, @@ -509,11 +510,12 @@ SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const auto& db_slice = shard->db_slice(); SinkReplyBuilder::MGetResponse response(args.size()); - absl::InlinedVector iters(args.size()); + absl::InlinedVector iters(args.size()); size_t total_size = 0; for (size_t i = 0; i < args.size(); ++i) { - OpResult it_res = db_slice.Find(t->GetDbContext(), args[i], OBJ_STRING); + OpResult it_res = + db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_STRING); if (!it_res) continue; iters[i] = *it_res; @@ -524,7 +526,7 @@ SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const char* next = response.storage_list->data; for (size_t i = 0; i < args.size(); ++i) { - PrimeIterator it = iters[i]; + PrimeConstIterator it = iters[i]; if (it.is_done()) continue; @@ -1292,7 +1294,8 @@ void StringFamily::StrLen(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - OpResult it_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_STRING); + OpResult it_res = + shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_STRING); if (!it_res.ok()) return it_res.status(); diff --git a/src/server/table.h b/src/server/table.h index 2935dbbc4..efbf83832 100644 --- a/src/server/table.h +++ b/src/server/table.h @@ -25,8 +25,9 @@ using PrimeTable = DashTable; using ExpireTable = DashTable; /// Iterators are invalidated when new keys are added to the table or some entries are deleted. -/// Iterators are still valid if a different entry in the table was mutated. +/// Iterators are still valid if a different entry in the table was mutated. using PrimeIterator = PrimeTable::iterator; +using PrimeConstIterator = PrimeTable::const_iterator; using ExpireIterator = ExpireTable::iterator; inline bool IsValid(PrimeIterator it) { @@ -37,6 +38,10 @@ inline bool IsValid(ExpireIterator it) { return !it.is_done(); } +inline bool IsValid(PrimeConstIterator it) { + return !it.is_done(); +} + struct SlotStats { uint64_t key_count = 0; uint64_t total_reads = 0; diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 04938e71e..d7065b339 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -147,7 +147,7 @@ int ZsetDel(detail::RobjWrapper* robj_wrapper, sds ele) { } // taken from t_zset.c -std::optional GetZsetScore(detail::RobjWrapper* robj_wrapper, sds member) { +std::optional GetZsetScore(const detail::RobjWrapper* robj_wrapper, sds member) { if (robj_wrapper->encoding() == OBJ_ENCODING_LISTPACK) { double score; if (zzlFind((uint8_t*)robj_wrapper->inner_obj(), member, &score) == NULL) @@ -186,6 +186,7 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& op_args size_t member_len) { auto& db_slice = op_args.shard->db_slice(); if (zparams.flags & ZADD_IN_XX) { + // TODO(#2252): Replace once AddOrFindMutable() exists return db_slice.Find(op_args.db_cntx, key, OBJ_ZSET); } @@ -721,10 +722,11 @@ void SendAtLeastOneKeyError(ConnectionContext* cntx) { enum class AggType : uint8_t { SUM, MIN, MAX, NOOP }; using ScoredMap = absl::flat_hash_map; -ScoredMap FromObject(CompactObj& co, double weight) { +ScoredMap FromObject(const CompactObj& co, double weight) { ZSetFamily::RangeParams params; params.with_scores = true; - IntervalVisitor vis(Action::RANGE, params, &co); + // RANGE is a read-only operation, but requires const_cast + IntervalVisitor vis(Action::RANGE, params, &const_cast(co)); vis(ZSetFamily::IndexInterval(0, -1)); ScoredArray arr = vis.PopResult(); @@ -795,16 +797,16 @@ void InterScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) { dest->swap(*src); } -using KeyIterWeightVec = vector>; +using KeyIterWeightVec = vector>; ScoredMap UnionShardKeysWithScore(const KeyIterWeightVec& key_iter_weight_vec, AggType agg_type) { ScoredMap result; - for (const auto& key_iter_wieght : key_iter_weight_vec) { - if (key_iter_wieght.first.is_done()) { + for (const auto& key_iter_weight : key_iter_weight_vec) { + if (key_iter_weight.first.is_done()) { continue; } - ScoredMap sm = FromObject(key_iter_wieght.first->second, key_iter_wieght.second); + ScoredMap sm = FromObject(key_iter_weight.first->second, key_iter_weight.second); if (result.empty()) { result.swap(sm); } else { @@ -852,7 +854,7 @@ OpResult OpUnion(EngineShard* shard, Transaction* t, string_view dest auto& db_slice = shard->db_slice(); KeyIterWeightVec key_weight_vec(keys.size()); for (unsigned j = 0; j < keys.size(); ++j) { - auto it_res = db_slice.Find(t->GetDbContext(), keys[j], OBJ_ZSET); + auto it_res = db_slice.FindReadOnly(t->GetDbContext(), keys[j], OBJ_ZSET); if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1. return it_res.status(); if (!it_res) @@ -1273,9 +1275,9 @@ bool ParseLimit(string_view offset_str, string_view limit_str, ZSetFamily::Range ScoredArray OpBZPop(Transaction* t, EngineShard* shard, std::string_view key, bool is_max) { auto& db_slice = shard->db_slice(); - auto it_res = db_slice.Find(t->GetDbContext(), key, OBJ_ZSET); + auto it_res = db_slice.FindMutable(t->GetDbContext(), key, OBJ_ZSET); CHECK(it_res) << t->DebugId() << " " << key; // must exist and must be ok. - PrimeIterator it = *it_res; + PrimeIterator it = it_res->it; ZSetFamily::RangeParams range_params; range_params.reverse = is_max; @@ -1291,12 +1293,12 @@ ScoredArray OpBZPop(Transaction* t, EngineShard* shard, std::string_view key, bo IntervalVisitor iv{Action::POP, range_spec.params, &pv}; std::visit(iv, range_spec.interval); - db_slice.PostUpdate(t->GetDbIndex(), *it_res, key); + it_res->post_updater.Run(); auto zlen = pv.Size(); if (zlen == 0) { DVLOG(1) << "deleting key " << key << " " << t->DebugId(); - CHECK(db_slice.Del(t->GetDbIndex(), *it_res)); + CHECK(db_slice.Del(t->GetDbIndex(), it_res->it)); } OpArgs op_args = t->GetOpArgs(shard); @@ -1364,7 +1366,7 @@ vector OpFetch(EngineShard* shard, Transaction* t) { auto& db_slice = shard->db_slice(); for (size_t i = 0; i < keys.size(); ++i) { - auto it = db_slice.Find(t->GetDbContext(), keys[i], OBJ_ZSET); + auto it = db_slice.FindReadOnly(t->GetDbContext(), keys[i], OBJ_ZSET); if (!it) { results.push_back({}); continue; @@ -1380,22 +1382,20 @@ vector OpFetch(EngineShard* shard, Transaction* t) { auto OpPopCount(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, string_view key) -> OpResult { auto& db_slice = op_args.shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it); - - PrimeValue& pv = res_it.value()->second; + PrimeValue& pv = res_it->it->second; IntervalVisitor iv{Action::POP, range_spec.params, &pv}; std::visit(iv, range_spec.interval); - db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key); + res_it->post_updater.Run(); auto zlen = pv.Size(); if (zlen == 0) { - CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value())); + CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it->it)); } return iv.PopResult(); @@ -1403,11 +1403,13 @@ auto OpPopCount(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, string_view key) -> OpResult { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - PrimeValue& pv = res_it.value()->second; + // Action::RANGE is read-only, but requires mutable pointer, thus const_cast + PrimeValue& pv = const_cast(res_it.value()->second); IntervalVisitor iv{Action::RANGE, range_spec.params, &pv}; std::visit(iv, range_spec.interval); @@ -1417,11 +1419,13 @@ auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, st auto OpRanges(const std::vector& range_specs, const OpArgs& op_args, string_view key) -> OpResult> { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - PrimeValue& pv = res_it.value()->second; + // Action::RANGE is read-only, but requires mutable pointer, thus const_cast + PrimeValue& pv = const_cast(res_it.value()->second); vector result_arrays; for (auto& range_spec : range_specs) { IntervalVisitor iv{Action::RANGE, range_spec.params, &pv}; @@ -1435,21 +1439,19 @@ auto OpRanges(const std::vector& range_specs, const OpAr OpResult OpRemRange(const OpArgs& op_args, string_view key, const ZSetFamily::ZRangeSpec& range_spec) { auto& db_slice = op_args.shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it); - - PrimeValue& pv = res_it.value()->second; + PrimeValue& pv = res_it->it->second; IntervalVisitor iv{Action::REMOVE, range_spec.params, &pv}; std::visit(iv, range_spec.interval); - db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key); + res_it->post_updater.Run(); auto zlen = pv.Size(); if (zlen == 0) { - CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value())); + CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it->it)); } return iv.removed(); @@ -1457,11 +1459,12 @@ OpResult OpRemRange(const OpArgs& op_args, string_view key, OpResult OpRank(const OpArgs& op_args, string_view key, string_view member, bool reverse) { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); + const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); if (robj_wrapper->encoding() == OBJ_ENCODING_LISTPACK) { unsigned char* zl = (uint8_t*)robj_wrapper->inner_obj(); unsigned char *eptr, *sptr; @@ -1503,11 +1506,12 @@ OpResult OpRank(const OpArgs& op_args, string_view key, string_view me OpResult OpCount(const OpArgs& op_args, std::string_view key, const ZSetFamily::ScoreInterval& interval) { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); + const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); zrangespec range = GetZrangeSpec(false, interval); unsigned count = 0; @@ -1553,13 +1557,14 @@ OpResult OpCount(const OpArgs& op_args, std::string_view key, OpResult OpLexCount(const OpArgs& op_args, string_view key, const ZSetFamily::LexInterval& interval) { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); zlexrangespec range = GetLexRange(false, interval); unsigned count = 0; - detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); + const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); if (robj_wrapper->encoding() == OBJ_ENCODING_LISTPACK) { uint8_t* zl = (uint8_t*)robj_wrapper->inner_obj(); @@ -1597,12 +1602,11 @@ OpResult OpLexCount(const OpArgs& op_args, string_view key, OpResult OpRem(const OpArgs& op_args, string_view key, ArgSlice members) { auto& db_slice = op_args.shard->db_slice(); - OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET); + auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it); - detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); + detail::RobjWrapper* robj_wrapper = res_it->it->second.GetRobjWrapper(); sds& tmp_str = op_args.shard->tmp_str1; unsigned deleted = 0; for (string_view member : members) { @@ -1610,25 +1614,26 @@ OpResult OpRem(const OpArgs& op_args, string_view key, ArgSlice member deleted += ZsetDel(robj_wrapper, tmp_str); } auto zlen = robj_wrapper->Size(); - db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key); + res_it->post_updater.Run(); if (zlen == 0) { - CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value())); + CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it->it)); } return deleted; } OpResult OpScore(const OpArgs& op_args, string_view key, string_view member) { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); - PrimeValue& pv = res_it.value()->second; + const PrimeValue& pv = res_it.value()->second; sds& tmp_str = op_args.shard->tmp_str1; tmp_str = sdscpylen(tmp_str, member.data(), member.size()); - detail::RobjWrapper* robj_wrapper = pv.GetRobjWrapper(); + const detail::RobjWrapper* robj_wrapper = pv.GetRobjWrapper(); auto res = GetZsetScore(robj_wrapper, tmp_str); if (!res) return OpStatus::KEY_NOTFOUND; @@ -1636,13 +1641,14 @@ OpResult OpScore(const OpArgs& op_args, string_view key, string_view mem } OpResult OpMScore(const OpArgs& op_args, string_view key, ArgSlice members) { - OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult res_it = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!res_it) return res_it.status(); MScoreResponse scores(members.size()); - detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); + const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper(); sds& tmp_str = op_args.shard->tmp_str1; for (size_t i = 0; i < members.size(); i++) { @@ -1657,20 +1663,21 @@ OpResult OpMScore(const OpArgs& op_args, string_view key, ArgSli OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor, const ScanOpts& scan_op) { - OpResult find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + OpResult find_res = + op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); if (!find_res) return find_res.status(); - PrimeIterator it = find_res.value(); - PrimeValue& pv = it->second; + PrimeConstIterator it = find_res.value(); + const PrimeValue& pv = it->second; StringVec res; char buf[128]; if (pv.Encoding() == OBJ_ENCODING_LISTPACK) { ZSetFamily::RangeParams params; params.with_scores = true; - IntervalVisitor iv{Action::RANGE, params, &pv}; + IntervalVisitor iv{Action::RANGE, params, const_cast(&pv)}; iv(ZSetFamily::IndexInterval{0, kuint32max}); ScoredArray arr = iv.PopResult(); @@ -1865,7 +1872,8 @@ void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - OpResult find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_ZSET); + OpResult find_res = + shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_ZSET); if (!find_res) { return find_res.status(); }