From 1cbfcd49123e2e15f5ff2efda7935d953a74866c Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Wed, 7 Aug 2024 16:33:03 +0300 Subject: [PATCH] chore: add timeout to replication sockets (#3434) * chore: add timeout fo replication sockets Master will stop the replication flow if writes could not progress for more than K millis. --------- Signed-off-by: Roman Gershman Signed-off-by: Roman Gershman Co-authored-by: Shahar Mike --- src/server/dflycmd.cc | 67 ++++++++++++++++++++++------- src/server/dflycmd.h | 7 ++- src/server/engine_shard_set.cc | 16 ++++--- src/server/engine_shard_set.h | 6 +-- src/server/journal/streamer.cc | 10 ++--- src/server/main_service.cc | 6 ++- src/server/rdb_save.cc | 21 +++++++-- src/server/rdb_save.h | 4 ++ src/server/server_family.cc | 4 +- tests/dragonfly/replication_test.py | 33 +++++++++++++- 10 files changed, 133 insertions(+), 41 deletions(-) diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index 0301b2515..82b15b6a9 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -33,6 +33,7 @@ using namespace std; ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible); +ABSL_DECLARE_FLAG(uint32_t, replication_timeout); namespace dfly { @@ -119,6 +120,7 @@ void DflyCmd::ReplicaInfo::Cancel() { } flow->full_sync_fb.JoinIfNeeded(); + flow->conn = nullptr; }); // Wait for error handler to quit. @@ -501,6 +503,7 @@ void DflyCmd::ReplicaOffset(CmdArgList args, ConnectionContext* cntx) { OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) { DCHECK(!flow->full_sync_fb.IsJoinable()); DCHECK(shard); + DCHECK(flow->conn); // The summary contains the LUA scripts, so make sure at least (and exactly one) // of the flows also contain them. @@ -527,13 +530,10 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha return OpStatus::CANCELLED; } - // Shard can be null for io thread. - if (shard != nullptr) { - if (flow->start_partial_sync_at.has_value()) - saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at); - else - saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard); - } + if (flow->start_partial_sync_at.has_value()) + saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at); + else + saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard); flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx); return OpStatus::OK; @@ -555,12 +555,12 @@ void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) { OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) { // Create streamer for shard flows. + DCHECK(shard); + DCHECK(flow->conn); - if (shard != nullptr) { - flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx)); - bool send_lsn = flow->version >= DflyVersion::VER4; - flow->streamer->Start(flow->conn->socket(), send_lsn); - } + flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx)); + bool send_lsn = flow->version >= DflyVersion::VER4; + flow->streamer->Start(flow->conn->socket(), send_lsn); // Register cleanup. flow->cleanup = [flow]() { @@ -577,6 +577,8 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) { error_code ec; if (ec = flow->saver->SaveBody(cntx, nullptr); ec) { + if (!flow->conn->socket()->IsOpen()) + ec = make_error_code(errc::operation_canceled); // we cancelled the operation. cntx->ReportError(ec); return; } @@ -588,8 +590,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) { } } -auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) - -> std::pair> { +auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) -> std::pair { unique_lock lk(mu_); unsigned sync_id = next_sync_id_++; @@ -612,7 +613,7 @@ auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr)); CHECK(inserted); - return *it; + return {it->first, flow_count}; } auto DflyCmd::GetReplicaInfoFromConnection(ConnectionContext* cntx) @@ -651,6 +652,40 @@ void DflyCmd::StopReplication(uint32_t sync_id) { replica_infos_.erase(sync_id); } +void DflyCmd::BreakStalledFlowsInShard() { + unique_lock global_lock(mu_, try_to_lock); + + // give up on blocking because we run this function periodically in a background fiber, + // so it will eventually grab the lock. + if (!global_lock.owns_lock()) + return; + + ShardId sid = EngineShard::tlocal()->shard_id(); + vector deleted; + + for (auto [sync_id, replica_ptr] : replica_infos_) { + shared_lock replica_lock = replica_ptr->GetSharedLock(); + + if (!replica_ptr->flows[sid].saver) + continue; + + // If saver is present - we are currently using it for full sync. + int64_t last_write_ns = replica_ptr->flows[sid].saver->GetLastWriteTime(); + int64_t timeout_ns = int64_t(absl::GetFlag(FLAGS_replication_timeout)) * 1'000'000LL; + int64_t now = absl::GetCurrentTimeNanos(); + if (last_write_ns > 0 && last_write_ns + timeout_ns < now) { + VLOG(1) << "Breaking full sync for sync_id " << sync_id << " last_write_ts: " << last_write_ns + << ", now: " << now; + deleted.push_back(sync_id); + replica_lock.unlock(); + replica_ptr->Cancel(); + } + } + + for (auto sync_id : deleted) + replica_infos_.erase(sync_id); +} + shared_ptr DflyCmd::GetReplicaInfo(uint32_t sync_id) { lock_guard lk(mu_); @@ -807,7 +842,7 @@ void DflyCmd::Shutdown() { void FlowInfo::TryShutdownSocket() { // Close socket for clean disconnect. if (conn->socket()->IsOpen()) { - (void)conn->socket()->Shutdown(SHUT_RDWR); + std::ignore = conn->socket()->Shutdown(SHUT_RDWR); } } diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index 11a14c89a..0112871a3 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -147,8 +147,8 @@ class DflyCmd { // Stop all background processes so we can exit in orderly manner. void Shutdown(); - // Create new sync session. - std::pair> CreateSyncSession(ConnectionContext* cntx); + // Create new sync session. Returns (session_id, number of flows) + std::pair CreateSyncSession(ConnectionContext* cntx); // Master side acces method to replication info of that connection. std::shared_ptr GetReplicaInfoFromConnection(ConnectionContext* cntx); @@ -160,6 +160,9 @@ class DflyCmd { // Sets metadata. void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version); + // Tries to break those flows that stuck on socket write for too long time. + void BreakStalledFlowsInShard(); + private: // JOURNAL [START/STOP] // Start or stop journaling. diff --git a/src/server/engine_shard_set.cc b/src/server/engine_shard_set.cc index 9512c80bb..a7997ff94 100644 --- a/src/server/engine_shard_set.cc +++ b/src/server/engine_shard_set.cc @@ -682,13 +682,14 @@ void EngineShard::RetireExpiredAndEvict() { } void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms, - std::function global_handler) { + std::function shard_handler) { VLOG(1) << "RunPeriodic with period " << period_ms.count() << "ms"; bool runs_global_periodic = (shard_id() == 0); // Only shard 0 runs global periodic. unsigned global_count = 0; int64_t last_stats_time = time(nullptr); int64_t last_heartbeat_ms = INT64_MAX; + int64_t last_handler_ms = 0; while (true) { if (fiber_periodic_done_.WaitFor(period_ms)) { @@ -702,6 +703,10 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms, } Heartbeat(); last_heartbeat_ms = fb2::ProactorBase::GetMonotonicTimeNs() / 1000000; + if (shard_handler && last_handler_ms + 100 < last_heartbeat_ms) { + last_handler_ms = last_heartbeat_ms; + shard_handler(); + } if (runs_global_periodic) { ++global_count; @@ -727,10 +732,6 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms, rss_mem_peak.store(total_rss, memory_order_relaxed); } } - - if (global_handler) { - global_handler(); - } } } } @@ -903,7 +904,7 @@ size_t GetTieredFileLimit(size_t threads) { return max_shard_file_size; } -void EngineShardSet::Init(uint32_t sz, std::function global_handler) { +void EngineShardSet::Init(uint32_t sz, std::function shard_handler) { CHECK_EQ(0u, size()); shard_queue_.resize(sz); @@ -922,7 +923,8 @@ void EngineShardSet::Init(uint32_t sz, std::function global_handler) { shard->InitTieredStorage(pb, max_shard_file_size); // Must be last, as it accesses objects initialized above. - shard->StartPeriodicFiber(pb, global_handler); + // We can not move shard_handler because this code is called multiple times. + shard->StartPeriodicFiber(pb, shard_handler); } }); } diff --git a/src/server/engine_shard_set.h b/src/server/engine_shard_set.h index fe925b547..b406c73a7 100644 --- a/src/server/engine_shard_set.h +++ b/src/server/engine_shard_set.h @@ -203,12 +203,12 @@ class EngineShard { // blocks the calling fiber. void Shutdown(); // called before destructing EngineShard. - void StartPeriodicFiber(util::ProactorBase* pb, std::function global_handler); + void StartPeriodicFiber(util::ProactorBase* pb, std::function shard_handler); void Heartbeat(); void RetireExpiredAndEvict(); - void RunPeriodic(std::chrono::milliseconds period_ms, std::function global_handler); + void RunPeriodic(std::chrono::milliseconds period_ms, std::function shard_handler); void CacheStats(); @@ -288,7 +288,7 @@ class EngineShardSet { return pp_; } - void Init(uint32_t size, std::function global_handler); + void Init(uint32_t size, std::function shard_handler); // Shutdown sequence: // - EngineShardSet.PreShutDown() diff --git a/src/server/journal/streamer.cc b/src/server/journal/streamer.cc index 182237c83..3640dabd0 100644 --- a/src/server/journal/streamer.cc +++ b/src/server/journal/streamer.cc @@ -13,9 +13,9 @@ using namespace facade; -ABSL_FLAG(uint32_t, replication_stream_timeout, 500, - "Time in milliseconds to wait for the replication output buffer go below " - "the throttle limit."); +ABSL_FLAG(uint32_t, replication_timeout, 10000, + "Time in milliseconds to wait for the replication writes being stuck."); + ABSL_FLAG(uint32_t, replication_stream_output_limit, 64_KB, "Time to wait for the replication output buffer go below the throttle limit"); @@ -155,8 +155,8 @@ void JournalStreamer::ThrottleIfNeeded() { if (IsStopped() || !IsStalled()) return; - auto next = chrono::steady_clock::now() + - chrono::milliseconds(absl::GetFlag(FLAGS_replication_stream_timeout)); + auto next = + chrono::steady_clock::now() + chrono::milliseconds(absl::GetFlag(FLAGS_replication_timeout)); size_t inflight_start = in_flight_bytes_; size_t sent_start = total_sent_; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 16e79e7df..2d20f457d 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -887,7 +887,6 @@ void Service::Init(util::AcceptServer* acceptor, std::vector ServerState::tlocal()->UpdateChannelStore(cs); }); - shard_set->Init(shard_num, nullptr); const auto tcp_disabled = GetFlag(FLAGS_port) == 0u; // We assume that listeners.front() is the main_listener // see dfly_main RunEngine @@ -895,6 +894,11 @@ void Service::Init(util::AcceptServer* acceptor, std::vector acl_family_.Init(listeners.front(), &user_registry_); } + // Initialize shard_set with a global callback running once in a while in the shard threads. + shard_set->Init(shard_num, [this] { server_family_.GetDflyCmd()->BreakStalledFlowsInShard(); }); + + // Requires that shard_set will be initialized before because server_family_.Init might + // load the snapshot. server_family_.Init(acceptor, std::move(listeners)); } diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index 34741b388..d5d7ce906 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -1135,10 +1135,15 @@ class RdbSaver::Impl { return &meta_serializer_; } + int64_t last_write_ts() const { + return last_write_time_ns_; + } + private: unique_ptr& GetSnapshot(EngineShard* shard); io::Sink* sink_; + int64_t last_write_time_ns_ = -1; // last write call. vector> shard_snapshots_; // used for serializing non-body components in the calling fiber. RdbSerializer meta_serializer_; @@ -1263,10 +1268,12 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) { continue; DVLOG(2) << "Pulled " << record->id; - auto before = absl::GetCurrentTimeNanos(); + last_write_time_ns_ = absl::GetCurrentTimeNanos(); io_error = sink_->Write(io::Buffer(record->value)); - stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - before) / 1'000; + + stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - last_write_time_ns_) / 1'000; stats.rdb_save_count++; + last_write_time_ns_ = -1; if (io_error) { VLOG(1) << "Error writing to sink " << io_error.message(); break; @@ -1369,7 +1376,10 @@ RdbSaver::SnapshotStats RdbSaver::Impl::GetCurrentSnapshotProgress() const { } error_code RdbSaver::Impl::FlushSerializer() { - return serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry); + last_write_time_ns_ = absl::GetCurrentTimeNanos(); + auto ec = serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry); + last_write_time_ns_ = -1; + return ec; } RdbSaver::GlobalData RdbSaver::GetGlobalData(const Service* service) { @@ -1482,7 +1492,6 @@ error_code RdbSaver::SaveBody(Context* cntx, RdbTypeFreqMap* freq_map) { VLOG(1) << "SaveBody , snapshots count: " << impl_->Size(); error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation()); if (io_error) { - LOG(ERROR) << "io error " << io_error; return io_error; } if (cntx->GetError()) { @@ -1572,6 +1581,10 @@ RdbSaver::SnapshotStats RdbSaver::GetCurrentSnapshotProgress() const { return impl_->GetCurrentSnapshotProgress(); } +int64_t RdbSaver::GetLastWriteTime() const { + return impl_->last_write_ts(); +} + void SerializerBase::AllocateCompressorOnce() { if (compressor_impl_) { return; diff --git a/src/server/rdb_save.h b/src/server/rdb_save.h index 25b49dfce..11b6c17eb 100644 --- a/src/server/rdb_save.h +++ b/src/server/rdb_save.h @@ -122,6 +122,10 @@ class RdbSaver { // Fetch global data to be serialized in summary part of a snapshot / full sync. static GlobalData GetGlobalData(const Service* service); + // Returns time in nanos of start of the last pending write interaction. + // Returns -1 if no write operations are currently pending. + int64_t GetLastWriteTime() const; + private: class Impl; diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 52cb052dc..71be48bda 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -2665,7 +2665,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) { std::string_view arg = ArgS(args, i + 1); if (cmd == "CAPA") { if (arg == "dragonfly" && args.size() == 2 && i == 0) { - auto [sid, replica_info] = dfly_cmd_->CreateSyncSession(cntx); + auto [sid, flow_count] = dfly_cmd_->CreateSyncSession(cntx); cntx->conn()->SetName(absl::StrCat("repl_ctrl_", sid)); string sync_id = absl::StrCat("SYNC", sid); @@ -2681,7 +2681,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) { rb->StartArray(4); rb->SendSimpleString(master_replid_); rb->SendSimpleString(sync_id); - rb->SendLong(replica_info->flows.size()); + rb->SendLong(flow_count); rb->SendLong(unsigned(DflyVersion::CURRENT_VER)); return; } diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index 4a21fe53e..11d64ae1e 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -2174,7 +2174,7 @@ async def test_replica_reconnect(df_factory, break_conn): c_master = master.client() c_replica = replica.client() - await c_master.execute_command("set k 12345") + await c_master.set("k", "12345") await c_replica.execute_command(f"REPLICAOF localhost {master.port}") await wait_available_async(c_replica) assert (await c_replica.info("REPLICATION"))["master_link_status"] == "up" @@ -2230,3 +2230,34 @@ async def test_announce_ip_port(df_factory): host, port, _ = node[0] assert host == "overrode-host" assert port == "1337" + + +async def test_master_stalled_disconnect(df_factory: DflyInstanceFactory): + # disconnect after 1 second of being blocked + master = df_factory.create(replication_timeout=1000) + replica = df_factory.create() + + df_factory.start_all([master, replica]) + + c_master = master.client() + c_replica = replica.client() + + await c_master.execute_command("debug", "populate", "200000", "foo", "500") + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + + @assert_eventually + async def check_replica_connected(): + repl_info = await c_master.info("replication") + assert "slave0" in repl_info + + @assert_eventually + async def check_replica_disconnected(): + repl_info = await c_master.info("replication") + assert "slave0" not in repl_info + + await check_replica_connected() + await c_replica.execute_command("DEBUG REPLICA PAUSE") + await check_replica_connected() # still connected + await asyncio.sleep(1) # wait for the master to recognize it's being blocked + await check_replica_disconnected() + df_factory.stop_all()