chore: add db_slice lock to protect segments from preemptions (#3406)

DastTable::Traverse is error prone when the callback passed preempts because the segment might change. This is problematic and we need atomicity while traversing segments with preemption. The fix is to add Traverse in DbSlice and protect the traversal via ThreadLocalMutex.

* add ConditionFlag to DbSlice
* add Traverse in DbSlice and protect it with the ConditionFlag
* remove condition flag from snapshot
* remove condition flag from streamer

---------

Signed-off-by: kostas <kostas@dragonflydb.io>
This commit is contained in:
Kostas Kyrimis 2024-07-30 15:02:54 +03:00 committed by GitHub
parent f536f8afbd
commit aa02070e3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 65 additions and 71 deletions

View file

@ -452,4 +452,25 @@ RandomPick UniquePicksGenerator::Generate() {
return max_index; return max_index;
} }
ThreadLocalMutex::ThreadLocalMutex() {
shard_ = EngineShard::tlocal();
}
ThreadLocalMutex::~ThreadLocalMutex() {
DCHECK_EQ(EngineShard::tlocal(), shard_);
}
void ThreadLocalMutex::lock() {
DCHECK_EQ(EngineShard::tlocal(), shard_);
util::fb2::NoOpLock noop_lk_;
cond_var_.wait(noop_lk_, [this]() { return !flag_; });
flag_ = true;
}
void ThreadLocalMutex::unlock() {
DCHECK_EQ(EngineShard::tlocal(), shard_);
flag_ = false;
cond_var_.notify_one();
}
} // namespace dfly } // namespace dfly

View file

@ -365,45 +365,18 @@ struct ConditionFlag {
}; };
// Helper class used to guarantee atomicity between serialization of buckets // Helper class used to guarantee atomicity between serialization of buckets
class ConditionGuard { class ThreadLocalMutex {
public: public:
explicit ConditionGuard(ConditionFlag* enclosing) : enclosing_(enclosing) { ThreadLocalMutex();
util::fb2::NoOpLock noop_lk_; ~ThreadLocalMutex();
enclosing_->cond_var.wait(noop_lk_, [this]() { return !enclosing_->flag; });
enclosing_->flag = true;
}
~ConditionGuard() { void lock();
enclosing_->flag = false; void unlock();
enclosing_->cond_var.notify_one();
}
private:
ConditionFlag* enclosing_;
};
class LocalBlockingCounter {
public:
void lock() {
++mutating_;
}
void unlock() {
DCHECK(mutating_ > 0);
--mutating_;
if (mutating_ == 0) {
cond_var_.notify_one();
}
}
void Wait() {
util::fb2::NoOpLock noop_lk_;
cond_var_.wait(noop_lk_, [this]() { return mutating_ == 0; });
}
private: private:
EngineShard* shard_;
util::fb2::CondVarAny cond_var_; util::fb2::CondVarAny cond_var_;
size_t mutating_ = 0; bool flag_ = false;
}; };
} // namespace dfly } // namespace dfly

View file

@ -744,7 +744,7 @@ void DbSlice::FlushSlotsFb(const cluster::SlotSet& slot_ids) {
PrimeTable::Cursor cursor; PrimeTable::Cursor cursor;
uint64_t i = 0; uint64_t i = 0;
do { do {
PrimeTable::Cursor next = pt->Traverse(cursor, del_entry_cb); PrimeTable::Cursor next = Traverse(pt, cursor, del_entry_cb);
++i; ++i;
cursor = next; cursor = next;
if (i % 100 == 0) { if (i % 100 == 0) {
@ -1149,7 +1149,7 @@ void DbSlice::ExpireAllIfNeeded() {
ExpireTable::Cursor cursor; ExpireTable::Cursor cursor;
do { do {
cursor = db.expire.Traverse(cursor, cb); cursor = Traverse(&db.expire, cursor, cb);
} while (cursor); } while (cursor);
} }
} }
@ -1160,7 +1160,6 @@ uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) {
void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound) { void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound) {
FetchedItemsRestorer fetched_restorer(&fetched_items_); FetchedItemsRestorer fetched_restorer(&fetched_items_);
std::unique_lock<LocalBlockingCounter> lk(block_counter_);
uint64_t bucket_version = it.GetVersion(); uint64_t bucket_version = it.GetVersion();
// change_cb_ is ordered by version. // change_cb_ is ordered by version.
@ -1184,7 +1183,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_
//! Unregisters the callback. //! Unregisters the callback.
void DbSlice::UnregisterOnChange(uint64_t id) { void DbSlice::UnregisterOnChange(uint64_t id) {
block_counter_.Wait(); std::unique_lock lk(local_mu_);
auto it = find_if(change_cb_.begin(), change_cb_.end(), auto it = find_if(change_cb_.begin(), change_cb_.end(),
[id](const auto& cb) { return cb.first == id; }); [id](const auto& cb) { return cb.first == id; });
CHECK(it != change_cb_.end()); CHECK(it != change_cb_.end());
@ -1216,13 +1215,13 @@ auto DbSlice::DeleteExpiredStep(const Context& cntx, unsigned count) -> DeleteEx
unsigned i = 0; unsigned i = 0;
for (; i < count / 3; ++i) { for (; i < count / 3; ++i) {
db.expire_cursor = db.expire.Traverse(db.expire_cursor, cb); db.expire_cursor = Traverse(&db.expire, db.expire_cursor, cb);
} }
// continue traversing only if we had strong deletion rate based on the first sample. // continue traversing only if we had strong deletion rate based on the first sample.
if (result.deleted * 4 > result.traversed) { if (result.deleted * 4 > result.traversed) {
for (; i < count; ++i) { for (; i < count; ++i) {
db.expire_cursor = db.expire.Traverse(db.expire_cursor, cb); db.expire_cursor = Traverse(&db.expire, db.expire_cursor, cb);
} }
} }
@ -1388,7 +1387,7 @@ void DbSlice::ClearOffloadedEntries(absl::Span<const DbIndex> indices, const DbT
// Delete all tiered entries // Delete all tiered entries
PrimeTable::Cursor cursor; PrimeTable::Cursor cursor;
do { do {
cursor = db_ptr->prime.Traverse(cursor, [&](PrimeIterator it) { cursor = Traverse(&db_ptr->prime, cursor, [&](PrimeIterator it) {
if (it->second.IsExternal()) { if (it->second.IsExternal()) {
tiered_storage->Delete(index, &it->second); tiered_storage->Delete(index, &it->second);
} else if (it->second.HasStashPending()) { } else if (it->second.HasStashPending()) {
@ -1515,7 +1514,7 @@ void DbSlice::CallChangeCallbacks(DbIndex id, std::string_view key, const Change
DVLOG(2) << "Running callbacks for key " << key << " in dbid " << id; DVLOG(2) << "Running callbacks for key " << key << " in dbid " << id;
FetchedItemsRestorer fetched_restorer(&fetched_items_); FetchedItemsRestorer fetched_restorer(&fetched_items_);
std::unique_lock<LocalBlockingCounter> lk(block_counter_); std::unique_lock lk(local_mu_);
const size_t limit = change_cb_.size(); const size_t limit = change_cb_.size();
auto ccb = change_cb_.begin(); auto ccb = change_cb_.begin();

View file

@ -497,6 +497,20 @@ class DbSlice {
void PerformDeletion(Iterator del_it, DbTable* table); void PerformDeletion(Iterator del_it, DbTable* table);
void PerformDeletion(PrimeIterator del_it, DbTable* table); void PerformDeletion(PrimeIterator del_it, DbTable* table);
// Provides access to the internal lock of db_slice for flows that serialize
// entries with preemption and need to synchronize with Traverse below which
// acquires the same lock.
ThreadLocalMutex* GetSerializationMutex() {
return &local_mu_;
}
// Wrapper around DashTable::Traverse that allows preemptions
template <typename Cb, typename DashTable>
PrimeTable::Cursor Traverse(DashTable* pt, PrimeTable::Cursor cursor, Cb&& cb) {
std::unique_lock lk(local_mu_);
return pt->Traverse(cursor, std::forward<Cb>(cb));
}
private: private:
void PreUpdate(DbIndex db_ind, Iterator it, std::string_view key); void PreUpdate(DbIndex db_ind, Iterator it, std::string_view key);
void PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size_t orig_size); void PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size_t orig_size);
@ -550,13 +564,8 @@ class DbSlice {
void CallChangeCallbacks(DbIndex id, std::string_view key, const ChangeReq& cr) const; void CallChangeCallbacks(DbIndex id, std::string_view key, const ChangeReq& cr) const;
// We need this because registered callbacks might yield. If RegisterOnChange // Used to provide exclusive access while Traversing segments
// gets called after we preempt while iterating over the registered callbacks mutable ThreadLocalMutex local_mu_;
// (let's say in FlushChangeToEarlierCallbacks) we will get UB, because we pushed
// into a vector which might get resized, invalidating the iterators that are being
// used by the preempted FlushChangeToEarlierCallbacks. LocalBlockingCounter
// protects us against this case.
mutable LocalBlockingCounter block_counter_;
ShardId shard_id_; ShardId shard_id_;
uint8_t caching_mode_ : 1; uint8_t caching_mode_ : 1;

View file

@ -272,7 +272,7 @@ void DoBuildObjHist(EngineShard* shard, ConnectionContext* cntx, ObjHistMap* obj
continue; continue;
PrimeTable::Cursor cursor; PrimeTable::Cursor cursor;
do { do {
cursor = dbt->prime.Traverse(cursor, [&](PrimeIterator it) { cursor = db_slice.Traverse(&dbt->prime, cursor, [&](PrimeIterator it) {
unsigned obj_type = it->second.ObjType(); unsigned obj_type = it->second.ObjType();
auto& hist_ptr = (*obj_hist_map)[obj_type]; auto& hist_ptr = (*obj_hist_map)[obj_type];
if (!hist_ptr) { if (!hist_ptr) {

View file

@ -317,7 +317,7 @@ bool EngineShard::DoDefrag() {
uint64_t attempts = 0; uint64_t attempts = 0;
do { do {
cur = prime_table->Traverse(cur, [&](PrimeIterator it) { cur = slice.Traverse(prime_table, cur, [&](PrimeIterator it) {
// for each value check whether we should move it because it // for each value check whether we should move it because it
// seats on underutilized page of memory, and if so, do it. // seats on underutilized page of memory, and if so, do it.
bool did = it->second.DefragIfNeeded(threshold); bool did = it->second.DefragIfNeeded(threshold);

View file

@ -582,8 +582,9 @@ void OpScan(const OpArgs& op_args, const ScanOpts& scan_opts, uint64_t* cursor,
auto [prime_table, expire_table] = db_slice.GetTables(op_args.db_cntx.db_index); auto [prime_table, expire_table] = db_slice.GetTables(op_args.db_cntx.db_index);
string scratch; string scratch;
do { do {
cur = prime_table->Traverse( cur = db_slice.Traverse(prime_table, cur, [&](PrimeIterator it) {
cur, [&](PrimeIterator it) { cnt += ScanCb(op_args, it, scan_opts, &scratch, vec); }); cnt += ScanCb(op_args, it, scan_opts, &scratch, vec);
});
} while (cur && cnt < scan_opts.limit); } while (cur && cnt < scan_opts.limit);
VLOG(1) << "OpScan " << db_slice.shard_id() << " cursor: " << cur.value(); VLOG(1) << "OpScan " << db_slice.shard_id() << " cursor: " << cur.value();

View file

@ -213,9 +213,7 @@ void RestoreStreamer::Run() {
return; return;
bool written = false; bool written = false;
cursor = pt->Traverse(cursor, [&](PrimeTable::bucket_iterator it) { cursor = db_slice_->Traverse(pt, cursor, [&](PrimeTable::bucket_iterator it) {
ConditionGuard guard(&bucket_ser_);
db_slice_->FlushChangeToEarlierCallbacks(0 /*db_id always 0 for cluster*/, db_slice_->FlushChangeToEarlierCallbacks(0 /*db_id always 0 for cluster*/,
DbSlice::Iterator::FromPrime(it), snapshot_version_); DbSlice::Iterator::FromPrime(it), snapshot_version_);
if (WriteBucket(it)) { if (WriteBucket(it)) {
@ -313,8 +311,6 @@ bool RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) {
void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) { void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) {
DCHECK_EQ(db_index, 0) << "Restore migration only allowed in cluster mode in db0"; DCHECK_EQ(db_index, 0) << "Restore migration only allowed in cluster mode in db0";
ConditionGuard guard(&bucket_ser_);
PrimeTable* table = db_slice_->GetTables(0).first; PrimeTable* table = db_slice_->GetTables(0).first;
if (const PrimeTable::bucket_iterator* bit = req.update()) { if (const PrimeTable::bucket_iterator* bit = req.update()) {

View file

@ -107,8 +107,6 @@ class RestoreStreamer : public JournalStreamer {
cluster::SlotSet my_slots_; cluster::SlotSet my_slots_;
bool fiber_cancelled_ = false; bool fiber_cancelled_ = false;
bool snapshot_finished_ = false; bool snapshot_finished_ = false;
ConditionFlag bucket_ser_;
}; };
} // namespace dfly } // namespace dfly

View file

@ -221,7 +221,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
return; return;
PrimeTable::Cursor next = PrimeTable::Cursor next =
pt->Traverse(cursor, absl::bind_front(&SliceSnapshot::BucketSaveCb, this)); db_slice_->Traverse(pt, cursor, absl::bind_front(&SliceSnapshot::BucketSaveCb, this));
cursor = next; cursor = next;
PushSerializedToChannel(false); PushSerializedToChannel(false);
@ -253,8 +253,6 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
} }
bool SliceSnapshot::BucketSaveCb(PrimeIterator it) { bool SliceSnapshot::BucketSaveCb(PrimeIterator it) {
ConditionGuard guard(&bucket_ser_);
++stats_.savecb_calls; ++stats_.savecb_calls;
auto check = [&](auto v) { auto check = [&](auto v) {
@ -364,8 +362,6 @@ bool SliceSnapshot::PushSerializedToChannel(bool force) {
} }
void SliceSnapshot::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) { void SliceSnapshot::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) {
ConditionGuard guard(&bucket_ser_);
PrimeTable* table = db_slice_->GetTables(db_index).first; PrimeTable* table = db_slice_->GetTables(db_index).first;
const PrimeTable::bucket_iterator* bit = req.update(); const PrimeTable::bucket_iterator* bit = req.update();
@ -390,7 +386,7 @@ void SliceSnapshot::OnJournalEntry(const journal::JournalItem& item, bool await)
// To enable journal flushing to sync after non auto journal command is executed we call // To enable journal flushing to sync after non auto journal command is executed we call
// TriggerJournalWriteToSink. This call uses the NOOP opcode with await=true. Since there is no // TriggerJournalWriteToSink. This call uses the NOOP opcode with await=true. Since there is no
// additional journal change to serialize, it simply invokes PushSerializedToChannel. // additional journal change to serialize, it simply invokes PushSerializedToChannel.
ConditionGuard guard(&bucket_ser_); std::unique_lock lk(*db_slice_->GetSerializationMutex());
if (item.opcode != journal::Op::NOOP) { if (item.opcode != journal::Op::NOOP) {
serializer_->WriteJournalEntry(item.data); serializer_->WriteJournalEntry(item.data);
} }
@ -403,7 +399,7 @@ void SliceSnapshot::OnJournalEntry(const journal::JournalItem& item, bool await)
} }
void SliceSnapshot::CloseRecordChannel() { void SliceSnapshot::CloseRecordChannel() {
ConditionGuard guard(&bucket_ser_); std::unique_lock lk(*db_slice_->GetSerializationMutex());
CHECK(!serialize_bucket_running_); CHECK(!serialize_bucket_running_);
// Make sure we close the channel only once with a CAS check. // Make sure we close the channel only once with a CAS check.

View file

@ -179,8 +179,6 @@ class SliceSnapshot {
size_t savecb_calls = 0; size_t savecb_calls = 0;
size_t keys_total = 0; size_t keys_total = 0;
} stats_; } stats_;
ConditionFlag bucket_ser_;
}; };
} // namespace dfly } // namespace dfly

View file

@ -87,6 +87,9 @@ class DflyInstance:
if threads > 1: if threads > 1:
self.args["num_shards"] = threads - 1 self.args["num_shards"] = threads - 1
# Add 1 byte limit for big values
self.args["serialization_max_chunk_size"] = 1
def __del__(self): def __del__(self):
assert self.proc == None assert self.proc == None
@ -163,7 +166,7 @@ class DflyInstance:
proc.kill() proc.kill()
else: else:
proc.terminate() proc.terminate()
proc.communicate(timeout=15) proc.communicate(timeout=120)
# if the return code is 0 it means normal termination # if the return code is 0 it means normal termination
# if the return code is negative it means termination by signal # if the return code is negative it means termination by signal
# if the return code is positive it means abnormal exit # if the return code is positive it means abnormal exit

View file

@ -107,7 +107,7 @@ async def test_replication_all(
) )
# Wait for all replicas to transition into stable sync # Wait for all replicas to transition into stable sync
async with async_timeout.timeout(20): async with async_timeout.timeout(240):
await wait_for_replicas_state(*c_replicas) await wait_for_replicas_state(*c_replicas)
# Stop streaming data once every replica is in stable sync # Stop streaming data once every replica is in stable sync