feat: Support atomic replica takeover (#1314)

* fix(server): Initialize ServerFamily with all listeners.

- Add a test for CLIENT LIST which is the visible result of this.

* use std move

* feat: Implement replicas take over

* Basic test

* Address CR comments

* Write a better test. Sadly it fails

* chore: Expose AwaitDispatches for reuse in takeover

* Ensure that no commands can execute during or after a takeover

* CR progress

* Actually disable the expiration

* Improve tests coverage

* Fix the dispatch waiting code

* Improve testing coverage and fix a shutdown snaphot bug

* don't replicate a replica
This commit is contained in:
Roy Jacobson 2023-07-02 16:11:28 +02:00 committed by GitHub
parent e71fae7eea
commit 4babed54d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 392 additions and 68 deletions

View file

@ -179,6 +179,28 @@ void Listener::PreAcceptLoop(util::ProactorBase* pb) {
per_thread_.resize(pool()->size()); per_thread_.resize(pool()->size());
} }
bool Listener::AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter) {
absl::Time start = absl::Now();
while (absl::Now() - start < timeout) {
std::atomic<bool> any_connection_dispatching = false;
auto cb = [&any_connection_dispatching, &filter](unsigned thread_index,
util::Connection* conn) {
if (filter(conn) && static_cast<Connection*>(conn)->IsCurrentlyDispatching()) {
any_connection_dispatching.store(true);
}
};
this->TraverseConnections(cb);
if (!any_connection_dispatching.load()) {
return true;
}
VLOG(1) << "A command is still dispatching, let's wait for it";
ThisFiber::SleepFor(100us);
}
return false;
}
void Listener::PreShutdown() { void Listener::PreShutdown() {
// Iterate on all connections and allow them to finish their commands for // Iterate on all connections and allow them to finish their commands for
// a short period. // a short period.
@ -188,26 +210,8 @@ void Listener::PreShutdown() {
// at this stage since we're in SHUTDOWN mode. // at this stage since we're in SHUTDOWN mode.
// If a command is running for too long we give up and proceed. // If a command is running for too long we give up and proceed.
const absl::Duration kDispatchShutdownTimeout = absl::Milliseconds(10); const absl::Duration kDispatchShutdownTimeout = absl::Milliseconds(10);
absl::Time start = absl::Now();
bool success = false; if (!AwaitDispatches(kDispatchShutdownTimeout, [](util::Connection*) { return true; })) {
while (absl::Now() - start < kDispatchShutdownTimeout) {
std::atomic<bool> any_connection_dispatching = false;
auto cb = [&any_connection_dispatching](unsigned thread_index, util::Connection* conn) {
if (static_cast<Connection*>(conn)->IsCurrentlyDispatching()) {
any_connection_dispatching.store(true);
}
};
this->TraverseConnections(cb);
if (!any_connection_dispatching.load()) {
success = true;
break;
}
VLOG(1) << "A command is still dispatching, let's wait for it";
ThisFiber::SleepFor(100us);
}
if (!success) {
LOG(WARNING) << "Some commands are still being dispatched but didn't conclude in time. " LOG(WARNING) << "Some commands are still being dispatched but didn't conclude in time. "
"Proceeding in shutdown."; "Proceeding in shutdown.";
} }

View file

@ -24,6 +24,11 @@ class Listener : public util::ListenerInterface {
std::error_code ConfigureServerSocket(int fd) final; std::error_code ConfigureServerSocket(int fd) final;
// Wait until all connections that pass the filter have stopped dispatching or until a timeout has
// run out. Returns true if the all connections have stopped dispatching.
bool AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter);
private: private:
util::Connection* NewConnection(ProactorBase* proactor) final; util::Connection* NewConnection(ProactorBase* proactor) final;
ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final; ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final;
@ -33,7 +38,6 @@ class Listener : public util::ListenerInterface {
void PreAcceptLoop(ProactorBase* pb) final; void PreAcceptLoop(ProactorBase* pb) final;
void PreShutdown() final; void PreShutdown() final;
void PostShutdown() final; void PostShutdown() final;
std::unique_ptr<util::HttpListenerBase> http_base_; std::unique_ptr<util::HttpListenerBase> http_base_;

View file

@ -36,6 +36,8 @@ const char* DebugString(OpStatus op) {
return "ENTRIES ADDED IS TO SMALL"; return "ENTRIES ADDED IS TO SMALL";
case OpStatus::INVALID_NUMERIC_RESULT: case OpStatus::INVALID_NUMERIC_RESULT:
return "INVALID NUMERIC RESULT"; return "INVALID NUMERIC RESULT";
case OpStatus::CANCELLED:
return "CANCELLED";
} }
return "Unknown Error Code"; // we should not be here, but this is how enums works in c++ return "Unknown Error Code"; // we should not be here, but this is how enums works in c++
} }

View file

@ -26,6 +26,7 @@ enum class OpStatus : uint16_t {
STREAM_ID_SMALL, STREAM_ID_SMALL,
ENTRIES_ADDED_SMALL, ENTRIES_ADDED_SMALL,
INVALID_NUMERIC_RESULT, INVALID_NUMERIC_RESULT,
CANCELLED,
}; };
const char* DebugString(OpStatus op); const char* DebugString(OpStatus op);

View file

@ -44,6 +44,8 @@ const char* GlobalStateName(GlobalState s) {
return "SAVING"; return "SAVING";
case GlobalState::SHUTTING_DOWN: case GlobalState::SHUTTING_DOWN:
return "SHUTTING DOWN"; return "SHUTTING DOWN";
case GlobalState::TAKEN_OVER:
return "TAKEN OVER";
} }
ABSL_UNREACHABLE(); ABSL_UNREACHABLE();
} }

View file

@ -134,6 +134,7 @@ enum class GlobalState : uint8_t {
LOADING, LOADING,
SAVING, SAVING,
SHUTTING_DOWN, SHUTTING_DOWN,
TAKEN_OVER,
}; };
enum class TimeUnit : uint8_t { SEC, MSEC }; enum class TimeUnit : uint8_t { SEC, MSEC };

View file

@ -894,8 +894,8 @@ pair<PrimeIterator, ExpireIterator> DbSlice::ExpireIfNeeded(const Context& cntx,
// TODO: to employ multi-generation update of expire-base and the underlying values. // TODO: to employ multi-generation update of expire-base and the underlying values.
time_t expire_time = ExpireTime(expire_it); time_t expire_time = ExpireTime(expire_it);
// Never do expiration on replica. // Never do expiration on replica or if expiration is disabled.
if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica()) if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica() || !expire_allowed_)
return make_pair(it, expire_it); return make_pair(it, expire_it);
// Replicate expiry // Replicate expiry

View file

@ -324,6 +324,10 @@ class DbSlice {
// Resets the event counter for updates/insertions // Resets the event counter for updates/insertions
void ResetUpdateEvents(); void ResetUpdateEvents();
void SetExpireAllowed(bool is_allowed) {
expire_allowed_ = is_allowed;
}
private: private:
std::pair<PrimeIterator, bool> AddOrUpdateInternal(const Context& cntx, std::string_view key, std::pair<PrimeIterator, bool> AddOrUpdateInternal(const Context& cntx, std::string_view key,
PrimeValue obj, uint64_t expire_at_ms, PrimeValue obj, uint64_t expire_at_ms,
@ -351,6 +355,7 @@ class DbSlice {
EngineShard* owner_; EngineShard* owner_;
time_t expire_base_[2]; // Used for expire logic, represents a real clock. time_t expire_base_[2]; // Used for expire logic, represents a real clock.
bool expire_allowed_ = true;
uint64_t version_ = 1; // Used to version entries in the PrimeTable. uint64_t version_ = 1; // Used to version entries in the PrimeTable.
ssize_t memory_budget_ = SSIZE_MAX; ssize_t memory_budget_ = SSIZE_MAX;

View file

@ -14,10 +14,12 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/logging.h" #include "base/logging.h"
#include "facade/dragonfly_connection.h" #include "facade/dragonfly_connection.h"
#include "facade/dragonfly_listener.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/error.h" #include "server/error.h"
#include "server/journal/journal.h" #include "server/journal/journal.h"
#include "server/journal/streamer.h" #include "server/journal/streamer.h"
#include "server/main_service.h"
#include "server/rdb_save.h" #include "server/rdb_save.h"
#include "server/script_mgr.h" #include "server/script_mgr.h"
#include "server/server_family.h" #include "server/server_family.h"
@ -65,15 +67,27 @@ std::string_view SyncStateName(DflyCmd::SyncState sync_state) {
} }
struct TransactionGuard { struct TransactionGuard {
constexpr static auto kEmptyCb = [](Transaction* t, EngineShard* shard) { return OpStatus::OK; }; static OpStatus ExitGuardCb(Transaction* t, EngineShard* shard) {
shard->db_slice().SetExpireAllowed(true);
return OpStatus::OK;
};
TransactionGuard(Transaction* t) : t(t) { explicit TransactionGuard(Transaction* t, bool disable_expirations = false) : t(t) {
t->Schedule(); t->Schedule();
t->Execute(kEmptyCb, false); t->Execute(
[disable_expirations](Transaction* t, EngineShard* shard) {
if (disable_expirations) {
shard->db_slice().SetExpireAllowed(!disable_expirations);
}
return OpStatus::OK;
},
false);
VLOG(1) << "Transaction guard engaged";
} }
~TransactionGuard() { ~TransactionGuard() {
t->Execute(kEmptyCb, true); VLOG(1) << "Releasing transaction guard";
t->Execute(ExitGuardCb, true);
} }
Transaction* t; Transaction* t;
@ -110,6 +124,10 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
return StartStable(args, cntx); return StartStable(args, cntx);
} }
if (sub_cmd == "TAKEOVER" && args.size() == 3) {
return TakeOver(args, cntx);
}
if (sub_cmd == "EXPIRE") { if (sub_cmd == "EXPIRE") {
return Expire(args, cntx); return Expire(args, cntx);
} }
@ -316,7 +334,6 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) {
StopFullSyncInThread(flow, shard); StopFullSyncInThread(flow, shard);
status = StartStableSyncInThread(flow, &replica_ptr->cntx, shard); status = StartStableSyncInThread(flow, &replica_ptr->cntx, shard);
return OpStatus::OK;
}; };
shard_set->RunBlockingInParallel(std::move(cb)); shard_set->RunBlockingInParallel(std::move(cb));
@ -331,6 +348,80 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) {
return rb->SendOk(); return rb->SendOk();
} }
void DflyCmd::TakeOver(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
string_view sync_id_str = ArgS(args, 2);
float timeout;
if (!absl::SimpleAtof(ArgS(args, 1), &timeout)) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (timeout < 0) {
return (*cntx)->SendError("timeout is negative");
}
VLOG(1) << "Got DFLY TAKEOVER " << sync_id_str;
auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb);
if (!sync_id)
return;
unique_lock lk(replica_ptr->mu);
if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::STABLE_SYNC, rb))
return;
LOG(INFO) << "Takeover initiated, locking down the database.";
sf_->service().SwitchState(GlobalState::ACTIVE, GlobalState::TAKEN_OVER);
absl::Duration timeout_dur = absl::Seconds(timeout);
absl::Time start = absl::Now();
AggregateStatus status;
// TODO: We should cancel blocking commands before awaiting all
// dispatches to finish.
if (!sf_->AwaitDispatches(timeout_dur, [self = cntx->owner()](util::Connection* conn) {
// The only command that is currently dispatching should be the takeover command -
// so we wait until this is true.
return conn != self;
})) {
LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << timeout_dur;
status = OpStatus::TIMED_OUT;
}
TransactionGuard tg{cntx->transaction, /*disable_expirations=*/true};
if (*status == OpStatus::OK) {
auto cb = [&cntx = replica_ptr->cntx, replica_ptr = replica_ptr, timeout_dur, start,
&status](EngineShard* shard) {
FlowInfo* flow = &replica_ptr->flows[shard->shard_id()];
shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, {}, true);
while (flow->last_acked_lsn < shard->journal()->GetLsn()) {
if (absl::Now() - start > timeout_dur) {
LOG(WARNING) << "Couldn't synchronize with replica for takeover in time.";
status = OpStatus::TIMED_OUT;
return;
}
if (cntx.IsCancelled()) {
status = OpStatus::CANCELLED;
return;
}
ThisFiber::SleepFor(1ms);
}
};
shard_set->RunBlockingInParallel(std::move(cb));
}
if (*status != OpStatus::OK) {
sf_->service().SwitchState(GlobalState::TAKEN_OVER, GlobalState::ACTIVE);
return rb->SendError("Takeover failed!");
}
(*cntx)->SendOk();
VLOG(1) << "Takeover accepted, shutting down.";
return sf_->ShutdownCmd({}, cntx);
}
void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) { void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder()); RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
cntx->transaction->ScheduleSingleHop([](Transaction* t, EngineShard* shard) { cntx->transaction->ScheduleSingleHop([](Transaction* t, EngineShard* shard) {

View file

@ -99,7 +99,7 @@ class DflyCmd {
struct ReplicaInfo { struct ReplicaInfo {
ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port, ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port,
Context::ErrHandler err_handler) Context::ErrHandler err_handler)
: state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{address}, : state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{std::move(address)},
listening_port(listening_port), flows{flow_count} { listening_port(listening_port), flows{flow_count} {
} }
@ -151,6 +151,10 @@ class DflyCmd {
// Switch to stable state replication. // Switch to stable state replication.
void StartStable(CmdArgList args, ConnectionContext* cntx); void StartStable(CmdArgList args, ConnectionContext* cntx);
// TAKEOVER <syncid>
// Shut this master down atomically with replica promotion.
void TakeOver(CmdArgList args, ConnectionContext* cntx);
// EXPIRE // EXPIRE
// Check all keys for expiry. // Check all keys for expiry.
void Expire(CmdArgList args, ConnectionContext* cntx); void Expire(CmdArgList args, ConnectionContext* cntx);

View file

@ -331,6 +331,10 @@ void EngineShardSet::RunBriefInParallel(U&& func, P&& pred) const {
template <typename U> void EngineShardSet::RunBlockingInParallel(U&& func) { template <typename U> void EngineShardSet::RunBlockingInParallel(U&& func) {
BlockingCounter bc{size()}; BlockingCounter bc{size()};
static_assert(std::is_invocable_v<U, EngineShard*>,
"Argument must be invocable EngineShard* as argument.");
static_assert(std::is_void_v<std::invoke_result_t<U, EngineShard*>>,
"Callable must not have a return value!");
for (uint32_t i = 0; i < size(); ++i) { for (uint32_t i = 0; i < size(); ++i) {
util::ProactorBase* dest = pp_->at(i); util::ProactorBase* dest = pp_->at(i);

View file

@ -114,6 +114,25 @@ error_code JournalSlice::Close() {
void JournalSlice::AddLogRecord(const Entry& entry, bool await) { void JournalSlice::AddLogRecord(const Entry& entry, bool await) {
DCHECK(ring_buffer_); DCHECK(ring_buffer_);
if (entry.opcode != Op::NOOP) {
// TODO: This is preparation for AOC style journaling, currently unused.
RingItem item;
item.lsn = lsn_;
lsn_++;
item.opcode = entry.opcode;
item.txid = entry.txid;
VLOG(1) << "Writing item [" << item.lsn << "]: " << entry.ToString();
ring_buffer_->EmplaceOrOverride(move(item));
if (shard_file_) {
string line = absl::StrCat(item.lsn, " ", entry.txid, " ", entry.opcode, "\n");
error_code ec = shard_file_->Write(io::Buffer(line), file_offset_, 0);
CHECK_EC(ec);
file_offset_ += line.size();
}
}
{ {
std::shared_lock lk(cb_mu_); std::shared_lock lk(cb_mu_);
DVLOG(2) << "AddLogRecord: run callbacks for " << entry.ToString() DVLOG(2) << "AddLogRecord: run callbacks for " << entry.ToString()
@ -123,26 +142,6 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) {
k_v.second(entry, await); k_v.second(entry, await);
} }
} }
if (entry.opcode == Op::NOOP)
return;
// TODO: This is preparation for AOC style journaling, currently unused.
RingItem item;
item.lsn = lsn_;
item.opcode = entry.opcode;
item.txid = entry.txid;
VLOG(1) << "Writing item [" << item.lsn << "]: " << entry.ToString();
ring_buffer_->EmplaceOrOverride(move(item));
if (shard_file_) {
string line = absl::StrCat(lsn_, " ", entry.txid, " ", entry.opcode, "\n");
error_code ec = shard_file_->Write(io::Buffer(line), file_offset_, 0);
CHECK_EC(ec);
file_offset_ += line.size();
}
++lsn_;
} }
uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) { uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) {

View file

@ -695,9 +695,23 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionCon
bool is_trans_cmd = CO::IsTransKind(cid->name()); bool is_trans_cmd = CO::IsTransKind(cid->name());
bool under_script = bool(dfly_cntx->conn_state.script_info); bool under_script = bool(dfly_cntx->conn_state.script_info);
bool blocked_by_loading = !dfly_cntx->journal_emulated && etl.gstate() == GlobalState::LOADING && bool allowed_by_state = true;
(cid->opt_mask() & CO::LOADING) == 0; switch (etl.gstate()) {
if (blocked_by_loading || etl.gstate() == GlobalState::SHUTTING_DOWN) { case GlobalState::LOADING:
allowed_by_state = dfly_cntx->journal_emulated || (cid->opt_mask() & CO::LOADING);
break;
case GlobalState::SHUTTING_DOWN:
allowed_by_state = false;
break;
case GlobalState::TAKEN_OVER:
allowed_by_state = cid->name() == "REPLCONF" || cid->name() == "SAVE";
break;
default:
break;
}
if (!allowed_by_state) {
VLOG(1) << "Command " << cid->name() << " not executed because global state is "
<< GlobalStateName(etl.gstate());
string err = StrCat("Can not execute during ", GlobalStateName(etl.gstate())); string err = StrCat("Can not execute during ", GlobalStateName(etl.gstate()));
(*dfly_cntx)->SendError(err); (*dfly_cntx)->SendError(err);
return false; return false;

View file

@ -228,6 +228,21 @@ void Replica::Pause(bool pause) {
sock_->proactor()->Await([&] { is_paused_ = pause; }); sock_->proactor()->Await([&] { is_paused_ = pause; });
} }
std::error_code Replica::TakeOver(std::string_view timeout) {
VLOG(1) << "Taking over";
std::error_code ec;
sock_->proactor()->Await(
[this, &ec, timeout] { ec = SendNextPhaseRequest(absl::StrCat("TAKEOVER ", timeout)); });
if (ec) {
// TODO: Handle timeout more gracefully.
return cntx_.ReportError(ec);
}
// If we successfully taken over, return and let server_family stop us.
return {};
}
void Replica::MainReplicationFb() { void Replica::MainReplicationFb() {
VLOG(1) << "Main replication fiber started"; VLOG(1) << "Main replication fiber started";
// Switch shard states to replication. // Switch shard states to replication.
@ -610,7 +625,7 @@ error_code Replica::InitiateDflySync() {
RETURN_ON_ERR(cntx_.GetError()); RETURN_ON_ERR(cntx_.GetError());
// Send DFLY SYNC. // Send DFLY SYNC.
if (auto ec = SendNextPhaseRequest(false); ec) { if (auto ec = SendNextPhaseRequest("SYNC"); ec) {
return cntx_.ReportError(ec); return cntx_.ReportError(ec);
} }
@ -626,7 +641,7 @@ error_code Replica::InitiateDflySync() {
return cntx_.GetError(); return cntx_.GetError();
// Send DFLY STARTSTABLE. // Send DFLY STARTSTABLE.
if (auto ec = SendNextPhaseRequest(true); ec) { if (auto ec = SendNextPhaseRequest("STARTSTABLE"); ec) {
return cntx_.ReportError(ec); return cntx_.ReportError(ec);
} }
@ -770,11 +785,10 @@ void Replica::DefaultErrorHandler(const GenericError& err) {
CloseSocket(); CloseSocket();
} }
error_code Replica::SendNextPhaseRequest(bool stable) { error_code Replica::SendNextPhaseRequest(string_view kind) {
ReqSerializer serializer{sock_.get()}; ReqSerializer serializer{sock_.get()};
// Ask master to start sending replication stream // Ask master to start sending replication stream
string_view kind = (stable) ? "STARTSTABLE"sv : "SYNC"sv;
string request = StrCat("DFLY ", kind, " ", master_context_.dfly_session_id); string request = StrCat("DFLY ", kind, " ", master_context_.dfly_session_id);
VLOG(1) << "Sending: " << request; VLOG(1) << "Sending: " << request;

View file

@ -115,6 +115,8 @@ class Replica {
void Pause(bool pause); void Pause(bool pause);
std::error_code TakeOver(std::string_view timeout);
std::string_view MasterId() const { std::string_view MasterId() const {
return master_context_.master_repl_id; return master_context_.master_repl_id;
} }
@ -141,8 +143,8 @@ class Replica {
void JoinAllFlows(); // Join all flows if possible. void JoinAllFlows(); // Join all flows if possible.
void SetShardStates(bool replica); // Call SetReplica(replica) on all shards. void SetShardStates(bool replica); // Call SetReplica(replica) on all shards.
// Send DFLY SYNC or DFLY STARTSTABLE if stable is true. // Send DFLY ${kind} to the master instance.
std::error_code SendNextPhaseRequest(bool stable); std::error_code SendNextPhaseRequest(std::string_view kind);
void DefaultErrorHandler(const GenericError& err); void DefaultErrorHandler(const GenericError& err);

View file

@ -1054,12 +1054,13 @@ GenericError ServerFamily::DoSave(bool new_version, string_view basename, Transa
// Manage global state. // Manage global state.
GlobalState new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::SAVING); GlobalState new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::SAVING);
if (new_state != GlobalState::SAVING) { if (new_state != GlobalState::SAVING && new_state != GlobalState::TAKEN_OVER) {
return {make_error_code(errc::operation_in_progress), return {make_error_code(errc::operation_in_progress),
StrCat(GlobalStateName(new_state), " - can not save database")}; StrCat(GlobalStateName(new_state), " - can not save database")};
} }
absl::Cleanup rev_state = [this] { absl::Cleanup rev_state = [this, new_state] {
service_.SwitchState(GlobalState::SAVING, GlobalState::ACTIVE); if (new_state == GlobalState::SAVING)
service_.SwitchState(GlobalState::SAVING, GlobalState::ACTIVE);
}; };
absl::Time start = absl::Now(); absl::Time start = absl::Now();
@ -1259,6 +1260,19 @@ void ServerFamily::BreakOnShutdown() {
dfly_cmd_->BreakOnShutdown(); dfly_cmd_->BreakOnShutdown();
} }
bool ServerFamily::AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter) {
auto start = absl::Now();
for (auto* listener : listeners_) {
absl::Duration remaining_time = timeout - (absl::Now() - start);
if (remaining_time < absl::Nanoseconds(0) ||
!listener->AwaitDispatches(remaining_time, filter)) {
return false;
}
}
return true;
}
string GetPassword() { string GetPassword() {
string flag = GetFlag(FLAGS_requirepass); string flag = GetFlag(FLAGS_requirepass);
if (!flag.empty()) { if (!flag.empty()) {
@ -1933,6 +1947,41 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
} }
} }
void ServerFamily::ReplTakeOver(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "Starting take over";
VLOG(1) << "Acquire replica lock";
unique_lock lk(replicaof_mu_);
float_t timeout_sec;
if (!absl::SimpleAtof(ArgS(args, 0), &timeout_sec)) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (timeout_sec < 0) {
return (*cntx)->SendError("timeout is negative");
}
if (ServerState::tlocal()->is_master)
return (*cntx)->SendError("Already a master instance");
auto repl_ptr = replica_;
CHECK(repl_ptr);
auto info = replica_->GetInfo();
if (!info.full_sync_done) {
return (*cntx)->SendError("Full sync not done");
}
std::error_code ec = replica_->TakeOver(ArgS(args, 0));
if (ec)
return (*cntx)->SendError("Couldn't execute takeover");
LOG(INFO) << "Takeover successful, promoting this instance to master.";
service_.proactor_pool().AwaitFiberOnAll(
[&](util::ProactorBase* pb) { ServerState::tlocal()->is_master = true; });
replica_->Stop();
replica_.reset();
return (*cntx)->SendOk();
}
void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
if (args.size() % 2 == 1) if (args.size() % 2 == 1)
goto err; goto err;
@ -2083,7 +2132,7 @@ void ServerFamily::Latency(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendError(kSyntaxErr); (*cntx)->SendError(kSyntaxErr);
} }
void ServerFamily::_Shutdown(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::ShutdownCmd(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 1) { if (args.size() > 1) {
(*cntx)->SendError(kSyntaxErr); (*cntx)->SendError(kSyntaxErr);
return; return;
@ -2145,9 +2194,11 @@ void ServerFamily::Register(CommandRegistry* registry) {
<< CI{"LATENCY", CO::NOSCRIPT | CO::LOADING | CO::FAST, -2, 0, 0, 0}.HFUNC(Latency) << CI{"LATENCY", CO::NOSCRIPT | CO::LOADING | CO::FAST, -2, 0, 0, 0}.HFUNC(Latency)
<< CI{"MEMORY", kMemOpts, -2, 0, 0, 0}.HFUNC(Memory) << CI{"MEMORY", kMemOpts, -2, 0, 0, 0}.HFUNC(Memory)
<< CI{"SAVE", CO::ADMIN | CO::GLOBAL_TRANS, -1, 0, 0, 0}.HFUNC(Save) << CI{"SAVE", CO::ADMIN | CO::GLOBAL_TRANS, -1, 0, 0, 0}.HFUNC(Save)
<< CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC(_Shutdown) << CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC(
ShutdownCmd)
<< CI{"SLAVEOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf) << CI{"SLAVEOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf)
<< CI{"REPLICAOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf) << CI{"REPLICAOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf)
<< CI{"REPLTAKEOVER", CO::ADMIN | CO::GLOBAL_TRANS, 2, 0, 0, 0}.HFUNC(ReplTakeOver)
<< CI{"REPLCONF", CO::ADMIN | CO::LOADING, -1, 0, 0, 0}.HFUNC(ReplConf) << CI{"REPLCONF", CO::ADMIN | CO::LOADING, -1, 0, 0, 0}.HFUNC(ReplConf)
<< CI{"ROLE", CO::LOADING | CO::FAST | CO::NOSCRIPT, 1, 0, 0, 0}.HFUNC(Role) << CI{"ROLE", CO::LOADING | CO::FAST | CO::NOSCRIPT, 1, 0, 0, 0}.HFUNC(Role)
<< CI{"SLOWLOG", CO::ADMIN | CO::FAST, -2, 0, 0, 0}.SetHandler(SlowLog) << CI{"SLOWLOG", CO::ADMIN | CO::FAST, -2, 0, 0, 0}.SetHandler(SlowLog)

View file

@ -94,6 +94,8 @@ class ServerFamily {
void Register(CommandRegistry* registry); void Register(CommandRegistry* registry);
void Shutdown(); void Shutdown();
void ShutdownCmd(CmdArgList args, ConnectionContext* cntx);
Service& service() { Service& service() {
return service_; return service_;
} }
@ -154,6 +156,9 @@ class ServerFamily {
void BreakOnShutdown(); void BreakOnShutdown();
bool AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter);
private: private:
uint32_t shard_count() const { uint32_t shard_count() const {
return shard_set->size(); return shard_set->size();
@ -174,14 +179,13 @@ class ServerFamily {
void Latency(CmdArgList args, ConnectionContext* cntx); void Latency(CmdArgList args, ConnectionContext* cntx);
void Psync(CmdArgList args, ConnectionContext* cntx); void Psync(CmdArgList args, ConnectionContext* cntx);
void ReplicaOf(CmdArgList args, ConnectionContext* cntx); void ReplicaOf(CmdArgList args, ConnectionContext* cntx);
void ReplTakeOver(CmdArgList args, ConnectionContext* cntx);
void ReplConf(CmdArgList args, ConnectionContext* cntx); void ReplConf(CmdArgList args, ConnectionContext* cntx);
void Role(CmdArgList args, ConnectionContext* cntx); void Role(CmdArgList args, ConnectionContext* cntx);
void Save(CmdArgList args, ConnectionContext* cntx); void Save(CmdArgList args, ConnectionContext* cntx);
void Script(CmdArgList args, ConnectionContext* cntx); void Script(CmdArgList args, ConnectionContext* cntx);
void Sync(CmdArgList args, ConnectionContext* cntx); void Sync(CmdArgList args, ConnectionContext* cntx);
void _Shutdown(CmdArgList args, ConnectionContext* cntx);
void SyncGeneric(std::string_view repl_master_id, uint64_t offs, ConnectionContext* cntx); void SyncGeneric(std::string_view repl_master_id, uint64_t offs, ConnectionContext* cntx);
// Returns the number of loaded keys if successfull. // Returns the number of loaded keys if successfull.

View file

@ -507,6 +507,7 @@ bool Transaction::RunInShard(EngineShard* shard, bool txq_ooo) {
if (IsGlobal()) { if (IsGlobal()) {
DCHECK(!awaked_prerun && !became_suspended); // Global transactions can not be blocking. DCHECK(!awaked_prerun && !became_suspended); // Global transactions can not be blocking.
VLOG(2) << "Releasing shard lock";
shard->shard_lock()->Release(Mode()); shard->shard_lock()->Release(Mode());
} else { // not global. } else { // not global.
largs = GetLockArgs(idx); largs = GetLockArgs(idx);
@ -572,6 +573,7 @@ void Transaction::ScheduleInternal() {
// Lock shards // Lock shards
auto cb = [mode](EngineShard* shard) { shard->shard_lock()->Acquire(mode); }; auto cb = [mode](EngineShard* shard) { shard->shard_lock()->Acquire(mode); };
shard_set->RunBriefInParallel(std::move(cb)); shard_set->RunBriefInParallel(std::move(cb));
VLOG(1) << "Global shard lock acquired";
} else { } else {
num_shards = unique_shard_cnt_; num_shards = unique_shard_cnt_;
DCHECK_GT(num_shards, 0u); DCHECK_GT(num_shards, 0u);
@ -893,8 +895,8 @@ void Transaction::RunQuickie(EngineShard* shard) {
auto& sd = shard_data_[SidToId(unique_shard_id_)]; auto& sd = shard_data_[SidToId(unique_shard_id_)];
DCHECK_EQ(0, sd.local_mask & (KEYLOCK_ACQUIRED | OUT_OF_ORDER)); DCHECK_EQ(0, sd.local_mask & (KEYLOCK_ACQUIRED | OUT_OF_ORDER));
DVLOG(1) << "RunQuickSingle " << DebugId() << " " << shard->shard_id() << " " << args_[0]; DVLOG(1) << "RunQuickSingle " << DebugId() << " " << shard->shard_id();
DCHECK(cb_ptr_) << DebugId() << " " << shard->shard_id() << " " << args_[0]; DCHECK(cb_ptr_) << DebugId() << " " << shard->shard_id();
// Calling the callback in somewhat safe way // Calling the callback in somewhat safe way
try { try {

View file

@ -3,6 +3,7 @@ import time
import subprocess import subprocess
import aiohttp import aiohttp
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from redis.asyncio import Redis as RedisClient
from dataclasses import dataclass from dataclasses import dataclass
@ -32,6 +33,10 @@ class DflyInstance:
self.args = args self.args = args
self.params = params self.params = params
self.proc = None self.proc = None
self._client : Optional[RedisClient] = None
def client(self) -> RedisClient:
return RedisClient(port=self.port)
def start(self): def start(self):
self._start() self._start()

View file

@ -937,7 +937,6 @@ async def assert_lag_condition(inst, client, condition):
assert False, "Lag has never satisfied condition!" assert False, "Lag has never satisfied condition!"
@dfly_args({"proactor_threads": 2}) @dfly_args({"proactor_threads": 2})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000): async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000):
@ -1069,3 +1068,114 @@ async def test_readonly_script(df_local_factory):
assert False assert False
except aioredis.ResponseError as roe: except aioredis.ResponseError as roe:
assert 'READONLY ' in str(roe) assert 'READONLY ' in str(roe)
take_over_cases = [
[2, 2],
[2, 4],
[4, 2],
[8, 8],
]
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_counters(df_local_factory, master_threads, replica_threads):
master = df_local_factory.create(proactor_threads=master_threads,
port=BASE_PORT,
# vmodule="journal_slice=2,dflycmd=2,main_service=1",
logtostderr=True)
replica1 = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=replica_threads)
replica2 = df_local_factory.create(
port=BASE_PORT+2, proactor_threads=replica_threads)
replica3 = df_local_factory.create(
port=BASE_PORT+3, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica1, replica2, replica3])
async with (
master.client() as c_master,
replica1.client() as c1,
master.client() as c_blocking,
replica2.client() as c2,
replica3.client() as c3,
):
await c1.execute_command(f"REPLICAOF localhost {master.port}")
await c2.execute_command(f"REPLICAOF localhost {master.port}")
await c3.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c1)
async def counter(key):
value = 0
await c_master.execute_command(f"SET {key} 0")
start = time.time()
while time.time() - start < 20:
try:
value = await c_master.execute_command(f"INCR {key}")
except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e:
break
else:
assert False, "The incrementing loop should be exited with a connection error"
return key, value
async def block_during_takeover():
"Add a blocking command during takeover to make sure it doesn't block it."
# TODO: We need to make takeover interrupt blocking commands.
return
try:
await c_blocking.execute_command("BLPOP BLOCKING_KEY1 BLOCKING_KEY2 10")
except redis.exceptions.ConnectionError:
pass
async def delayed_takeover():
await asyncio.sleep(1)
await c1.execute_command(f"REPLTAKEOVER 5")
_, _, *results = await asyncio.gather(delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)])
assert await c1.execute_command("role") == [b'master', []]
for key, client_value in results:
replicated_value = await c1.get(key)
assert client_value == int(replicated_value)
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_seeder(df_local_factory, df_seeder_factory, master_threads, replica_threads):
master = df_local_factory.create(proactor_threads=master_threads,
port=BASE_PORT,
dbfilename=f"dump_{master_threads}_{replica_threads}",
logtostderr=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
async with (
master.client() as c_master,
replica.client() as c_replica,
):
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
async def seed():
await seeder.run(target_ops=3000)
fill_task = asyncio.create_task(seed())
# Give the seeder a bit of time.
await asyncio.sleep(1)
await c_replica.execute_command(f"REPLTAKEOVER 5")
seeder.stop()
assert await c_replica.execute_command("role") == [b'master', []]
# Need to wait a bit to give time to write the shutdown snapshot
await asyncio.sleep(1)
assert master.proc.poll() == 0, "Master process did not exit correctly."
master.start()
await wait_available_async(c_master)
capture = await seeder.capture()
assert await seeder.compare(capture, port=replica.port)

View file

@ -2,6 +2,7 @@ import itertools
import sys import sys
import asyncio import asyncio
from redis import asyncio as aioredis from redis import asyncio as aioredis
import redis
import random import random
import string import string
import itertools import itertools
@ -342,7 +343,7 @@ class DflySeeder:
assert await seeder.compare(capture, port=1112) assert await seeder.compare(capture, port=1112)
""" """
def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[]): def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[], stop_on_failure=True):
self.gen = CommandGenerator( self.gen = CommandGenerator(
keys, val_size, batch_size, max_multikey, unsupported_types keys, val_size, batch_size, max_multikey, unsupported_types
) )
@ -350,6 +351,7 @@ class DflySeeder:
self.dbcount = dbcount self.dbcount = dbcount
self.multi_transaction_probability = multi_transaction_probability self.multi_transaction_probability = multi_transaction_probability
self.stop_flag = False self.stop_flag = False
self.stop_on_failure = stop_on_failure
self.log_file = log_file self.log_file = log_file
if self.log_file is not None: if self.log_file is not None:
@ -496,6 +498,9 @@ class DflySeeder:
try: try:
await pipe.execute() await pipe.execute()
except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e:
if self.stop_on_failure:
raise SystemExit(e)
except Exception as e: except Exception as e:
raise SystemExit(e) raise SystemExit(e)
queue.task_done() queue.task_done()