diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 61ec2596a..a9cde2270 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1734,6 +1734,7 @@ void Connection::LaunchAsyncFiberIfNeeded() { } } +// Should never block - the callers may run in as a a brief callback. void Connection::SendAsync(MessageHandle msg) { DCHECK(cc_); DCHECK(listener()); diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index a9ba499e3..3b4ca7908 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -971,8 +971,8 @@ uint32_t DbSlice::GetMCFlag(DbIndex db_ind, const PrimeKey& key) const { auto& db = *db_arr_[db_ind]; auto it = db.mcflag.Find(key); if (it.is_done()) { - LOG(ERROR) << "Internal error, inconsistent state, mcflag should be present but not found " - << key.ToString(); + LOG(DFATAL) << "Internal error, inconsistent state, mcflag should be present but not found " + << key.ToString(); return 0; } return it->second; @@ -1186,7 +1186,9 @@ void DbSlice::PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size db.slots_stats[KeySlot(key)].total_writes += 1; } - SendInvalidationTrackingMessage(key); + if (!client_tracking_map_.empty()) { + QueueInvalidationTrackingMessageAtomic(key); + } } DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) const { @@ -1196,7 +1198,7 @@ DbSlice::ItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, Iterator it) cons DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) const { if (!it->second.HasExpire()) { - LOG(ERROR) << "Invalid call to ExpireIfNeeded"; + LOG(DFATAL) << "Invalid call to ExpireIfNeeded"; return {it, ExpireIterator{}}; } @@ -1212,10 +1214,10 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica() || !expire_allowed_) return {it, expire_it}; } else { - LOG(ERROR) << "Internal error, entry " << it->first.ToString() - << " not found in expire table, db_index: " << cntx.db_index - << ", expire table size: " << db->expire.size() - << ", prime table size: " << db->prime.size() << util::fb2::GetStacktrace(); + LOG(DFATAL) << "Internal error, entry " << it->first.ToString() + << " not found in expire table, db_index: " << cntx.db_index + << ", expire table size: " << db->expire.size() + << ", prime table size: " << db->prime.size() << util::fb2::GetStacktrace(); } string scratch; @@ -1234,9 +1236,9 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato doc_del_cb_(key, cntx, it->second); } - const_cast(this)->PerformDeletion(Iterator(it, StringOrView::FromView(key)), - ExpIterator(expire_it, StringOrView::FromView(key)), - db.get()); + const_cast(this)->PerformDeletionAtomic( + Iterator(it, StringOrView::FromView(key)), + ExpIterator(expire_it, StringOrView::FromView(key)), db.get()); ++events_.expired_keys; @@ -1258,7 +1260,7 @@ void DbSlice::ExpireAllIfNeeded() { auto cb = [&](ExpireTable::iterator exp_it) { auto prime_it = db.prime.Find(exp_it->first); if (!IsValid(prime_it)) { - LOG(ERROR) << "Expire entry " << exp_it->first.ToString() << " not found in prime table"; + LOG(DFATAL) << "Expire entry " << exp_it->first.ToString() << " not found in prime table"; return; } ExpireIfNeeded(Context{nullptr, db_index, GetCurrentTimeMs()}, prime_it); @@ -1404,6 +1406,7 @@ pair DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t s if (expired_keys_events_recording_) db_table->expired_keys_events_.emplace_back(key); } + SendQueuedInvalidationMessages(); auto time_finish = absl::GetCurrentTimeNanos(); events_.evicted_keys += evicted_items; @@ -1545,34 +1548,49 @@ void DbSlice::SetNotifyKeyspaceEvents(std::string_view notify_keyspace_events) { expired_keys_events_recording_ = !notify_keyspace_events.empty(); } -void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { - if (client_tracking_map_.empty()) - return; - +void DbSlice::QueueInvalidationTrackingMessageAtomic(std::string_view key) { auto it = client_tracking_map_.find(key); if (it == client_tracking_map_.end()) { return; } - auto& client_set = it->second; - // Notify all the clients. We copy key because we dispatch briefly below and - // we need to preserve its lifetime - // TODO this key is further copied within DispatchFiber. Fix this. - auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, - util::ProactorBase*) { - for (auto& client : client_set) { - if (client.IsExpired() || (client.Thread() != idx)) { - continue; - } - auto* conn = client.Get(); - auto* cntx = static_cast(conn->cntx()); - if (cntx && cntx->conn_state.tracking_info_.IsTrackingOn()) { - conn->SendInvalidationMessageAsync({key}); - } + + ConnectionHashSet moved_set = std::move(it->second); + client_tracking_map_.erase(it); + + auto [pend_it, inserted] = pending_send_map_.emplace(key, std::move(moved_set)); + if (!inserted) { + ConnectionHashSet& client_set = pend_it->second; + for (auto& client : moved_set) { + client_set.insert(client); } - }; - shard_set->pool()->DispatchBrief(std::move(cb)); - // remove this key from the tracking table as the key no longer exists - client_tracking_map_.erase(key); + } +} + +void DbSlice::SendQueuedInvalidationMessages() { + // We run while loop because when we block below, we might have new items added to + // pending_send_map_. + while (!pending_send_map_.empty()) { + auto local_map = std::move(pending_send_map_); + + // Notify all the clients. this function is not efficient, + // because it broadcasts to all threads unrelated to the subscribers for the key. + auto cb = [&](unsigned idx, util::ProactorBase*) { + for (auto& [key, client_list] : local_map) { + for (auto& client : client_list) { + if (client.IsExpired() || (client.Thread() != idx)) { + continue; + } + auto* conn = client.Get(); + auto* cntx = static_cast(conn->cntx()); + if (cntx && cntx->conn_state.tracking_info_.IsTrackingOn()) { + conn->SendInvalidationMessageAsync({key}); + } + } + } + }; + + shard_set->pool()->AwaitBrief(std::move(cb)); + } } void DbSlice::StartSampleTopK(DbIndex db_ind, uint32_t min_freq) { @@ -1641,7 +1659,8 @@ void DbSlice::PerformDeletion(PrimeIterator del_it, DbTable* table) { return PerformDeletion(Iterator::FromPrime(del_it), table); } -void DbSlice::PerformDeletion(Iterator del_it, ExpIterator exp_it, DbTable* table) { +void DbSlice::PerformDeletionAtomic(Iterator del_it, ExpIterator exp_it, DbTable* table) { + FiberAtomicGuard guard; size_t table_before = table->table_memory(); if (!exp_it.is_done()) { table->expire.Erase(exp_it.GetInnerIt()); @@ -1649,8 +1668,8 @@ void DbSlice::PerformDeletion(Iterator del_it, ExpIterator exp_it, DbTable* tabl if (del_it->second.HasFlag()) { if (table->mcflag.Erase(del_it->first) == 0) { - LOG(ERROR) << "Internal error, inconsistent state, mcflag should be present but not found " - << del_it->first.ToString(); + LOG(DFATAL) << "Internal error, inconsistent state, mcflag should be present but not found " + << del_it->first.ToString(); } } @@ -1700,7 +1719,9 @@ void DbSlice::PerformDeletion(Iterator del_it, ExpIterator exp_it, DbTable* tabl --entries_count_; memory_budget_ += (value_heap_size + key_size_used); - SendInvalidationTrackingMessage(del_it.key()); + if (!client_tracking_map_.empty()) { + QueueInvalidationTrackingMessageAtomic(del_it.key()); + } } void DbSlice::PerformDeletion(Iterator del_it, DbTable* table) { @@ -1710,13 +1731,17 @@ void DbSlice::PerformDeletion(Iterator del_it, DbTable* table) { DCHECK(!exp_it.is_done()); } - PerformDeletion(del_it, exp_it, table); + PerformDeletionAtomic(del_it, exp_it, table); } void DbSlice::OnCbFinish() { // TBD update bumpups logic we can not clear now after cb finish as cb can preempt // btw what do we do with inline? fetched_items_.clear(); + + if (!pending_send_map_.empty()) { + SendQueuedInvalidationMessages(); + } } void DbSlice::CallChangeCallbacks(DbIndex id, std::string_view key, const ChangeReq& cr) const { diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 631759d00..95a6c7446 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -568,11 +568,13 @@ class DbSlice { // Clear tiered storage entries for the specified indices. void ClearOffloadedEntries(absl::Span indices, const DbTableArray& db_arr); - void PerformDeletion(Iterator del_it, ExpIterator exp_it, DbTable* table); + // + void PerformDeletionAtomic(Iterator del_it, ExpIterator exp_it, DbTable* table); void PerformDeletion(PrimeIterator del_it, DbTable* table); - // Send invalidation message to the clients that are tracking the change to a key. - void SendInvalidationTrackingMessage(std::string_view key); + // Queues invalidation message to the clients that are tracking the change to a key. + void QueueInvalidationTrackingMessageAtomic(std::string_view key); + void SendQueuedInvalidationMessages(); void CreateDb(DbIndex index); @@ -679,7 +681,7 @@ class DbSlice { absl::flat_hash_map, absl::container_internal::hash_default_eq, AllocatorType> - client_tracking_map_; + client_tracking_map_, pending_send_map_; class PrimeBumpPolicy; }; diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index accdfc870..613341fcc 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -441,11 +441,9 @@ TEST_F(ServerFamilyTest, ClientTrackingUpdateKey) { std::vector keys_invalidated; for (unsigned int i = 2; i < 6; ++i) keys_invalidated.push_back(GetInvalidationMessage("IO0", i).key); - ASSERT_THAT(keys_invalidated, ElementsAre("X1", "Y3", "Z2", "Z4")); + ASSERT_THAT(keys_invalidated, UnorderedElementsAre("X1", "Y3", "Z2", "Z4")); - // The following doesn't work correctly as we currently can't mock listener. - // flushdb command - // Run({"FLUSHDB"}); + Run({"FLUSHDB"}); } TEST_F(ServerFamilyTest, ClientTrackingDeleteKey) { diff --git a/tests/dragonfly/cluster_test.py b/tests/dragonfly/cluster_test.py index c17f8f582..e0ee1cf36 100644 --- a/tests/dragonfly/cluster_test.py +++ b/tests/dragonfly/cluster_test.py @@ -2940,6 +2940,7 @@ async def test_migration_rebalance_node(df_factory: DflyInstanceFactory, df_seed assert await seeder.compare(capture, nodes[1].instance.port) +@pytest.mark.skip("Flaky test") @dfly_args({"proactor_threads": 2, "cluster_mode": "yes"}) async def test_cluster_sharded_pub_sub(df_factory: DflyInstanceFactory): nodes = [df_factory.create(port=next(next_port)) for i in range(2)]