diff --git a/src/server/common.cc b/src/server/common.cc index 9274211b7..658f29153 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -452,4 +452,25 @@ RandomPick UniquePicksGenerator::Generate() { return max_index; } +ThreadLocalMutex::ThreadLocalMutex() { + shard_ = EngineShard::tlocal(); +} + +ThreadLocalMutex::~ThreadLocalMutex() { + DCHECK_EQ(EngineShard::tlocal(), shard_); +} + +void ThreadLocalMutex::lock() { + DCHECK_EQ(EngineShard::tlocal(), shard_); + util::fb2::NoOpLock noop_lk_; + cond_var_.wait(noop_lk_, [this]() { return !flag_; }); + flag_ = true; +} + +void ThreadLocalMutex::unlock() { + DCHECK_EQ(EngineShard::tlocal(), shard_); + flag_ = false; + cond_var_.notify_one(); +} + } // namespace dfly diff --git a/src/server/common.h b/src/server/common.h index 084ba8864..ac04e256c 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -365,45 +365,18 @@ struct ConditionFlag { }; // Helper class used to guarantee atomicity between serialization of buckets -class ConditionGuard { +class ThreadLocalMutex { public: - explicit ConditionGuard(ConditionFlag* enclosing) : enclosing_(enclosing) { - util::fb2::NoOpLock noop_lk_; - enclosing_->cond_var.wait(noop_lk_, [this]() { return !enclosing_->flag; }); - enclosing_->flag = true; - } + ThreadLocalMutex(); + ~ThreadLocalMutex(); - ~ConditionGuard() { - enclosing_->flag = false; - enclosing_->cond_var.notify_one(); - } - - private: - ConditionFlag* enclosing_; -}; - -class LocalBlockingCounter { - public: - void lock() { - ++mutating_; - } - - void unlock() { - DCHECK(mutating_ > 0); - --mutating_; - if (mutating_ == 0) { - cond_var_.notify_one(); - } - } - - void Wait() { - util::fb2::NoOpLock noop_lk_; - cond_var_.wait(noop_lk_, [this]() { return mutating_ == 0; }); - } + void lock(); + void unlock(); private: + EngineShard* shard_; util::fb2::CondVarAny cond_var_; - size_t mutating_ = 0; + bool flag_ = false; }; } // namespace dfly diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index cb04c6e2d..309d675cb 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -744,7 +744,7 @@ void DbSlice::FlushSlotsFb(const cluster::SlotSet& slot_ids) { PrimeTable::Cursor cursor; uint64_t i = 0; do { - PrimeTable::Cursor next = pt->Traverse(cursor, del_entry_cb); + PrimeTable::Cursor next = Traverse(pt, cursor, del_entry_cb); ++i; cursor = next; if (i % 100 == 0) { @@ -1149,7 +1149,7 @@ void DbSlice::ExpireAllIfNeeded() { ExpireTable::Cursor cursor; do { - cursor = db.expire.Traverse(cursor, cb); + cursor = Traverse(&db.expire, cursor, cb); } while (cursor); } } @@ -1160,7 +1160,6 @@ uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) { void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound) { FetchedItemsRestorer fetched_restorer(&fetched_items_); - std::unique_lock lk(block_counter_); uint64_t bucket_version = it.GetVersion(); // change_cb_ is ordered by version. @@ -1184,7 +1183,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_ //! Unregisters the callback. void DbSlice::UnregisterOnChange(uint64_t id) { - block_counter_.Wait(); + std::unique_lock lk(local_mu_); auto it = find_if(change_cb_.begin(), change_cb_.end(), [id](const auto& cb) { return cb.first == id; }); CHECK(it != change_cb_.end()); @@ -1216,13 +1215,13 @@ auto DbSlice::DeleteExpiredStep(const Context& cntx, unsigned count) -> DeleteEx unsigned i = 0; for (; i < count / 3; ++i) { - db.expire_cursor = db.expire.Traverse(db.expire_cursor, cb); + db.expire_cursor = Traverse(&db.expire, db.expire_cursor, cb); } // continue traversing only if we had strong deletion rate based on the first sample. if (result.deleted * 4 > result.traversed) { for (; i < count; ++i) { - db.expire_cursor = db.expire.Traverse(db.expire_cursor, cb); + db.expire_cursor = Traverse(&db.expire, db.expire_cursor, cb); } } @@ -1388,7 +1387,7 @@ void DbSlice::ClearOffloadedEntries(absl::Span indices, const DbT // Delete all tiered entries PrimeTable::Cursor cursor; do { - cursor = db_ptr->prime.Traverse(cursor, [&](PrimeIterator it) { + cursor = Traverse(&db_ptr->prime, cursor, [&](PrimeIterator it) { if (it->second.IsExternal()) { tiered_storage->Delete(index, &it->second); } else if (it->second.HasStashPending()) { @@ -1515,7 +1514,7 @@ void DbSlice::CallChangeCallbacks(DbIndex id, std::string_view key, const Change DVLOG(2) << "Running callbacks for key " << key << " in dbid " << id; FetchedItemsRestorer fetched_restorer(&fetched_items_); - std::unique_lock lk(block_counter_); + std::unique_lock lk(local_mu_); const size_t limit = change_cb_.size(); auto ccb = change_cb_.begin(); diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 03f90ad2f..034d76d0d 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -497,6 +497,20 @@ class DbSlice { void PerformDeletion(Iterator del_it, DbTable* table); void PerformDeletion(PrimeIterator del_it, DbTable* table); + // Provides access to the internal lock of db_slice for flows that serialize + // entries with preemption and need to synchronize with Traverse below which + // acquires the same lock. + ThreadLocalMutex* GetSerializationMutex() { + return &local_mu_; + } + + // Wrapper around DashTable::Traverse that allows preemptions + template + PrimeTable::Cursor Traverse(DashTable* pt, PrimeTable::Cursor cursor, Cb&& cb) { + std::unique_lock lk(local_mu_); + return pt->Traverse(cursor, std::forward(cb)); + } + private: void PreUpdate(DbIndex db_ind, Iterator it, std::string_view key); void PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size_t orig_size); @@ -550,13 +564,8 @@ class DbSlice { void CallChangeCallbacks(DbIndex id, std::string_view key, const ChangeReq& cr) const; - // We need this because registered callbacks might yield. If RegisterOnChange - // gets called after we preempt while iterating over the registered callbacks - // (let's say in FlushChangeToEarlierCallbacks) we will get UB, because we pushed - // into a vector which might get resized, invalidating the iterators that are being - // used by the preempted FlushChangeToEarlierCallbacks. LocalBlockingCounter - // protects us against this case. - mutable LocalBlockingCounter block_counter_; + // Used to provide exclusive access while Traversing segments + mutable ThreadLocalMutex local_mu_; ShardId shard_id_; uint8_t caching_mode_ : 1; diff --git a/src/server/debugcmd.cc b/src/server/debugcmd.cc index a7f4a187c..8bdf23da5 100644 --- a/src/server/debugcmd.cc +++ b/src/server/debugcmd.cc @@ -272,7 +272,7 @@ void DoBuildObjHist(EngineShard* shard, ConnectionContext* cntx, ObjHistMap* obj continue; PrimeTable::Cursor cursor; do { - cursor = dbt->prime.Traverse(cursor, [&](PrimeIterator it) { + cursor = db_slice.Traverse(&dbt->prime, cursor, [&](PrimeIterator it) { unsigned obj_type = it->second.ObjType(); auto& hist_ptr = (*obj_hist_map)[obj_type]; if (!hist_ptr) { diff --git a/src/server/engine_shard_set.cc b/src/server/engine_shard_set.cc index 2ece3eae3..c20e5946b 100644 --- a/src/server/engine_shard_set.cc +++ b/src/server/engine_shard_set.cc @@ -317,7 +317,7 @@ bool EngineShard::DoDefrag() { uint64_t attempts = 0; do { - cur = prime_table->Traverse(cur, [&](PrimeIterator it) { + cur = slice.Traverse(prime_table, cur, [&](PrimeIterator it) { // for each value check whether we should move it because it // seats on underutilized page of memory, and if so, do it. bool did = it->second.DefragIfNeeded(threshold); diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index b3f99498b..fca2a27e9 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -582,8 +582,9 @@ void OpScan(const OpArgs& op_args, const ScanOpts& scan_opts, uint64_t* cursor, auto [prime_table, expire_table] = db_slice.GetTables(op_args.db_cntx.db_index); string scratch; do { - cur = prime_table->Traverse( - cur, [&](PrimeIterator it) { cnt += ScanCb(op_args, it, scan_opts, &scratch, vec); }); + cur = db_slice.Traverse(prime_table, cur, [&](PrimeIterator it) { + cnt += ScanCb(op_args, it, scan_opts, &scratch, vec); + }); } while (cur && cnt < scan_opts.limit); VLOG(1) << "OpScan " << db_slice.shard_id() << " cursor: " << cur.value(); diff --git a/src/server/journal/streamer.cc b/src/server/journal/streamer.cc index 0553aa01e..182237c83 100644 --- a/src/server/journal/streamer.cc +++ b/src/server/journal/streamer.cc @@ -213,9 +213,7 @@ void RestoreStreamer::Run() { return; bool written = false; - cursor = pt->Traverse(cursor, [&](PrimeTable::bucket_iterator it) { - ConditionGuard guard(&bucket_ser_); - + cursor = db_slice_->Traverse(pt, cursor, [&](PrimeTable::bucket_iterator it) { db_slice_->FlushChangeToEarlierCallbacks(0 /*db_id always 0 for cluster*/, DbSlice::Iterator::FromPrime(it), snapshot_version_); if (WriteBucket(it)) { @@ -313,8 +311,6 @@ bool RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) { void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) { DCHECK_EQ(db_index, 0) << "Restore migration only allowed in cluster mode in db0"; - ConditionGuard guard(&bucket_ser_); - PrimeTable* table = db_slice_->GetTables(0).first; if (const PrimeTable::bucket_iterator* bit = req.update()) { diff --git a/src/server/journal/streamer.h b/src/server/journal/streamer.h index f9af83de6..ce60f1071 100644 --- a/src/server/journal/streamer.h +++ b/src/server/journal/streamer.h @@ -107,8 +107,6 @@ class RestoreStreamer : public JournalStreamer { cluster::SlotSet my_slots_; bool fiber_cancelled_ = false; bool snapshot_finished_ = false; - - ConditionFlag bucket_ser_; }; } // namespace dfly diff --git a/src/server/snapshot.cc b/src/server/snapshot.cc index d6bedf56f..910320f41 100644 --- a/src/server/snapshot.cc +++ b/src/server/snapshot.cc @@ -221,7 +221,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn return; PrimeTable::Cursor next = - pt->Traverse(cursor, absl::bind_front(&SliceSnapshot::BucketSaveCb, this)); + db_slice_->Traverse(pt, cursor, absl::bind_front(&SliceSnapshot::BucketSaveCb, this)); cursor = next; PushSerializedToChannel(false); @@ -253,8 +253,6 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn } bool SliceSnapshot::BucketSaveCb(PrimeIterator it) { - ConditionGuard guard(&bucket_ser_); - ++stats_.savecb_calls; auto check = [&](auto v) { @@ -364,8 +362,6 @@ bool SliceSnapshot::PushSerializedToChannel(bool force) { } void SliceSnapshot::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) { - ConditionGuard guard(&bucket_ser_); - PrimeTable* table = db_slice_->GetTables(db_index).first; const PrimeTable::bucket_iterator* bit = req.update(); @@ -390,7 +386,7 @@ void SliceSnapshot::OnJournalEntry(const journal::JournalItem& item, bool await) // To enable journal flushing to sync after non auto journal command is executed we call // TriggerJournalWriteToSink. This call uses the NOOP opcode with await=true. Since there is no // additional journal change to serialize, it simply invokes PushSerializedToChannel. - ConditionGuard guard(&bucket_ser_); + std::unique_lock lk(*db_slice_->GetSerializationMutex()); if (item.opcode != journal::Op::NOOP) { serializer_->WriteJournalEntry(item.data); } @@ -403,7 +399,7 @@ void SliceSnapshot::OnJournalEntry(const journal::JournalItem& item, bool await) } void SliceSnapshot::CloseRecordChannel() { - ConditionGuard guard(&bucket_ser_); + std::unique_lock lk(*db_slice_->GetSerializationMutex()); CHECK(!serialize_bucket_running_); // Make sure we close the channel only once with a CAS check. diff --git a/src/server/snapshot.h b/src/server/snapshot.h index d83fb9737..7aaaea5b4 100644 --- a/src/server/snapshot.h +++ b/src/server/snapshot.h @@ -179,8 +179,6 @@ class SliceSnapshot { size_t savecb_calls = 0; size_t keys_total = 0; } stats_; - - ConditionFlag bucket_ser_; }; } // namespace dfly diff --git a/tests/dragonfly/instance.py b/tests/dragonfly/instance.py index 6807974d2..c8eaa184b 100644 --- a/tests/dragonfly/instance.py +++ b/tests/dragonfly/instance.py @@ -87,6 +87,9 @@ class DflyInstance: if threads > 1: self.args["num_shards"] = threads - 1 + # Add 1 byte limit for big values + self.args["serialization_max_chunk_size"] = 1 + def __del__(self): assert self.proc == None @@ -163,7 +166,7 @@ class DflyInstance: proc.kill() else: proc.terminate() - proc.communicate(timeout=15) + proc.communicate(timeout=120) # if the return code is 0 it means normal termination # if the return code is negative it means termination by signal # if the return code is positive it means abnormal exit diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index a28eb79dc..da027b8d7 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -107,7 +107,7 @@ async def test_replication_all( ) # Wait for all replicas to transition into stable sync - async with async_timeout.timeout(20): + async with async_timeout.timeout(240): await wait_for_replicas_state(*c_replicas) # Stop streaming data once every replica is in stable sync