chore: add timeout to replication sockets (#3434)

* chore: add timeout fo replication sockets

Master will stop the replication flow if writes could not progress for more than K millis.

---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
Signed-off-by: Roman Gershman <romange@gmail.com>
Co-authored-by: Shahar Mike <chakaz@users.noreply.github.com>
This commit is contained in:
Roman Gershman 2024-08-07 16:33:03 +03:00 committed by GitHub
parent 7c84b8e524
commit 1cbfcd4912
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 133 additions and 41 deletions

View file

@ -33,6 +33,7 @@
using namespace std;
ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible);
ABSL_DECLARE_FLAG(uint32_t, replication_timeout);
namespace dfly {
@ -119,6 +120,7 @@ void DflyCmd::ReplicaInfo::Cancel() {
}
flow->full_sync_fb.JoinIfNeeded();
flow->conn = nullptr;
});
// Wait for error handler to quit.
@ -501,6 +503,7 @@ void DflyCmd::ReplicaOffset(CmdArgList args, ConnectionContext* cntx) {
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
DCHECK(!flow->full_sync_fb.IsJoinable());
DCHECK(shard);
DCHECK(flow->conn);
// The summary contains the LUA scripts, so make sure at least (and exactly one)
// of the flows also contain them.
@ -527,13 +530,10 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
return OpStatus::CANCELLED;
}
// Shard can be null for io thread.
if (shard != nullptr) {
if (flow->start_partial_sync_at.has_value())
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
else
saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
}
if (flow->start_partial_sync_at.has_value())
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
else
saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx);
return OpStatus::OK;
@ -555,12 +555,12 @@ void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) {
OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
// Create streamer for shard flows.
DCHECK(shard);
DCHECK(flow->conn);
if (shard != nullptr) {
flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx));
bool send_lsn = flow->version >= DflyVersion::VER4;
flow->streamer->Start(flow->conn->socket(), send_lsn);
}
flow->streamer.reset(new JournalStreamer(sf_->journal(), cntx));
bool send_lsn = flow->version >= DflyVersion::VER4;
flow->streamer->Start(flow->conn->socket(), send_lsn);
// Register cleanup.
flow->cleanup = [flow]() {
@ -577,6 +577,8 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
error_code ec;
if (ec = flow->saver->SaveBody(cntx, nullptr); ec) {
if (!flow->conn->socket()->IsOpen())
ec = make_error_code(errc::operation_canceled); // we cancelled the operation.
cntx->ReportError(ec);
return;
}
@ -588,8 +590,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
}
}
auto DflyCmd::CreateSyncSession(ConnectionContext* cntx)
-> std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> {
auto DflyCmd::CreateSyncSession(ConnectionContext* cntx) -> std::pair<uint32_t, unsigned> {
unique_lock lk(mu_);
unsigned sync_id = next_sync_id_++;
@ -612,7 +613,7 @@ auto DflyCmd::CreateSyncSession(ConnectionContext* cntx)
auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr));
CHECK(inserted);
return *it;
return {it->first, flow_count};
}
auto DflyCmd::GetReplicaInfoFromConnection(ConnectionContext* cntx)
@ -651,6 +652,40 @@ void DflyCmd::StopReplication(uint32_t sync_id) {
replica_infos_.erase(sync_id);
}
void DflyCmd::BreakStalledFlowsInShard() {
unique_lock global_lock(mu_, try_to_lock);
// give up on blocking because we run this function periodically in a background fiber,
// so it will eventually grab the lock.
if (!global_lock.owns_lock())
return;
ShardId sid = EngineShard::tlocal()->shard_id();
vector<uint32_t> deleted;
for (auto [sync_id, replica_ptr] : replica_infos_) {
shared_lock replica_lock = replica_ptr->GetSharedLock();
if (!replica_ptr->flows[sid].saver)
continue;
// If saver is present - we are currently using it for full sync.
int64_t last_write_ns = replica_ptr->flows[sid].saver->GetLastWriteTime();
int64_t timeout_ns = int64_t(absl::GetFlag(FLAGS_replication_timeout)) * 1'000'000LL;
int64_t now = absl::GetCurrentTimeNanos();
if (last_write_ns > 0 && last_write_ns + timeout_ns < now) {
VLOG(1) << "Breaking full sync for sync_id " << sync_id << " last_write_ts: " << last_write_ns
<< ", now: " << now;
deleted.push_back(sync_id);
replica_lock.unlock();
replica_ptr->Cancel();
}
}
for (auto sync_id : deleted)
replica_infos_.erase(sync_id);
}
shared_ptr<DflyCmd::ReplicaInfo> DflyCmd::GetReplicaInfo(uint32_t sync_id) {
lock_guard lk(mu_);
@ -807,7 +842,7 @@ void DflyCmd::Shutdown() {
void FlowInfo::TryShutdownSocket() {
// Close socket for clean disconnect.
if (conn->socket()->IsOpen()) {
(void)conn->socket()->Shutdown(SHUT_RDWR);
std::ignore = conn->socket()->Shutdown(SHUT_RDWR);
}
}

View file

@ -147,8 +147,8 @@ class DflyCmd {
// Stop all background processes so we can exit in orderly manner.
void Shutdown();
// Create new sync session.
std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> CreateSyncSession(ConnectionContext* cntx);
// Create new sync session. Returns (session_id, number of flows)
std::pair<uint32_t, unsigned> CreateSyncSession(ConnectionContext* cntx);
// Master side acces method to replication info of that connection.
std::shared_ptr<ReplicaInfo> GetReplicaInfoFromConnection(ConnectionContext* cntx);
@ -160,6 +160,9 @@ class DflyCmd {
// Sets metadata.
void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version);
// Tries to break those flows that stuck on socket write for too long time.
void BreakStalledFlowsInShard();
private:
// JOURNAL [START/STOP]
// Start or stop journaling.

View file

@ -682,13 +682,14 @@ void EngineShard::RetireExpiredAndEvict() {
}
void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
std::function<void()> global_handler) {
std::function<void()> shard_handler) {
VLOG(1) << "RunPeriodic with period " << period_ms.count() << "ms";
bool runs_global_periodic = (shard_id() == 0); // Only shard 0 runs global periodic.
unsigned global_count = 0;
int64_t last_stats_time = time(nullptr);
int64_t last_heartbeat_ms = INT64_MAX;
int64_t last_handler_ms = 0;
while (true) {
if (fiber_periodic_done_.WaitFor(period_ms)) {
@ -702,6 +703,10 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
}
Heartbeat();
last_heartbeat_ms = fb2::ProactorBase::GetMonotonicTimeNs() / 1000000;
if (shard_handler && last_handler_ms + 100 < last_heartbeat_ms) {
last_handler_ms = last_heartbeat_ms;
shard_handler();
}
if (runs_global_periodic) {
++global_count;
@ -727,10 +732,6 @@ void EngineShard::RunPeriodic(std::chrono::milliseconds period_ms,
rss_mem_peak.store(total_rss, memory_order_relaxed);
}
}
if (global_handler) {
global_handler();
}
}
}
}
@ -903,7 +904,7 @@ size_t GetTieredFileLimit(size_t threads) {
return max_shard_file_size;
}
void EngineShardSet::Init(uint32_t sz, std::function<void()> global_handler) {
void EngineShardSet::Init(uint32_t sz, std::function<void()> shard_handler) {
CHECK_EQ(0u, size());
shard_queue_.resize(sz);
@ -922,7 +923,8 @@ void EngineShardSet::Init(uint32_t sz, std::function<void()> global_handler) {
shard->InitTieredStorage(pb, max_shard_file_size);
// Must be last, as it accesses objects initialized above.
shard->StartPeriodicFiber(pb, global_handler);
// We can not move shard_handler because this code is called multiple times.
shard->StartPeriodicFiber(pb, shard_handler);
}
});
}

View file

@ -203,12 +203,12 @@ class EngineShard {
// blocks the calling fiber.
void Shutdown(); // called before destructing EngineShard.
void StartPeriodicFiber(util::ProactorBase* pb, std::function<void()> global_handler);
void StartPeriodicFiber(util::ProactorBase* pb, std::function<void()> shard_handler);
void Heartbeat();
void RetireExpiredAndEvict();
void RunPeriodic(std::chrono::milliseconds period_ms, std::function<void()> global_handler);
void RunPeriodic(std::chrono::milliseconds period_ms, std::function<void()> shard_handler);
void CacheStats();
@ -288,7 +288,7 @@ class EngineShardSet {
return pp_;
}
void Init(uint32_t size, std::function<void()> global_handler);
void Init(uint32_t size, std::function<void()> shard_handler);
// Shutdown sequence:
// - EngineShardSet.PreShutDown()

View file

@ -13,9 +13,9 @@
using namespace facade;
ABSL_FLAG(uint32_t, replication_stream_timeout, 500,
"Time in milliseconds to wait for the replication output buffer go below "
"the throttle limit.");
ABSL_FLAG(uint32_t, replication_timeout, 10000,
"Time in milliseconds to wait for the replication writes being stuck.");
ABSL_FLAG(uint32_t, replication_stream_output_limit, 64_KB,
"Time to wait for the replication output buffer go below the throttle limit");
@ -155,8 +155,8 @@ void JournalStreamer::ThrottleIfNeeded() {
if (IsStopped() || !IsStalled())
return;
auto next = chrono::steady_clock::now() +
chrono::milliseconds(absl::GetFlag(FLAGS_replication_stream_timeout));
auto next =
chrono::steady_clock::now() + chrono::milliseconds(absl::GetFlag(FLAGS_replication_timeout));
size_t inflight_start = in_flight_bytes_;
size_t sent_start = total_sent_;

View file

@ -887,7 +887,6 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
ServerState::tlocal()->UpdateChannelStore(cs);
});
shard_set->Init(shard_num, nullptr);
const auto tcp_disabled = GetFlag(FLAGS_port) == 0u;
// We assume that listeners.front() is the main_listener
// see dfly_main RunEngine
@ -895,6 +894,11 @@ void Service::Init(util::AcceptServer* acceptor, std::vector<facade::Listener*>
acl_family_.Init(listeners.front(), &user_registry_);
}
// Initialize shard_set with a global callback running once in a while in the shard threads.
shard_set->Init(shard_num, [this] { server_family_.GetDflyCmd()->BreakStalledFlowsInShard(); });
// Requires that shard_set will be initialized before because server_family_.Init might
// load the snapshot.
server_family_.Init(acceptor, std::move(listeners));
}

View file

@ -1135,10 +1135,15 @@ class RdbSaver::Impl {
return &meta_serializer_;
}
int64_t last_write_ts() const {
return last_write_time_ns_;
}
private:
unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard);
io::Sink* sink_;
int64_t last_write_time_ns_ = -1; // last write call.
vector<unique_ptr<SliceSnapshot>> shard_snapshots_;
// used for serializing non-body components in the calling fiber.
RdbSerializer meta_serializer_;
@ -1263,10 +1268,12 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
continue;
DVLOG(2) << "Pulled " << record->id;
auto before = absl::GetCurrentTimeNanos();
last_write_time_ns_ = absl::GetCurrentTimeNanos();
io_error = sink_->Write(io::Buffer(record->value));
stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - before) / 1'000;
stats.rdb_save_usec += (absl::GetCurrentTimeNanos() - last_write_time_ns_) / 1'000;
stats.rdb_save_count++;
last_write_time_ns_ = -1;
if (io_error) {
VLOG(1) << "Error writing to sink " << io_error.message();
break;
@ -1369,7 +1376,10 @@ RdbSaver::SnapshotStats RdbSaver::Impl::GetCurrentSnapshotProgress() const {
}
error_code RdbSaver::Impl::FlushSerializer() {
return serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry);
last_write_time_ns_ = absl::GetCurrentTimeNanos();
auto ec = serializer()->FlushToSink(sink_, SerializerBase::FlushState::kFlushMidEntry);
last_write_time_ns_ = -1;
return ec;
}
RdbSaver::GlobalData RdbSaver::GetGlobalData(const Service* service) {
@ -1482,7 +1492,6 @@ error_code RdbSaver::SaveBody(Context* cntx, RdbTypeFreqMap* freq_map) {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
if (io_error) {
LOG(ERROR) << "io error " << io_error;
return io_error;
}
if (cntx->GetError()) {
@ -1572,6 +1581,10 @@ RdbSaver::SnapshotStats RdbSaver::GetCurrentSnapshotProgress() const {
return impl_->GetCurrentSnapshotProgress();
}
int64_t RdbSaver::GetLastWriteTime() const {
return impl_->last_write_ts();
}
void SerializerBase::AllocateCompressorOnce() {
if (compressor_impl_) {
return;

View file

@ -122,6 +122,10 @@ class RdbSaver {
// Fetch global data to be serialized in summary part of a snapshot / full sync.
static GlobalData GetGlobalData(const Service* service);
// Returns time in nanos of start of the last pending write interaction.
// Returns -1 if no write operations are currently pending.
int64_t GetLastWriteTime() const;
private:
class Impl;

View file

@ -2665,7 +2665,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
std::string_view arg = ArgS(args, i + 1);
if (cmd == "CAPA") {
if (arg == "dragonfly" && args.size() == 2 && i == 0) {
auto [sid, replica_info] = dfly_cmd_->CreateSyncSession(cntx);
auto [sid, flow_count] = dfly_cmd_->CreateSyncSession(cntx);
cntx->conn()->SetName(absl::StrCat("repl_ctrl_", sid));
string sync_id = absl::StrCat("SYNC", sid);
@ -2681,7 +2681,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
rb->StartArray(4);
rb->SendSimpleString(master_replid_);
rb->SendSimpleString(sync_id);
rb->SendLong(replica_info->flows.size());
rb->SendLong(flow_count);
rb->SendLong(unsigned(DflyVersion::CURRENT_VER));
return;
}

View file

@ -2174,7 +2174,7 @@ async def test_replica_reconnect(df_factory, break_conn):
c_master = master.client()
c_replica = replica.client()
await c_master.execute_command("set k 12345")
await c_master.set("k", "12345")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
assert (await c_replica.info("REPLICATION"))["master_link_status"] == "up"
@ -2230,3 +2230,34 @@ async def test_announce_ip_port(df_factory):
host, port, _ = node[0]
assert host == "overrode-host"
assert port == "1337"
async def test_master_stalled_disconnect(df_factory: DflyInstanceFactory):
# disconnect after 1 second of being blocked
master = df_factory.create(replication_timeout=1000)
replica = df_factory.create()
df_factory.start_all([master, replica])
c_master = master.client()
c_replica = replica.client()
await c_master.execute_command("debug", "populate", "200000", "foo", "500")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
@assert_eventually
async def check_replica_connected():
repl_info = await c_master.info("replication")
assert "slave0" in repl_info
@assert_eventually
async def check_replica_disconnected():
repl_info = await c_master.info("replication")
assert "slave0" not in repl_info
await check_replica_connected()
await c_replica.execute_command("DEBUG REPLICA PAUSE")
await check_replica_connected() # still connected
await asyncio.sleep(1) # wait for the master to recognize it's being blocked
await check_replica_disconnected()
df_factory.stop_all()