chore: add timeout to replication sockets (#3434)

* chore: add timeout fo replication sockets

Master will stop the replication flow if writes could not progress for more than K millis.

---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
Signed-off-by: Roman Gershman <romange@gmail.com>
Co-authored-by: Shahar Mike <chakaz@users.noreply.github.com>
This commit is contained in:
Roman Gershman 2024-08-07 16:33:03 +03:00 committed by GitHub
parent 7c84b8e524
commit 1cbfcd4912
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 133 additions and 41 deletions

View file

@ -33,6 +33,7 @@
using namespace std; using namespace std;
ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible); ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible);
ABSL_DECLARE_FLAG(uint32_t, replication_timeout);
namespace dfly { namespace dfly {
@ -119,6 +120,7 @@ void DflyCmd::ReplicaInfo::Cancel() {
} }
flow->full_sync_fb.JoinIfNeeded(); flow->full_sync_fb.JoinIfNeeded();
flow->conn = nullptr;
}); });
// Wait for error handler to quit. // Wait for error handler to quit.
@ -501,6 +503,7 @@ void DflyCmd::ReplicaOffset(CmdArgList args, ConnectionContext* cntx) {
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) { OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
DCHECK(!flow->full_sync_fb.IsJoinable()); DCHECK(!flow->full_sync_fb.IsJoinable());
DCHECK(shard); DCHECK(shard);
DCHECK(flow->conn);
// The summary contains the LUA scripts, so make sure at least (and exactly one) // The summary contains the LUA scripts, so make sure at least (and exactly one)
// of the flows also contain them. // of the flows also contain them.
@ -527,13 +530,10 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
return OpStatus::CANCELLED; return OpStatus::CANCELLED;
} }
// Shard can be null for io thread. if (flow->start_partial_sync_at.has_value())
if (shard != nullptr) { saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
if (flow->start_partial_sync_at.has_value()) else
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at); saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
else
saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
}
flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx); flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx);
return OpStatus::OK; return OpStatus::OK;
@ -555,12 +555,12 @@ void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) {
OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) { OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
// Create streamer for shard flows. // Create streamer for shard flows.
DCHECK(shard);
DCHECK(flow->conn);
if (shard != nullptr) { flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx));
flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx)); bool send_lsn = flow->version >= DflyVersion::VER4;
bool send_lsn = flow->version >= DflyVersion::VER4; flow->streamer->Start(flow->conn->socket(), send_lsn);
flow->streamer->Start(flow->conn->socket(), send_lsn);
}
// Register cleanup. // Register cleanup.
flow->cleanup = [flow]() { flow->cleanup = [flow]() {
@ -577,6 +577,8 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
error_code ec; error_code ec;
if (ec = flow->saver->SaveBody(cntx, nullptr); ec) { if (ec = flow->saver->SaveBody(cntx, nullptr); ec) {
if (!flow->conn->socket()->IsOpen())
ec = make_error_code(errc::operation_canceled); // we cancelled the operation.
cntx->ReportError(ec); cntx->ReportError(ec);
return; return;
} }
@ -588,8 +590,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
} }
} }
auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) -> std::pair<uint32_t, unsigned> {
-> std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> {
unique_lock lk(mu_); unique_lock lk(mu_);
unsigned sync_id = next_sync_id_++; unsigned sync_id = next_sync_id_++;
@ -612,7 +613,7 @@ auto DflyCmd::CreateSyncSession(ConnectionContext* cntx)
auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr)); auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr));
CHECK(inserted); CHECK(inserted);
return *it; return {it->first, flow_count};
} }
auto DflyCmd::GetReplicaInfoFromConnection(ConnectionContext* cntx) auto DflyCmd::GetReplicaInfoFromConnection(ConnectionContext* cntx)
@ -651,6 +652,40 @@ void DflyCmd::StopReplication(uint32_t sync_id) {
replica_infos_.erase(sync_id); replica_infos_.erase(sync_id);
} }
void DflyCmd::BreakStalledFlowsInShard() {
unique_lock global_lock(mu_, try_to_lock);
// give up on blocking because we run this function periodically in a background fiber,
// so it will eventually grab the lock.
if (!global_lock.owns_lock())
return;
ShardId sid = EngineShard::tlocal()->shard_id();
vector<uint32_t> deleted;
for (auto [sync_id, replica_ptr] : replica_infos_) {
shared_lock replica_lock = replica_ptr->GetSharedLock();
if (!replica_ptr->flows[sid].saver)
continue;
// If saver is present - we are currently using it for full sync.
int64_t last_write_ns = replica_ptr->flows[sid].saver->GetLastWriteTime();
int64_t timeout_ns = int64_t(absl::GetFlag(FLAGS_replication_timeout)) * 1'000'000LL;
int64_t now = absl::GetCurrentTimeNanos();
if (last_write_ns > 0 && last_write_ns + timeout_ns < now) {
VLOG(1) << "Breaking full sync for sync_id " << sync_id << " last_write_ts: " << last_write_ns
<< ", now: " << now;
deleted.push_back(sync_id);
replica_lock.unlock();
replica_ptr->Cancel();
}
}
for (auto sync_id : deleted)
replica_infos_.erase(sync_id);
}
shared_ptr<DflyCmd::ReplicaInfo> DflyCmd::GetReplicaInfo(uint32_t sync_id) { shared_ptr<DflyCmd::ReplicaInfo> DflyCmd::GetReplicaInfo(uint32_t sync_id) {
lock_guard lk(mu_); lock_guard lk(mu_);
@ -807,7 +842,7 @@ void DflyCmd::Shutdown() {
void FlowInfo::TryShutdownSocket() { void FlowInfo::TryShutdownSocket() {
// Close socket for clean disconnect. // Close socket for clean disconnect.
if (conn->socket()->IsOpen()) { if (conn->socket()->IsOpen()) {
(void)conn->socket()->Shutdown(SHUT_RDWR); std::ignore = conn->socket()->Shutdown(SHUT_RDWR);
} }
} }

View file

@ -147,8 +147,8 @@ class DflyCmd {
// Stop all background processes so we can exit in orderly manner. // Stop all background processes so we can exit in orderly manner.
void Shutdown(); void Shutdown();
// Create new sync session. // Create new sync session. Returns (session_id, number of flows)
std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> CreateSyncSession(ConnectionContext* cntx); std::pair<uint32_t, unsigned> CreateSyncSession(ConnectionContext* cntx);
// Master side acces method to replication info of that connection. // Master side acces method to replication info of that connection.
std::shared_ptr<ReplicaInfo> GetReplicaInfoFromConnection(ConnectionContext* cntx); std::shared_ptr<ReplicaInfo> GetReplicaInfoFromConnection(ConnectionContext* cntx);
@ -160,6 +160,9 @@ class DflyCmd {
// Sets metadata. // Sets metadata.
void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version); void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version);
// Tries to break those flows that stuck on socket write for too long time.
void BreakStalledFlowsInShard();
private: private:
// JOURNAL [START/STOP] // JOURNAL [START/STOP]
// Start or stop journaling. // Start or stop journaling.

View file

@ -682,13 +682,14 @@ void EngineShard::RetireExpiredAndEvict() {
} }
void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms, void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
std::function<void()> global_handler) { std::function<void()> shard_handler) {
VLOG(1) << "RunPeriodic with period " << period_ms.count() << "ms"; VLOG(1) << "RunPeriodic with period " << period_ms.count() << "ms";
bool runs_global_periodic = (shard_id() == 0); // Only shard 0 runs global periodic. bool runs_global_periodic = (shard_id() == 0); // Only shard 0 runs global periodic.
unsigned global_count = 0; unsigned global_count = 0;
int64_t last_stats_time = time(nullptr); int64_t last_stats_time = time(nullptr);
int64_t last_heartbeat_ms = INT64_MAX; int64_t last_heartbeat_ms = INT64_MAX;
int64_t last_handler_ms = 0;
while (true) { while (true) {
if (fiber_periodic_done_.WaitFor(period_ms)) { if (fiber_periodic_done_.WaitFor(period_ms)) {
@ -702,6 +703,10 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
} }
Heartbeat(); Heartbeat();
last_heartbeat_ms = fb2::ProactorBase::GetMonotonicTimeNs() / 1000000; last_heartbeat_ms = fb2::ProactorBase::GetMonotonicTimeNs() / 1000000;
if (shard_handler && last_handler_ms + 100 < last_heartbeat_ms) {
last_handler_ms = last_heartbeat_ms;
shard_handler();
}
if (runs_global_periodic) { if (runs_global_periodic) {
++global_count; ++global_count;
@ -727,10 +732,6 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
rss_mem_peak.store(total_rss, memory_order_relaxed); rss_mem_peak.store(total_rss, memory_order_relaxed);
} }
} }
if (global_handler) {
global_handler();
}
} }
} }
} }
@ -903,7 +904,7 @@ size_t GetTieredFileLimit(size_t threads) {
return max_shard_file_size; return max_shard_file_size;
} }
void EngineShardSet::Init(uint32_t sz, std::function<void()> global_handler) { void EngineShardSet::Init(uint32_t sz, std::function<void()> shard_handler) {
CHECK_EQ(0u, size()); CHECK_EQ(0u, size());
shard_queue_.resize(sz); shard_queue_.resize(sz);
@ -922,7 +923,8 @@ void EngineShardSet::Init(uint32_t sz, std::function<void()> global_handler) {
shard->InitTieredStorage(pb, max_shard_file_size); shard->InitTieredStorage(pb, max_shard_file_size);
// Must be last, as it accesses objects initialized above. // Must be last, as it accesses objects initialized above.
shard->StartPeriodicFiber(pb, global_handler); // We can not move shard_handler because this code is called multiple times.
shard->StartPeriodicFiber(pb, shard_handler);
} }
}); });
} }

View file

@ -203,12 +203,12 @@ class EngineShard {
// blocks the calling fiber. // blocks the calling fiber.
void Shutdown(); // called before destructing EngineShard. void Shutdown(); // called before destructing EngineShard.
void StartPeriodicFiber(util::ProactorBase* pb, std::function<void()> global_handler); void StartPeriodicFiber(util::ProactorBase* pb, std::function<void()> shard_handler);
void Heartbeat(); void Heartbeat();
void RetireExpiredAndEvict(); void RetireExpiredAndEvict();
void RunPeriodic(std::chrono::milliseconds period_ms, std::function<void()> global_handler); void RunPeriodic(std::chrono::milliseconds period_ms, std::function<void()> shard_handler);
void CacheStats(); void CacheStats();
@ -288,7 +288,7 @@ class EngineShardSet {
return pp_; return pp_;
} }
void Init(uint32_t size, std::function<void()> global_handler); void Init(uint32_t size, std::function<void()> shard_handler);
// Shutdown sequence: // Shutdown sequence:
// - EngineShardSet.PreShutDown() // - EngineShardSet.PreShutDown()

View file

@ -13,9 +13,9 @@
using namespace facade; using namespace facade;
ABSL_FLAG(uint32_t, replication_stream_timeout, 500, ABSL_FLAG(uint32_t, replication_timeout, 10000,
"Time in milliseconds to wait for the replication output buffer go below " "Time in milliseconds to wait for the replication writes being stuck.");
"the throttle limit.");
ABSL_FLAG(uint32_t, replication_stream_output_limit, 64_KB, ABSL_FLAG(uint32_t, replication_stream_output_limit, 64_KB,
"Time to wait for the replication output buffer go below the throttle limit"); "Time to wait for the replication output buffer go below the throttle limit");
@ -155,8 +155,8 @@ void JournalStreamer::ThrottleIfNeeded() {
if (IsStopped() || !IsStalled()) if (IsStopped() || !IsStalled())
return; return;
auto next = chrono::steady_clock::now() + auto next =
chrono::milliseconds(absl::GetFlag(FLAGS_replication_stream_timeout)); chrono::steady_clock::now() + chrono::milliseconds(absl::GetFlag(FLAGS_replication_timeout));
size_t inflight_start = in_flight_bytes_; size_t inflight_start = in_flight_bytes_;
size_t sent_start = total_sent_; size_t sent_start = total_sent_;

View file

@ -887,7 +887,6 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
ServerState::tlocal()->UpdateChannelStore(cs); ServerState::tlocal()->UpdateChannelStore(cs);
}); });
shard_set->Init(shard_num, nullptr);
const auto tcp_disabled = GetFlag(FLAGS_port) == 0u; const auto tcp_disabled = GetFlag(FLAGS_port) == 0u;
// We assume that listeners.front() is the main_listener // We assume that listeners.front() is the main_listener
// see dfly_main RunEngine // see dfly_main RunEngine
@ -895,6 +894,11 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
acl_family_.Init(listeners.front(), &user_registry_); acl_family_.Init(listeners.front(), &user_registry_);
} }
// Initialize shard_set with a global callback running once in a while in the shard threads.
shard_set->Init(shard_num, [this] { server_family_.GetDflyCmd()->BreakStalledFlowsInShard(); });
// Requires that shard_set will be initialized before because server_family_.Init might
// load the snapshot.
server_family_.Init(acceptor, std::move(listeners)); server_family_.Init(acceptor, std::move(listeners));
} }

View file

@ -1135,10 +1135,15 @@ class RdbSaver::Impl {
return &meta_serializer_; return &meta_serializer_;
} }
int64_t last_write_ts() const {
return last_write_time_ns_;
}
private: private:
unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard); unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard);
io::Sink* sink_; io::Sink* sink_;
int64_t last_write_time_ns_ = -1; // last write call.
vector<unique_ptr<SliceSnapshot>> shard_snapshots_; vector<unique_ptr<SliceSnapshot>> shard_snapshots_;
// used for serializing non-body components in the calling fiber. // used for serializing non-body components in the calling fiber.
RdbSerializer meta_serializer_; RdbSerializer meta_serializer_;
@ -1263,10 +1268,12 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
continue; continue;
DVLOG(2) << "Pulled " << record->id; DVLOG(2) << "Pulled " << record->id;
auto before = absl::GetCurrentTimeNanos(); last_write_time_ns_ = absl::GetCurrentTimeNanos();
io_error = sink_->Write(io::Buffer(record->value)); io_error = sink_->Write(io::Buffer(record->value));
stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - before) / 1'000;
stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - last_write_time_ns_) / 1'000;
stats.rdb_save_count++; stats.rdb_save_count++;
last_write_time_ns_ = -1;
if (io_error) { if (io_error) {
VLOG(1) << "Error writing to sink " << io_error.message(); VLOG(1) << "Error writing to sink " << io_error.message();
break; break;
@ -1369,7 +1376,10 @@ RdbSaver::SnapshotStats RdbSaver::Impl::GetCurrentSnapshotProgress() const {
} }
error_code RdbSaver::Impl::FlushSerializer() { error_code RdbSaver::Impl::FlushSerializer() {
return serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry); last_write_time_ns_ = absl::GetCurrentTimeNanos();
auto ec = serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry);
last_write_time_ns_ = -1;
return ec;
} }
RdbSaver::GlobalData RdbSaver::GetGlobalData(const Service* service) { RdbSaver::GlobalData RdbSaver::GetGlobalData(const Service* service) {
@ -1482,7 +1492,6 @@ error_code RdbSaver::SaveBody(Context* cntx, RdbTypeFreqMap* freq_map) {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size(); VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation()); error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
if (io_error) { if (io_error) {
LOG(ERROR) << "io error " << io_error;
return io_error; return io_error;
} }
if (cntx->GetError()) { if (cntx->GetError()) {
@ -1572,6 +1581,10 @@ RdbSaver::SnapshotStats RdbSaver::GetCurrentSnapshotProgress() const {
return impl_->GetCurrentSnapshotProgress(); return impl_->GetCurrentSnapshotProgress();
} }
int64_t RdbSaver::GetLastWriteTime() const {
return impl_->last_write_ts();
}
void SerializerBase::AllocateCompressorOnce() { void SerializerBase::AllocateCompressorOnce() {
if (compressor_impl_) { if (compressor_impl_) {
return; return;

View file

@ -122,6 +122,10 @@ class RdbSaver {
// Fetch global data to be serialized in summary part of a snapshot / full sync. // Fetch global data to be serialized in summary part of a snapshot / full sync.
static GlobalData GetGlobalData(const Service* service); static GlobalData GetGlobalData(const Service* service);
// Returns time in nanos of start of the last pending write interaction.
// Returns -1 if no write operations are currently pending.
int64_t GetLastWriteTime() const;
private: private:
class Impl; class Impl;

View file

@ -2665,7 +2665,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
std::string_view arg = ArgS(args, i + 1); std::string_view arg = ArgS(args, i + 1);
if (cmd == "CAPA") { if (cmd == "CAPA") {
if (arg == "dragonfly" && args.size() == 2 && i == 0) { if (arg == "dragonfly" && args.size() == 2 && i == 0) {
auto [sid, replica_info] = dfly_cmd_->CreateSyncSession(cntx); auto [sid, flow_count] = dfly_cmd_->CreateSyncSession(cntx);
cntx->conn()->SetName(absl::StrCat("repl_ctrl_", sid)); cntx->conn()->SetName(absl::StrCat("repl_ctrl_", sid));
string sync_id = absl::StrCat("SYNC", sid); string sync_id = absl::StrCat("SYNC", sid);
@ -2681,7 +2681,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
rb->StartArray(4); rb->StartArray(4);
rb->SendSimpleString(master_replid_); rb->SendSimpleString(master_replid_);
rb->SendSimpleString(sync_id); rb->SendSimpleString(sync_id);
rb->SendLong(replica_info->flows.size()); rb->SendLong(flow_count);
rb->SendLong(unsigned(DflyVersion::CURRENT_VER)); rb->SendLong(unsigned(DflyVersion::CURRENT_VER));
return; return;
} }

View file

@ -2174,7 +2174,7 @@ async def test_replica_reconnect(df_factory, break_conn):
c_master = master.client() c_master = master.client()
c_replica = replica.client() c_replica = replica.client()
await c_master.execute_command("set k 12345") await c_master.set("k", "12345")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}") await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica) await wait_available_async(c_replica)
assert (await c_replica.info("REPLICATION"))["master_link_status"] == "up" assert (await c_replica.info("REPLICATION"))["master_link_status"] == "up"
@ -2230,3 +2230,34 @@ async def test_announce_ip_port(df_factory):
host, port, _ = node[0] host, port, _ = node[0]
assert host == "overrode-host" assert host == "overrode-host"
assert port == "1337" assert port == "1337"
async def test_master_stalled_disconnect(df_factory: DflyInstanceFactory):
# disconnect after 1 second of being blocked
master = df_factory.create(replication_timeout=1000)
replica = df_factory.create()
df_factory.start_all([master, replica])
c_master = master.client()
c_replica = replica.client()
await c_master.execute_command("debug", "populate", "200000", "foo", "500")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
@assert_eventually
async def check_replica_connected():
repl_info = await c_master.info("replication")
assert "slave0" in repl_info
@assert_eventually
async def check_replica_disconnected():
repl_info = await c_master.info("replication")
assert "slave0" not in repl_info
await check_replica_connected()
await c_replica.execute_command("DEBUG REPLICA PAUSE")
await check_replica_connected() # still connected
await asyncio.sleep(1) # wait for the master to recognize it's being blocked
await check_replica_disconnected()
df_factory.stop_all()