From 4babed54d3994b6b366672212a54879cd2cd7fb8 Mon Sep 17 00:00:00 2001 From: Roy Jacobson Date: Sun, 2 Jul 2023 16:11:28 +0200 Subject: [PATCH] feat: Support atomic replica takeover (#1314) * fix(server): Initialize ServerFamily with all listeners. - Add a test for CLIENT LIST which is the visible result of this. * use std move * feat: Implement replicas take over * Basic test * Address CR comments * Write a better test. Sadly it fails * chore: Expose AwaitDispatches for reuse in takeover * Ensure that no commands can execute during or after a takeover * CR progress * Actually disable the expiration * Improve tests coverage * Fix the dispatch waiting code * Improve testing coverage and fix a shutdown snaphot bug * don't replicate a replica --- src/facade/dragonfly_listener.cc | 42 ++++++----- src/facade/dragonfly_listener.h | 6 +- src/facade/op_status.cc | 2 + src/facade/op_status.h | 1 + src/server/common.cc | 2 + src/server/common.h | 1 + src/server/db_slice.cc | 4 +- src/server/db_slice.h | 5 ++ src/server/dflycmd.cc | 101 +++++++++++++++++++++++-- src/server/dflycmd.h | 6 +- src/server/engine_shard_set.h | 4 + src/server/journal/journal_slice.cc | 39 +++++----- src/server/main_service.cc | 20 ++++- src/server/replica.cc | 22 +++++- src/server/replica.h | 6 +- src/server/server_family.cc | 61 +++++++++++++-- src/server/server_family.h | 8 +- src/server/transaction.cc | 6 +- tests/dragonfly/__init__.py | 5 ++ tests/dragonfly/replication_test.py | 112 +++++++++++++++++++++++++++- tests/dragonfly/utility.py | 7 +- 21 files changed, 392 insertions(+), 68 deletions(-) diff --git a/src/facade/dragonfly_listener.cc b/src/facade/dragonfly_listener.cc index 5d5560dc2..6eaf24483 100644 --- a/src/facade/dragonfly_listener.cc +++ b/src/facade/dragonfly_listener.cc @@ -179,6 +179,28 @@ void Listener::PreAcceptLoop(util::ProactorBase* pb) { per_thread_.resize(pool()->size()); } +bool Listener::AwaitDispatches(absl::Duration timeout, + const std::function& filter) { + absl::Time start = absl::Now(); + + while (absl::Now() - start < timeout) { + std::atomic any_connection_dispatching = false; + auto cb = [&any_connection_dispatching, &filter](unsigned thread_index, + util::Connection* conn) { + if (filter(conn) && static_cast(conn)->IsCurrentlyDispatching()) { + any_connection_dispatching.store(true); + } + }; + this->TraverseConnections(cb); + if (!any_connection_dispatching.load()) { + return true; + } + VLOG(1) << "A command is still dispatching, let's wait for it"; + ThisFiber::SleepFor(100us); + } + return false; +} + void Listener::PreShutdown() { // Iterate on all connections and allow them to finish their commands for // a short period. @@ -188,26 +210,8 @@ void Listener::PreShutdown() { // at this stage since we're in SHUTDOWN mode. // If a command is running for too long we give up and proceed. const absl::Duration kDispatchShutdownTimeout = absl::Milliseconds(10); - absl::Time start = absl::Now(); - bool success = false; - while (absl::Now() - start < kDispatchShutdownTimeout) { - std::atomic any_connection_dispatching = false; - auto cb = [&any_connection_dispatching](unsigned thread_index, util::Connection* conn) { - if (static_cast(conn)->IsCurrentlyDispatching()) { - any_connection_dispatching.store(true); - } - }; - this->TraverseConnections(cb); - if (!any_connection_dispatching.load()) { - success = true; - break; - } - VLOG(1) << "A command is still dispatching, let's wait for it"; - ThisFiber::SleepFor(100us); - } - - if (!success) { + if (!AwaitDispatches(kDispatchShutdownTimeout, [](util::Connection*) { return true; })) { LOG(WARNING) << "Some commands are still being dispatched but didn't conclude in time. " "Proceeding in shutdown."; } diff --git a/src/facade/dragonfly_listener.h b/src/facade/dragonfly_listener.h index d1b0ce11b..ee8362ef9 100644 --- a/src/facade/dragonfly_listener.h +++ b/src/facade/dragonfly_listener.h @@ -24,6 +24,11 @@ class Listener : public util::ListenerInterface { std::error_code ConfigureServerSocket(int fd) final; + // Wait until all connections that pass the filter have stopped dispatching or until a timeout has + // run out. Returns true if the all connections have stopped dispatching. + bool AwaitDispatches(absl::Duration timeout, + const std::function& filter); + private: util::Connection* NewConnection(ProactorBase* proactor) final; ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final; @@ -33,7 +38,6 @@ class Listener : public util::ListenerInterface { void PreAcceptLoop(ProactorBase* pb) final; void PreShutdown() final; - void PostShutdown() final; std::unique_ptr http_base_; diff --git a/src/facade/op_status.cc b/src/facade/op_status.cc index e398bad00..67c208fc9 100644 --- a/src/facade/op_status.cc +++ b/src/facade/op_status.cc @@ -36,6 +36,8 @@ const char* DebugString(OpStatus op) { return "ENTRIES ADDED IS TO SMALL"; case OpStatus::INVALID_NUMERIC_RESULT: return "INVALID NUMERIC RESULT"; + case OpStatus::CANCELLED: + return "CANCELLED"; } return "Unknown Error Code"; // we should not be here, but this is how enums works in c++ } diff --git a/src/facade/op_status.h b/src/facade/op_status.h index 685ba1457..28f838c49 100644 --- a/src/facade/op_status.h +++ b/src/facade/op_status.h @@ -26,6 +26,7 @@ enum class OpStatus : uint16_t { STREAM_ID_SMALL, ENTRIES_ADDED_SMALL, INVALID_NUMERIC_RESULT, + CANCELLED, }; const char* DebugString(OpStatus op); diff --git a/src/server/common.cc b/src/server/common.cc index 3bc955693..25e5d2075 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -44,6 +44,8 @@ const char* GlobalStateName(GlobalState s) { return "SAVING"; case GlobalState::SHUTTING_DOWN: return "SHUTTING DOWN"; + case GlobalState::TAKEN_OVER: + return "TAKEN OVER"; } ABSL_UNREACHABLE(); } diff --git a/src/server/common.h b/src/server/common.h index e46b16f2d..bfe762949 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -134,6 +134,7 @@ enum class GlobalState : uint8_t { LOADING, SAVING, SHUTTING_DOWN, + TAKEN_OVER, }; enum class TimeUnit : uint8_t { SEC, MSEC }; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 93147fb4b..760ac22cc 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -894,8 +894,8 @@ pair DbSlice::ExpireIfNeeded(const Context& cntx, // TODO: to employ multi-generation update of expire-base and the underlying values. time_t expire_time = ExpireTime(expire_it); - // Never do expiration on replica. - if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica()) + // Never do expiration on replica or if expiration is disabled. + if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica() || !expire_allowed_) return make_pair(it, expire_it); // Replicate expiry diff --git a/src/server/db_slice.h b/src/server/db_slice.h index b0fc47e64..10d87b127 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -324,6 +324,10 @@ class DbSlice { // Resets the event counter for updates/insertions void ResetUpdateEvents(); + void SetExpireAllowed(bool is_allowed) { + expire_allowed_ = is_allowed; + } + private: std::pair AddOrUpdateInternal(const Context& cntx, std::string_view key, PrimeValue obj, uint64_t expire_at_ms, @@ -351,6 +355,7 @@ class DbSlice { EngineShard* owner_; time_t expire_base_[2]; // Used for expire logic, represents a real clock. + bool expire_allowed_ = true; uint64_t version_ = 1; // Used to version entries in the PrimeTable. ssize_t memory_budget_ = SSIZE_MAX; diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index a2b672705..c7fe6d25f 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -14,10 +14,12 @@ #include "base/flags.h" #include "base/logging.h" #include "facade/dragonfly_connection.h" +#include "facade/dragonfly_listener.h" #include "server/engine_shard_set.h" #include "server/error.h" #include "server/journal/journal.h" #include "server/journal/streamer.h" +#include "server/main_service.h" #include "server/rdb_save.h" #include "server/script_mgr.h" #include "server/server_family.h" @@ -65,15 +67,27 @@ std::string_view SyncStateName(DflyCmd::SyncState sync_state) { } struct TransactionGuard { - constexpr static auto kEmptyCb = [](Transaction* t, EngineShard* shard) { return OpStatus::OK; }; + static OpStatus ExitGuardCb(Transaction* t, EngineShard* shard) { + shard->db_slice().SetExpireAllowed(true); + return OpStatus::OK; + }; - TransactionGuard(Transaction* t) : t(t) { + explicit TransactionGuard(Transaction* t, bool disable_expirations = false) : t(t) { t->Schedule(); - t->Execute(kEmptyCb, false); + t->Execute( + [disable_expirations](Transaction* t, EngineShard* shard) { + if (disable_expirations) { + shard->db_slice().SetExpireAllowed(!disable_expirations); + } + return OpStatus::OK; + }, + false); + VLOG(1) << "Transaction guard engaged"; } ~TransactionGuard() { - t->Execute(kEmptyCb, true); + VLOG(1) << "Releasing transaction guard"; + t->Execute(ExitGuardCb, true); } Transaction* t; @@ -110,6 +124,10 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) { return StartStable(args, cntx); } + if (sub_cmd == "TAKEOVER" && args.size() == 3) { + return TakeOver(args, cntx); + } + if (sub_cmd == "EXPIRE") { return Expire(args, cntx); } @@ -316,7 +334,6 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) { StopFullSyncInThread(flow, shard); status = StartStableSyncInThread(flow, &replica_ptr->cntx, shard); - return OpStatus::OK; }; shard_set->RunBlockingInParallel(std::move(cb)); @@ -331,6 +348,80 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) { return rb->SendOk(); } +void DflyCmd::TakeOver(CmdArgList args, ConnectionContext* cntx) { + RedisReplyBuilder* rb = static_cast(cntx->reply_builder()); + string_view sync_id_str = ArgS(args, 2); + float timeout; + if (!absl::SimpleAtof(ArgS(args, 1), &timeout)) { + return (*cntx)->SendError(kInvalidIntErr); + } + if (timeout < 0) { + return (*cntx)->SendError("timeout is negative"); + } + + VLOG(1) << "Got DFLY TAKEOVER " << sync_id_str; + + auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb); + if (!sync_id) + return; + + unique_lock lk(replica_ptr->mu); + if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::STABLE_SYNC, rb)) + return; + + LOG(INFO) << "Takeover initiated, locking down the database."; + + sf_->service().SwitchState(GlobalState::ACTIVE, GlobalState::TAKEN_OVER); + + absl::Duration timeout_dur = absl::Seconds(timeout); + absl::Time start = absl::Now(); + AggregateStatus status; + + // TODO: We should cancel blocking commands before awaiting all + // dispatches to finish. + if (!sf_->AwaitDispatches(timeout_dur, [self = cntx->owner()](util::Connection* conn) { + // The only command that is currently dispatching should be the takeover command - + // so we wait until this is true. + return conn != self; + })) { + LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << timeout_dur; + status = OpStatus::TIMED_OUT; + } + + TransactionGuard tg{cntx->transaction, /*disable_expirations=*/true}; + + if (*status == OpStatus::OK) { + auto cb = [&cntx = replica_ptr->cntx, replica_ptr = replica_ptr, timeout_dur, start, + &status](EngineShard* shard) { + FlowInfo* flow = &replica_ptr->flows[shard->shard_id()]; + + shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, {}, true); + while (flow->last_acked_lsn < shard->journal()->GetLsn()) { + if (absl::Now() - start > timeout_dur) { + LOG(WARNING) << "Couldn't synchronize with replica for takeover in time."; + status = OpStatus::TIMED_OUT; + return; + } + if (cntx.IsCancelled()) { + status = OpStatus::CANCELLED; + return; + } + ThisFiber::SleepFor(1ms); + } + }; + shard_set->RunBlockingInParallel(std::move(cb)); + } + + if (*status != OpStatus::OK) { + sf_->service().SwitchState(GlobalState::TAKEN_OVER, GlobalState::ACTIVE); + return rb->SendError("Takeover failed!"); + } + (*cntx)->SendOk(); + + VLOG(1) << "Takeover accepted, shutting down."; + return sf_->ShutdownCmd({}, cntx); +} + void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) { RedisReplyBuilder* rb = static_cast(cntx->reply_builder()); cntx->transaction->ScheduleSingleHop([](Transaction* t, EngineShard* shard) { diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index e3ee89159..2f0212cbb 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -99,7 +99,7 @@ class DflyCmd { struct ReplicaInfo { ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port, Context::ErrHandler err_handler) - : state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{address}, + : state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{std::move(address)}, listening_port(listening_port), flows{flow_count} { } @@ -151,6 +151,10 @@ class DflyCmd { // Switch to stable state replication. void StartStable(CmdArgList args, ConnectionContext* cntx); + // TAKEOVER + // Shut this master down atomically with replica promotion. + void TakeOver(CmdArgList args, ConnectionContext* cntx); + // EXPIRE // Check all keys for expiry. void Expire(CmdArgList args, ConnectionContext* cntx); diff --git a/src/server/engine_shard_set.h b/src/server/engine_shard_set.h index 174991d57..6226aeca7 100644 --- a/src/server/engine_shard_set.h +++ b/src/server/engine_shard_set.h @@ -331,6 +331,10 @@ void EngineShardSet::RunBriefInParallel(U&& func, P&& pred) const { template void EngineShardSet::RunBlockingInParallel(U&& func) { BlockingCounter bc{size()}; + static_assert(std::is_invocable_v, + "Argument must be invocable EngineShard* as argument."); + static_assert(std::is_void_v>, + "Callable must not have a return value!"); for (uint32_t i = 0; i < size(); ++i) { util::ProactorBase* dest = pp_->at(i); diff --git a/src/server/journal/journal_slice.cc b/src/server/journal/journal_slice.cc index cd164a616..852006794 100644 --- a/src/server/journal/journal_slice.cc +++ b/src/server/journal/journal_slice.cc @@ -114,6 +114,25 @@ error_code JournalSlice::Close() { void JournalSlice::AddLogRecord(const Entry& entry, bool await) { DCHECK(ring_buffer_); + + if (entry.opcode != Op::NOOP) { + // TODO: This is preparation for AOC style journaling, currently unused. + RingItem item; + item.lsn = lsn_; + lsn_++; + item.opcode = entry.opcode; + item.txid = entry.txid; + VLOG(1) << "Writing item [" << item.lsn << "]: " << entry.ToString(); + ring_buffer_->EmplaceOrOverride(move(item)); + + if (shard_file_) { + string line = absl::StrCat(item.lsn, " ", entry.txid, " ", entry.opcode, "\n"); + error_code ec = shard_file_->Write(io::Buffer(line), file_offset_, 0); + CHECK_EC(ec); + file_offset_ += line.size(); + } + } + { std::shared_lock lk(cb_mu_); DVLOG(2) << "AddLogRecord: run callbacks for " << entry.ToString() @@ -123,26 +142,6 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) { k_v.second(entry, await); } } - - if (entry.opcode == Op::NOOP) - return; - - // TODO: This is preparation for AOC style journaling, currently unused. - RingItem item; - item.lsn = lsn_; - item.opcode = entry.opcode; - item.txid = entry.txid; - VLOG(1) << "Writing item [" << item.lsn << "]: " << entry.ToString(); - ring_buffer_->EmplaceOrOverride(move(item)); - - if (shard_file_) { - string line = absl::StrCat(lsn_, " ", entry.txid, " ", entry.opcode, "\n"); - error_code ec = shard_file_->Write(io::Buffer(line), file_offset_, 0); - CHECK_EC(ec); - file_offset_ += line.size(); - } - - ++lsn_; } uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index e2e582175..6ba768ea8 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -695,9 +695,23 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionCon bool is_trans_cmd = CO::IsTransKind(cid->name()); bool under_script = bool(dfly_cntx->conn_state.script_info); - bool blocked_by_loading = !dfly_cntx->journal_emulated && etl.gstate() == GlobalState::LOADING && - (cid->opt_mask() & CO::LOADING) == 0; - if (blocked_by_loading || etl.gstate() == GlobalState::SHUTTING_DOWN) { + bool allowed_by_state = true; + switch (etl.gstate()) { + case GlobalState::LOADING: + allowed_by_state = dfly_cntx->journal_emulated || (cid->opt_mask() & CO::LOADING); + break; + case GlobalState::SHUTTING_DOWN: + allowed_by_state = false; + break; + case GlobalState::TAKEN_OVER: + allowed_by_state = cid->name() == "REPLCONF" || cid->name() == "SAVE"; + break; + default: + break; + } + if (!allowed_by_state) { + VLOG(1) << "Command " << cid->name() << " not executed because global state is " + << GlobalStateName(etl.gstate()); string err = StrCat("Can not execute during ", GlobalStateName(etl.gstate())); (*dfly_cntx)->SendError(err); return false; diff --git a/src/server/replica.cc b/src/server/replica.cc index 10ed2a9e7..7083f88b8 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -228,6 +228,21 @@ void Replica::Pause(bool pause) { sock_->proactor()->Await([&] { is_paused_ = pause; }); } +std::error_code Replica::TakeOver(std::string_view timeout) { + VLOG(1) << "Taking over"; + + std::error_code ec; + sock_->proactor()->Await( + [this, &ec, timeout] { ec = SendNextPhaseRequest(absl::StrCat("TAKEOVER ", timeout)); }); + + if (ec) { + // TODO: Handle timeout more gracefully. + return cntx_.ReportError(ec); + } + // If we successfully taken over, return and let server_family stop us. + return {}; +} + void Replica::MainReplicationFb() { VLOG(1) << "Main replication fiber started"; // Switch shard states to replication. @@ -610,7 +625,7 @@ error_code Replica::InitiateDflySync() { RETURN_ON_ERR(cntx_.GetError()); // Send DFLY SYNC. - if (auto ec = SendNextPhaseRequest(false); ec) { + if (auto ec = SendNextPhaseRequest("SYNC"); ec) { return cntx_.ReportError(ec); } @@ -626,7 +641,7 @@ error_code Replica::InitiateDflySync() { return cntx_.GetError(); // Send DFLY STARTSTABLE. - if (auto ec = SendNextPhaseRequest(true); ec) { + if (auto ec = SendNextPhaseRequest("STARTSTABLE"); ec) { return cntx_.ReportError(ec); } @@ -770,11 +785,10 @@ void Replica::DefaultErrorHandler(const GenericError& err) { CloseSocket(); } -error_code Replica::SendNextPhaseRequest(bool stable) { +error_code Replica::SendNextPhaseRequest(string_view kind) { ReqSerializer serializer{sock_.get()}; // Ask master to start sending replication stream - string_view kind = (stable) ? "STARTSTABLE"sv : "SYNC"sv; string request = StrCat("DFLY ", kind, " ", master_context_.dfly_session_id); VLOG(1) << "Sending: " << request; diff --git a/src/server/replica.h b/src/server/replica.h index 0c32a4dd7..a1c2ed5e7 100644 --- a/src/server/replica.h +++ b/src/server/replica.h @@ -115,6 +115,8 @@ class Replica { void Pause(bool pause); + std::error_code TakeOver(std::string_view timeout); + std::string_view MasterId() const { return master_context_.master_repl_id; } @@ -141,8 +143,8 @@ class Replica { void JoinAllFlows(); // Join all flows if possible. void SetShardStates(bool replica); // Call SetReplica(replica) on all shards. - // Send DFLY SYNC or DFLY STARTSTABLE if stable is true. - std::error_code SendNextPhaseRequest(bool stable); + // Send DFLY ${kind} to the master instance. + std::error_code SendNextPhaseRequest(std::string_view kind); void DefaultErrorHandler(const GenericError& err); diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 6b7dca81a..bff4ce65b 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1054,12 +1054,13 @@ GenericError ServerFamily::DoSave(bool new_version, string_view basename, Transa // Manage global state. GlobalState new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::SAVING); - if (new_state != GlobalState::SAVING) { + if (new_state != GlobalState::SAVING && new_state != GlobalState::TAKEN_OVER) { return {make_error_code(errc::operation_in_progress), StrCat(GlobalStateName(new_state), " - can not save database")}; } - absl::Cleanup rev_state = [this] { - service_.SwitchState(GlobalState::SAVING, GlobalState::ACTIVE); + absl::Cleanup rev_state = [this, new_state] { + if (new_state == GlobalState::SAVING) + service_.SwitchState(GlobalState::SAVING, GlobalState::ACTIVE); }; absl::Time start = absl::Now(); @@ -1259,6 +1260,19 @@ void ServerFamily::BreakOnShutdown() { dfly_cmd_->BreakOnShutdown(); } +bool ServerFamily::AwaitDispatches(absl::Duration timeout, + const std::function& filter) { + auto start = absl::Now(); + for (auto* listener : listeners_) { + absl::Duration remaining_time = timeout - (absl::Now() - start); + if (remaining_time < absl::Nanoseconds(0) || + !listener->AwaitDispatches(remaining_time, filter)) { + return false; + } + } + return true; +} + string GetPassword() { string flag = GetFlag(FLAGS_requirepass); if (!flag.empty()) { @@ -1933,6 +1947,41 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) { } } +void ServerFamily::ReplTakeOver(CmdArgList args, ConnectionContext* cntx) { + VLOG(1) << "Starting take over"; + VLOG(1) << "Acquire replica lock"; + unique_lock lk(replicaof_mu_); + + float_t timeout_sec; + if (!absl::SimpleAtof(ArgS(args, 0), &timeout_sec)) { + return (*cntx)->SendError(kInvalidIntErr); + } + if (timeout_sec < 0) { + return (*cntx)->SendError("timeout is negative"); + } + + if (ServerState::tlocal()->is_master) + return (*cntx)->SendError("Already a master instance"); + auto repl_ptr = replica_; + CHECK(repl_ptr); + + auto info = replica_->GetInfo(); + if (!info.full_sync_done) { + return (*cntx)->SendError("Full sync not done"); + } + + std::error_code ec = replica_->TakeOver(ArgS(args, 0)); + if (ec) + return (*cntx)->SendError("Couldn't execute takeover"); + + LOG(INFO) << "Takeover successful, promoting this instance to master."; + service_.proactor_pool().AwaitFiberOnAll( + [&](util::ProactorBase* pb) { ServerState::tlocal()->is_master = true; }); + replica_->Stop(); + replica_.reset(); + return (*cntx)->SendOk(); +} + void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) { if (args.size() % 2 == 1) goto err; @@ -2083,7 +2132,7 @@ void ServerFamily::Latency(CmdArgList args, ConnectionContext* cntx) { (*cntx)->SendError(kSyntaxErr); } -void ServerFamily::_Shutdown(CmdArgList args, ConnectionContext* cntx) { +void ServerFamily::ShutdownCmd(CmdArgList args, ConnectionContext* cntx) { if (args.size() > 1) { (*cntx)->SendError(kSyntaxErr); return; @@ -2145,9 +2194,11 @@ void ServerFamily::Register(CommandRegistry* registry) { << CI{"LATENCY", CO::NOSCRIPT | CO::LOADING | CO::FAST, -2, 0, 0, 0}.HFUNC(Latency) << CI{"MEMORY", kMemOpts, -2, 0, 0, 0}.HFUNC(Memory) << CI{"SAVE", CO::ADMIN | CO::GLOBAL_TRANS, -1, 0, 0, 0}.HFUNC(Save) - << CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC(_Shutdown) + << CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC( + ShutdownCmd) << CI{"SLAVEOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf) << CI{"REPLICAOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf) + << CI{"REPLTAKEOVER", CO::ADMIN | CO::GLOBAL_TRANS, 2, 0, 0, 0}.HFUNC(ReplTakeOver) << CI{"REPLCONF", CO::ADMIN | CO::LOADING, -1, 0, 0, 0}.HFUNC(ReplConf) << CI{"ROLE", CO::LOADING | CO::FAST | CO::NOSCRIPT, 1, 0, 0, 0}.HFUNC(Role) << CI{"SLOWLOG", CO::ADMIN | CO::FAST, -2, 0, 0, 0}.SetHandler(SlowLog) diff --git a/src/server/server_family.h b/src/server/server_family.h index 53e2256db..a56148b64 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -94,6 +94,8 @@ class ServerFamily { void Register(CommandRegistry* registry); void Shutdown(); + void ShutdownCmd(CmdArgList args, ConnectionContext* cntx); + Service& service() { return service_; } @@ -154,6 +156,9 @@ class ServerFamily { void BreakOnShutdown(); + bool AwaitDispatches(absl::Duration timeout, + const std::function& filter); + private: uint32_t shard_count() const { return shard_set->size(); @@ -174,14 +179,13 @@ class ServerFamily { void Latency(CmdArgList args, ConnectionContext* cntx); void Psync(CmdArgList args, ConnectionContext* cntx); void ReplicaOf(CmdArgList args, ConnectionContext* cntx); + void ReplTakeOver(CmdArgList args, ConnectionContext* cntx); void ReplConf(CmdArgList args, ConnectionContext* cntx); void Role(CmdArgList args, ConnectionContext* cntx); void Save(CmdArgList args, ConnectionContext* cntx); void Script(CmdArgList args, ConnectionContext* cntx); void Sync(CmdArgList args, ConnectionContext* cntx); - void _Shutdown(CmdArgList args, ConnectionContext* cntx); - void SyncGeneric(std::string_view repl_master_id, uint64_t offs, ConnectionContext* cntx); // Returns the number of loaded keys if successfull. diff --git a/src/server/transaction.cc b/src/server/transaction.cc index e6cc64371..4d7bdd854 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -507,6 +507,7 @@ bool Transaction::RunInShard(EngineShard* shard, bool txq_ooo) { if (IsGlobal()) { DCHECK(!awaked_prerun && !became_suspended); // Global transactions can not be blocking. + VLOG(2) << "Releasing shard lock"; shard->shard_lock()->Release(Mode()); } else { // not global. largs = GetLockArgs(idx); @@ -572,6 +573,7 @@ void Transaction::ScheduleInternal() { // Lock shards auto cb = [mode](EngineShard* shard) { shard->shard_lock()->Acquire(mode); }; shard_set->RunBriefInParallel(std::move(cb)); + VLOG(1) << "Global shard lock acquired"; } else { num_shards = unique_shard_cnt_; DCHECK_GT(num_shards, 0u); @@ -893,8 +895,8 @@ void Transaction::RunQuickie(EngineShard* shard) { auto& sd = shard_data_[SidToId(unique_shard_id_)]; DCHECK_EQ(0, sd.local_mask & (KEYLOCK_ACQUIRED | OUT_OF_ORDER)); - DVLOG(1) << "RunQuickSingle " << DebugId() << " " << shard->shard_id() << " " << args_[0]; - DCHECK(cb_ptr_) << DebugId() << " " << shard->shard_id() << " " << args_[0]; + DVLOG(1) << "RunQuickSingle " << DebugId() << " " << shard->shard_id(); + DCHECK(cb_ptr_) << DebugId() << " " << shard->shard_id(); // Calling the callback in somewhat safe way try { diff --git a/tests/dragonfly/__init__.py b/tests/dragonfly/__init__.py index 3436c4bd7..5f26746e6 100644 --- a/tests/dragonfly/__init__.py +++ b/tests/dragonfly/__init__.py @@ -3,6 +3,7 @@ import time import subprocess import aiohttp from prometheus_client.parser import text_string_to_metric_families +from redis.asyncio import Redis as RedisClient from dataclasses import dataclass @@ -32,6 +33,10 @@ class DflyInstance: self.args = args self.params = params self.proc = None + self._client : Optional[RedisClient] = None + + def client(self) -> RedisClient: + return RedisClient(port=self.port) def start(self): self._start() diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index 9275bb0ce..1e053b6ab 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -937,7 +937,6 @@ async def assert_lag_condition(inst, client, condition): assert False, "Lag has never satisfied condition!" - @dfly_args({"proactor_threads": 2}) @pytest.mark.asyncio async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000): @@ -1069,3 +1068,114 @@ async def test_readonly_script(df_local_factory): assert False except aioredis.ResponseError as roe: assert 'READONLY ' in str(roe) + + +take_over_cases = [ + [2, 2], + [2, 4], + [4, 2], + [8, 8], +] + + +@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases) +@pytest.mark.asyncio +async def test_take_over_counters(df_local_factory, master_threads, replica_threads): + master = df_local_factory.create(proactor_threads=master_threads, + port=BASE_PORT, + # vmodule="journal_slice=2,dflycmd=2,main_service=1", + logtostderr=True) + replica1 = df_local_factory.create( + port=BASE_PORT+1, proactor_threads=replica_threads) + replica2 = df_local_factory.create( + port=BASE_PORT+2, proactor_threads=replica_threads) + replica3 = df_local_factory.create( + port=BASE_PORT+3, proactor_threads=replica_threads) + df_local_factory.start_all([master, replica1, replica2, replica3]) + async with ( + master.client() as c_master, + replica1.client() as c1, + master.client() as c_blocking, + replica2.client() as c2, + replica3.client() as c3, + ): + await c1.execute_command(f"REPLICAOF localhost {master.port}") + await c2.execute_command(f"REPLICAOF localhost {master.port}") + await c3.execute_command(f"REPLICAOF localhost {master.port}") + + await wait_available_async(c1) + + async def counter(key): + value = 0 + await c_master.execute_command(f"SET {key} 0") + start = time.time() + while time.time() - start < 20: + try: + value = await c_master.execute_command(f"INCR {key}") + except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e: + break + else: + assert False, "The incrementing loop should be exited with a connection error" + return key, value + + async def block_during_takeover(): + "Add a blocking command during takeover to make sure it doesn't block it." + # TODO: We need to make takeover interrupt blocking commands. + return + try: + await c_blocking.execute_command("BLPOP BLOCKING_KEY1 BLOCKING_KEY2 10") + except redis.exceptions.ConnectionError: + pass + + async def delayed_takeover(): + await asyncio.sleep(1) + await c1.execute_command(f"REPLTAKEOVER 5") + + _, _, *results = await asyncio.gather(delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)]) + assert await c1.execute_command("role") == [b'master', []] + + for key, client_value in results: + replicated_value = await c1.get(key) + assert client_value == int(replicated_value) + + +@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases) +@pytest.mark.asyncio +async def test_take_over_seeder(df_local_factory, df_seeder_factory, master_threads, replica_threads): + master = df_local_factory.create(proactor_threads=master_threads, + port=BASE_PORT, + dbfilename=f"dump_{master_threads}_{replica_threads}", + logtostderr=True) + replica = df_local_factory.create( + port=BASE_PORT+1, proactor_threads=replica_threads) + df_local_factory.start_all([master, replica]) + + seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False) + async with ( + master.client() as c_master, + replica.client() as c_replica, + ): + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + await wait_available_async(c_replica) + + async def seed(): + await seeder.run(target_ops=3000) + + fill_task = asyncio.create_task(seed()) + + # Give the seeder a bit of time. + await asyncio.sleep(1) + await c_replica.execute_command(f"REPLTAKEOVER 5") + seeder.stop() + + assert await c_replica.execute_command("role") == [b'master', []] + + # Need to wait a bit to give time to write the shutdown snapshot + await asyncio.sleep(1) + assert master.proc.poll() == 0, "Master process did not exit correctly." + + master.start() + await wait_available_async(c_master) + + capture = await seeder.capture() + assert await seeder.compare(capture, port=replica.port) diff --git a/tests/dragonfly/utility.py b/tests/dragonfly/utility.py index ef9eb78c9..d62d99ac6 100644 --- a/tests/dragonfly/utility.py +++ b/tests/dragonfly/utility.py @@ -2,6 +2,7 @@ import itertools import sys import asyncio from redis import asyncio as aioredis +import redis import random import string import itertools @@ -342,7 +343,7 @@ class DflySeeder: assert await seeder.compare(capture, port=1112) """ - def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[]): + def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[], stop_on_failure=True): self.gen = CommandGenerator( keys, val_size, batch_size, max_multikey, unsupported_types ) @@ -350,6 +351,7 @@ class DflySeeder: self.dbcount = dbcount self.multi_transaction_probability = multi_transaction_probability self.stop_flag = False + self.stop_on_failure = stop_on_failure self.log_file = log_file if self.log_file is not None: @@ -496,6 +498,9 @@ class DflySeeder: try: await pipe.execute() + except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e: + if self.stop_on_failure: + raise SystemExit(e) except Exception as e: raise SystemExit(e) queue.task_done()