fix: possible preemption under FiberAtomicGuard (#4692)

1. Fix FreeMemWithEvictionStep that could preempt under FiberAtomicGuard.
   This could happen during the return from the inner loop. Now, we break
   from the guard first and then preempt in a safe place.
2. Rename LocalBlockingCounter to LocalLatch
   because it's a variation of latch (see std::latch for example).
3. Rename PreUpdate to PreUpdateBlocking to emphasize it can block.
4. Fix mutations counting: consider either insertions or changing the existing entry.
   Before that we incremented this counter for misses as well.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2025-03-04 14:28:34 +02:00 committed by GitHub
parent debb2eb9e8
commit ea6fdadd67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 68 additions and 68 deletions

View file

@ -481,15 +481,15 @@ BorrowedInterpreter::~BorrowedInterpreter() {
ServerState::tlocal()->ReturnInterpreter(interpreter_);
}
void LocalBlockingCounter::unlock() {
DCHECK(mutating_ > 0);
void LocalLatch::unlock() {
DCHECK_GT(mutating_, 0u);
--mutating_;
if (mutating_ == 0) {
cond_var_.notify_all();
}
}
void LocalBlockingCounter::Wait() {
void LocalLatch::Wait() {
util::fb2::NoOpLock noop_lk_;
cond_var_.wait(noop_lk_, [this]() { return mutating_ == 0; });
}

View file

@ -393,7 +393,9 @@ struct BorrowedInterpreter {
bool owned_ = false;
};
class LocalBlockingCounter {
// A single threaded latch that passes a waiter fiber if its count is 0.
// Fibers that increase/decrease the count do not wait on the latch.
class LocalLatch {
public:
void lock() {
++mutating_;

View file

@ -215,7 +215,7 @@ unsigned PrimeEvictionPolicy::Evict(const PrimeTable::HotspotBuckets& eb, PrimeT
// log the evicted keys to journal.
if (auto journal = db_slice_->shard_owner()->journal(); journal) {
RecordExpiry(cntx_.db_index, key);
RecordExpiryBlocking(cntx_.db_index, key);
}
db_slice_->PerformDeletion(DbSlice::Iterator(last_slot_it, StringOrView::FromView(key)), table);
@ -505,7 +505,7 @@ OpResult<DbSlice::ItAndUpdater> DbSlice::FindMutableInternal(const Context& cntx
auto it = Iterator(res->it, StringOrView::FromView(key));
auto exp_it = ExpIterator(res->exp_it, StringOrView::FromView(key));
PreUpdate(cntx.db_index, it, key);
PreUpdateBlocking(cntx.db_index, it, key);
// PreUpdate() might have caused a deletion of `it`
if (res->it.IsOccupied()) {
return {{it, exp_it,
@ -534,29 +534,20 @@ OpResult<DbSlice::ConstIterator> DbSlice::FindReadOnly(const Context& cntx, stri
return res.status();
}
OpResult<DbSlice::PrimeItAndExp> DbSlice::FindInternal(const Context& cntx, std::string_view key,
std::optional<unsigned> req_obj_type,
UpdateStatsMode stats_mode) const {
if (!IsDbValid(cntx.db_index)) {
auto DbSlice::FindInternal(const Context& cntx, string_view key, optional<unsigned> req_obj_type,
UpdateStatsMode stats_mode) const -> OpResult<PrimeItAndExp> {
if (!IsDbValid(cntx.db_index)) { // Can it even happen?
LOG(DFATAL) << "Invalid db index " << cntx.db_index;
return OpStatus::KEY_NOTFOUND;
}
DbSlice::PrimeItAndExp res;
auto& db = *db_arr_[cntx.db_index];
PrimeItAndExp res;
res.it = db.prime.Find(key);
absl::Cleanup update_stats_on_miss = [&]() {
switch (stats_mode) {
case UpdateStatsMode::kMutableStats:
events_.mutations++;
break;
case UpdateStatsMode::kReadStats:
events_.misses++;
break;
}
};
int miss_weight = (stats_mode == UpdateStatsMode::kReadStats);
if (!IsValid(res.it)) {
events_.misses += miss_weight;
return OpStatus::KEY_NOTFOUND;
}
@ -564,17 +555,20 @@ OpResult<DbSlice::PrimeItAndExp> DbSlice::FindInternal(const Context& cntx, std:
TouchHllIfNeeded(key, db.dense_hll);
if (req_obj_type.has_value() && res.it->second.ObjType() != req_obj_type.value()) {
events_.misses += miss_weight;
return OpStatus::WRONG_TYPE;
}
if (res.it->second.HasExpire()) { // check expiry state
res = ExpireIfNeeded(cntx, res.it);
if (!IsValid(res.it)) {
events_.misses += miss_weight;
return OpStatus::KEY_NOTFOUND;
}
}
if (IsCacheMode() && IsValid(res.it)) {
DCHECK(IsValid(res.it));
if (IsCacheMode()) {
if (!change_cb_.empty()) {
auto bump_cb = [&](PrimeTable::bucket_iterator bit) {
CallChangeCallbacks(cntx.db_index, key, bit);
@ -582,7 +576,8 @@ OpResult<DbSlice::PrimeItAndExp> DbSlice::FindInternal(const Context& cntx, std:
db.prime.CVCUponBump(change_cb_.back().first, res.it, bump_cb);
}
block_counter_.Wait(); // We must not change the bucket's internal order during serialization
// We must not change the bucket's internal order during serialization
serialization_latch_.Wait();
auto bump_it = db.prime.BumpUp(res.it, PrimeBumpPolicy{&fetched_items_});
if (bump_it != res.it) { // the item was bumped
res.it = bump_it;
@ -590,7 +585,6 @@ OpResult<DbSlice::PrimeItAndExp> DbSlice::FindInternal(const Context& cntx, std:
}
}
std::move(update_stats_on_miss).Cancel();
switch (stats_mode) {
case UpdateStatsMode::kMutableStats:
events_.mutations++;
@ -646,7 +640,8 @@ OpResult<DbSlice::AddOrFindResult> DbSlice::AddOrFindInternal(const Context& cnt
if (res.ok()) {
Iterator it(res->it, StringOrView::FromView(key));
ExpIterator exp_it(res->exp_it, StringOrView::FromView(key));
PreUpdate(cntx.db_index, it, key);
PreUpdateBlocking(cntx.db_index, it, key);
// PreUpdate() might have caused a deletion of `it`
if (res->it.IsOccupied()) {
return DbSlice::AddOrFindResult{
@ -724,6 +719,7 @@ OpResult<DbSlice::AddOrFindResult> DbSlice::AddOrFindInternal(const Context& cnt
return OpStatus::OUT_OF_MEMORY;
}
events_.mutations++;
ssize_t table_increase = db.prime.mem_usage() - table_before;
memory_budget_ -= table_increase;
@ -1158,7 +1154,7 @@ bool DbSlice::CheckLock(IntentLock::Mode mode, DbIndex dbid, uint64_t fp) const
return true;
}
void DbSlice::PreUpdate(DbIndex db_ind, Iterator it, std::string_view key) {
void DbSlice::PreUpdateBlocking(DbIndex db_ind, Iterator it, std::string_view key) {
CallChangeCallbacks(db_ind, key, ChangeReq{it.GetInnerIt()});
it.GetInnerIt().SetVersion(NextVersion());
}
@ -1225,7 +1221,7 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato
// Replicate expiry
if (auto journal = owner_->journal(); journal) {
RecordExpiry(cntx.db_index, key);
RecordExpiryBlocking(cntx.db_index, key);
}
if (expired_keys_events_recording_)
@ -1248,7 +1244,7 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato
void DbSlice::ExpireAllIfNeeded() {
// We hold no locks to any of the keys so we should Wait() here such that
// we don't preempt in ExpireIfNeeded
block_counter_.Wait();
serialization_latch_.Wait();
// Disable flush journal changes to prevent preemtion in traverse.
journal::JournalFlushGuard journal_flush_guard(owner_->journal());
@ -1278,7 +1274,7 @@ uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) {
}
void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound) {
unique_lock<LocalBlockingCounter> lk(block_counter_);
unique_lock<LocalLatch> lk(serialization_latch_);
uint64_t bucket_version = it.GetVersion();
// change_cb_ is ordered by version.
@ -1302,7 +1298,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_
//! Unregisters the callback.
void DbSlice::UnregisterOnChange(uint64_t id) {
block_counter_.Wait();
serialization_latch_.Wait();
auto it = find_if(change_cb_.begin(), change_cb_.end(),
[id](const auto& cb) { return cb.first == id; });
CHECK(it != change_cb_.end());
@ -1396,24 +1392,6 @@ pair<uint64_t, size_t> DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t s
bool record_keys = owner_->journal() != nullptr || expired_keys_events_recording_;
vector<string> keys_to_journal;
auto return_cb = [&, this]() mutable {
// send the deletion to the replicas.
// fiber preemption could happen in this phase.
for (string_view key : keys_to_journal) {
if (auto journal = owner_->journal(); journal)
RecordExpiry(db_ind, key);
if (expired_keys_events_recording_)
db_table->expired_keys_events_.emplace_back(key);
}
SendQueuedInvalidationMessages();
auto time_finish = absl::GetCurrentTimeNanos();
events_.evicted_keys += evicted_items;
DVLOG(2) << "Eviction time (us): " << (time_finish - time_start) / 1000;
return pair<uint64_t, size_t>{evicted_items, evicted_bytes};
};
{
FiberAtomicGuard guard;
for (int32_t slot_id = num_slots - 1; slot_id >= 0; --slot_id) {
@ -1449,13 +1427,29 @@ pair<uint64_t, size_t> DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t s
// returns when whichever condition is met first
if ((evicted_items == max_eviction_per_hb) || (evicted_bytes >= increase_goal_bytes))
return return_cb();
goto finish;
}
}
}
} // FiberAtomicGuard
finish:
// send the deletion to the replicas.
// fiber preemption could happen in this phase.
for (string_view key : keys_to_journal) {
if (auto journal = owner_->journal(); journal)
RecordExpiryBlocking(db_ind, key);
if (expired_keys_events_recording_)
db_table->expired_keys_events_.emplace_back(key);
}
return return_cb();
SendQueuedInvalidationMessages();
auto time_finish = absl::GetCurrentTimeNanos();
events_.evicted_keys += evicted_items;
DVLOG(2) << "Eviction time (us): " << (time_finish - time_start) / 1000;
return pair<uint64_t, size_t>{evicted_items, evicted_bytes};
}
void DbSlice::CreateDb(DbIndex db_ind) {
@ -1745,10 +1739,12 @@ void DbSlice::OnCbFinish() {
}
void DbSlice::CallChangeCallbacks(DbIndex id, std::string_view key, const ChangeReq& cr) const {
std::unique_lock<LocalBlockingCounter> lk(block_counter_);
if (change_cb_.empty())
return;
// does not preempt, just increments the counter.
unique_lock<LocalLatch> lk(serialization_latch_);
DVLOG(2) << "Running callbacks for key " << key << " in dbid " << id;
const size_t limit = change_cb_.size();

View file

@ -527,11 +527,11 @@ class DbSlice {
void SetNotifyKeyspaceEvents(std::string_view notify_keyspace_events);
bool WillBlockOnJournalWrite() const {
return block_counter_.IsBlocked();
return serialization_latch_.IsBlocked();
}
LocalBlockingCounter* BlockingCounter() {
return &block_counter_;
LocalLatch* GetLatch() {
return &serialization_latch_;
}
void StartSampleTopK(DbIndex db_ind, uint32_t min_freq);
@ -547,7 +547,7 @@ class DbSlice {
size_t StopSampleKeys(DbIndex db_ind);
private:
void PreUpdate(DbIndex db_ind, Iterator it, std::string_view key);
void PreUpdateBlocking(DbIndex db_ind, Iterator it, std::string_view key);
void PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size_t orig_size);
bool DelEmptyPrimeValue(const Context& cntx, Iterator it);
@ -606,8 +606,8 @@ class DbSlice {
// We need this because registered callbacks might yield and when they do so we want
// to avoid Heartbeat or Flushing the db.
// This counter protects us against this case.
mutable LocalBlockingCounter block_counter_;
// This latch protects us against this case.
mutable LocalLatch serialization_latch_;
ShardId shard_id_;
uint8_t cache_mode_ : 1;

View file

@ -610,7 +610,8 @@ void OpScan(const OpArgs& op_args, const ScanOpts& scan_opts, uint64_t* cursor,
// ScanCb can preempt due to journaling expired entries and we need to make sure that
// we enter the callback in a timing when journaling will not cause preemptions. Otherwise,
// the bucket might change as we Traverse and yield.
db_slice.BlockingCounter()->Wait();
db_slice.GetLatch()->Wait();
// Disable flush journal changes to prevent preemtion in traverse.
journal::JournalFlushGuard journal_flush_guard(op_args.shard->journal());
unsigned cnt = 0;
@ -793,6 +794,9 @@ OpStatus OpMove(const OpArgs& op_args, string_view key, DbIndex target_db) {
if (!IsValid(from_res.it))
return OpStatus::KEY_NOTFOUND;
// Ensure target database exists.
db_slice.ActivateDb(target_db);
// Fetch value at key in target db.
DbContext target_cntx = op_args.db_cntx;
target_cntx.db_index = target_db;
@ -800,9 +804,6 @@ OpStatus OpMove(const OpArgs& op_args, string_view key, DbIndex target_db) {
if (IsValid(to_res.it))
return OpStatus::KEY_EXISTS;
// Ensure target database exists.
db_slice.ActivateDb(target_db);
bool sticky = from_res.it->first.IsSticky();
uint64_t exp_ts = db_slice.ExpireTime(from_res.exp_it);
from_res.post_updater.Run();

View file

@ -223,8 +223,8 @@ void RestoreStreamer::Run() {
std::lock_guard guard(big_value_mu_);
// Locking this never preempts. See snapshot.cc for why we need it.
auto* blocking_counter = db_slice_->BlockingCounter();
std::lock_guard blocking_counter_guard(*blocking_counter);
auto* blocking_counter = db_slice_->GetLatch();
lock_guard blocking_counter_guard(*blocking_counter);
stats_.buckets_loop += WriteBucket(it);
});

View file

@ -272,7 +272,8 @@ bool SliceSnapshot::BucketSaveCb(DbIndex db_index, PrimeTable::bucket_iterator i
db_slice_->FlushChangeToEarlierCallbacks(db_index, DbSlice::Iterator::FromPrime(it),
snapshot_version_);
auto* blocking_counter = db_slice_->BlockingCounter();
auto* blocking_counter = db_slice_->GetLatch();
// Locking this never preempts. We merely just increment the underline counter such that
// if SerializeBucket preempts, Heartbeat() won't run because the blocking counter is not
// zero.

View file

@ -71,7 +71,7 @@ TEST_F(StringFamilyTest, Incr) {
ASSERT_THAT(Run({"incrby", "ne", "0"}), IntArg(0));
ASSERT_THAT(Run({"decrby", "a", "-9223372036854775808"}), ErrArg("overflow"));
auto metrics = GetMetrics();
EXPECT_EQ(10, metrics.events.mutations);
EXPECT_EQ(9, metrics.events.mutations);
EXPECT_EQ(0, metrics.events.misses);
EXPECT_EQ(0, metrics.events.hits);
}

View file

@ -62,7 +62,7 @@ void RecordJournal(const OpArgs& op_args, std::string_view cmd, facade::ArgSlice
op_args.tx->LogJournalOnShard(op_args.shard, Payload(cmd, args), shard_cnt);
}
void RecordExpiry(DbIndex dbid, string_view key) {
void RecordExpiryBlocking(DbIndex dbid, string_view key) {
auto journal = EngineShard::tlocal()->journal();
CHECK(journal);

View file

@ -225,7 +225,7 @@ void RecordJournal(const OpArgs& op_args, std::string_view cmd, ArgSlice args,
// Record expiry in journal with independent transaction.
// Must be called from shard thread owning key.
// Might block the calling fiber unless Journal::SetFlushMode(false) is called.
void RecordExpiry(DbIndex dbid, std::string_view key);
void RecordExpiryBlocking(DbIndex dbid, std::string_view key);
// Trigger journal write to sink, no journal record will be added to journal.
// Must be called from shard thread of journal to sink.