feat: Support atomic replica takeover (#1314)

* fix(server): Initialize ServerFamily with all listeners.

- Add a test for CLIENT LIST which is the visible result of this.

* use std move

* feat: Implement replicas take over

* Basic test

* Address CR comments

* Write a better test. Sadly it fails

* chore: Expose AwaitDispatches for reuse in takeover

* Ensure that no commands can execute during or after a takeover

* CR progress

* Actually disable the expiration

* Improve tests coverage

* Fix the dispatch waiting code

* Improve testing coverage and fix a shutdown snaphot bug

* don't replicate a replica
This commit is contained in:
Roy Jacobson 2023-07-02 16:11:28 +02:00 committed by GitHub
parent e71fae7eea
commit 4babed54d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 392 additions and 68 deletions

View file

@ -179,6 +179,28 @@ void Listener::PreAcceptLoop(util::ProactorBase* pb) {
per_thread_.resize(pool()->size());
}
bool Listener::AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter) {
absl::Time start = absl::Now();
while (absl::Now() - start < timeout) {
std::atomic<bool> any_connection_dispatching = false;
auto cb = [&any_connection_dispatching, &filter](unsigned thread_index,
util::Connection* conn) {
if (filter(conn) && static_cast<Connection*>(conn)->IsCurrentlyDispatching()) {
any_connection_dispatching.store(true);
}
};
this->TraverseConnections(cb);
if (!any_connection_dispatching.load()) {
return true;
}
VLOG(1) << "A command is still dispatching, let's wait for it";
ThisFiber::SleepFor(100us);
}
return false;
}
void Listener::PreShutdown() {
// Iterate on all connections and allow them to finish their commands for
// a short period.
@ -188,26 +210,8 @@ void Listener::PreShutdown() {
// at this stage since we're in SHUTDOWN mode.
// If a command is running for too long we give up and proceed.
const absl::Duration kDispatchShutdownTimeout = absl::Milliseconds(10);
absl::Time start = absl::Now();
bool success = false;
while (absl::Now() - start < kDispatchShutdownTimeout) {
std::atomic<bool> any_connection_dispatching = false;
auto cb = [&any_connection_dispatching](unsigned thread_index, util::Connection* conn) {
if (static_cast<Connection*>(conn)->IsCurrentlyDispatching()) {
any_connection_dispatching.store(true);
}
};
this->TraverseConnections(cb);
if (!any_connection_dispatching.load()) {
success = true;
break;
}
VLOG(1) << "A command is still dispatching, let's wait for it";
ThisFiber::SleepFor(100us);
}
if (!success) {
if (!AwaitDispatches(kDispatchShutdownTimeout, [](util::Connection*) { return true; })) {
LOG(WARNING) << "Some commands are still being dispatched but didn't conclude in time. "
"Proceeding in shutdown.";
}

View file

@ -24,6 +24,11 @@ class Listener : public util::ListenerInterface {
std::error_code ConfigureServerSocket(int fd) final;
// Wait until all connections that pass the filter have stopped dispatching or until a timeout has
// run out. Returns true if the all connections have stopped dispatching.
bool AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter);
private:
util::Connection* NewConnection(ProactorBase* proactor) final;
ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final;
@ -33,7 +38,6 @@ class Listener : public util::ListenerInterface {
void PreAcceptLoop(ProactorBase* pb) final;
void PreShutdown() final;
void PostShutdown() final;
std::unique_ptr<util::HttpListenerBase> http_base_;

View file

@ -36,6 +36,8 @@ const char* DebugString(OpStatus op) {
return "ENTRIES ADDED IS TO SMALL";
case OpStatus::INVALID_NUMERIC_RESULT:
return "INVALID NUMERIC RESULT";
case OpStatus::CANCELLED:
return "CANCELLED";
}
return "Unknown Error Code"; // we should not be here, but this is how enums works in c++
}

View file

@ -26,6 +26,7 @@ enum class OpStatus : uint16_t {
STREAM_ID_SMALL,
ENTRIES_ADDED_SMALL,
INVALID_NUMERIC_RESULT,
CANCELLED,
};
const char* DebugString(OpStatus op);

View file

@ -44,6 +44,8 @@ const char* GlobalStateName(GlobalState s) {
return "SAVING";
case GlobalState::SHUTTING_DOWN:
return "SHUTTING DOWN";
case GlobalState::TAKEN_OVER:
return "TAKEN OVER";
}
ABSL_UNREACHABLE();
}

View file

@ -134,6 +134,7 @@ enum class GlobalState : uint8_t {
LOADING,
SAVING,
SHUTTING_DOWN,
TAKEN_OVER,
};
enum class TimeUnit : uint8_t { SEC, MSEC };

View file

@ -894,8 +894,8 @@ pair<PrimeIterator, ExpireIterator> DbSlice::ExpireIfNeeded(const Context& cntx,
// TODO: to employ multi-generation update of expire-base and the underlying values.
time_t expire_time = ExpireTime(expire_it);
// Never do expiration on replica.
if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica())
// Never do expiration on replica or if expiration is disabled.
if (time_t(cntx.time_now_ms) < expire_time || owner_->IsReplica() || !expire_allowed_)
return make_pair(it, expire_it);
// Replicate expiry

View file

@ -324,6 +324,10 @@ class DbSlice {
// Resets the event counter for updates/insertions
void ResetUpdateEvents();
void SetExpireAllowed(bool is_allowed) {
expire_allowed_ = is_allowed;
}
private:
std::pair<PrimeIterator, bool> AddOrUpdateInternal(const Context& cntx, std::string_view key,
PrimeValue obj, uint64_t expire_at_ms,
@ -351,6 +355,7 @@ class DbSlice {
EngineShard* owner_;
time_t expire_base_[2]; // Used for expire logic, represents a real clock.
bool expire_allowed_ = true;
uint64_t version_ = 1; // Used to version entries in the PrimeTable.
ssize_t memory_budget_ = SSIZE_MAX;

View file

@ -14,10 +14,12 @@
#include "base/flags.h"
#include "base/logging.h"
#include "facade/dragonfly_connection.h"
#include "facade/dragonfly_listener.h"
#include "server/engine_shard_set.h"
#include "server/error.h"
#include "server/journal/journal.h"
#include "server/journal/streamer.h"
#include "server/main_service.h"
#include "server/rdb_save.h"
#include "server/script_mgr.h"
#include "server/server_family.h"
@ -65,15 +67,27 @@ std::string_view SyncStateName(DflyCmd::SyncState sync_state) {
}
struct TransactionGuard {
constexpr static auto kEmptyCb = [](Transaction* t, EngineShard* shard) { return OpStatus::OK; };
static OpStatus ExitGuardCb(Transaction* t, EngineShard* shard) {
shard->db_slice().SetExpireAllowed(true);
return OpStatus::OK;
};
TransactionGuard(Transaction* t) : t(t) {
explicit TransactionGuard(Transaction* t, bool disable_expirations = false) : t(t) {
t->Schedule();
t->Execute(kEmptyCb, false);
t->Execute(
[disable_expirations](Transaction* t, EngineShard* shard) {
if (disable_expirations) {
shard->db_slice().SetExpireAllowed(!disable_expirations);
}
return OpStatus::OK;
},
false);
VLOG(1) << "Transaction guard engaged";
}
~TransactionGuard() {
t->Execute(kEmptyCb, true);
VLOG(1) << "Releasing transaction guard";
t->Execute(ExitGuardCb, true);
}
Transaction* t;
@ -110,6 +124,10 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
return StartStable(args, cntx);
}
if (sub_cmd == "TAKEOVER" && args.size() == 3) {
return TakeOver(args, cntx);
}
if (sub_cmd == "EXPIRE") {
return Expire(args, cntx);
}
@ -316,7 +334,6 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) {
StopFullSyncInThread(flow, shard);
status = StartStableSyncInThread(flow, &replica_ptr->cntx, shard);
return OpStatus::OK;
};
shard_set->RunBlockingInParallel(std::move(cb));
@ -331,6 +348,80 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) {
return rb->SendOk();
}
void DflyCmd::TakeOver(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
string_view sync_id_str = ArgS(args, 2);
float timeout;
if (!absl::SimpleAtof(ArgS(args, 1), &timeout)) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (timeout < 0) {
return (*cntx)->SendError("timeout is negative");
}
VLOG(1) << "Got DFLY TAKEOVER " << sync_id_str;
auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb);
if (!sync_id)
return;
unique_lock lk(replica_ptr->mu);
if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::STABLE_SYNC, rb))
return;
LOG(INFO) << "Takeover initiated, locking down the database.";
sf_->service().SwitchState(GlobalState::ACTIVE, GlobalState::TAKEN_OVER);
absl::Duration timeout_dur = absl::Seconds(timeout);
absl::Time start = absl::Now();
AggregateStatus status;
// TODO: We should cancel blocking commands before awaiting all
// dispatches to finish.
if (!sf_->AwaitDispatches(timeout_dur, [self = cntx->owner()](util::Connection* conn) {
// The only command that is currently dispatching should be the takeover command -
// so we wait until this is true.
return conn != self;
})) {
LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << timeout_dur;
status = OpStatus::TIMED_OUT;
}
TransactionGuard tg{cntx->transaction, /*disable_expirations=*/true};
if (*status == OpStatus::OK) {
auto cb = [&cntx = replica_ptr->cntx, replica_ptr = replica_ptr, timeout_dur, start,
&status](EngineShard* shard) {
FlowInfo* flow = &replica_ptr->flows[shard->shard_id()];
shard->journal()->RecordEntry(0, journal::Op::PING, 0, 0, {}, true);
while (flow->last_acked_lsn < shard->journal()->GetLsn()) {
if (absl::Now() - start > timeout_dur) {
LOG(WARNING) << "Couldn't synchronize with replica for takeover in time.";
status = OpStatus::TIMED_OUT;
return;
}
if (cntx.IsCancelled()) {
status = OpStatus::CANCELLED;
return;
}
ThisFiber::SleepFor(1ms);
}
};
shard_set->RunBlockingInParallel(std::move(cb));
}
if (*status != OpStatus::OK) {
sf_->service().SwitchState(GlobalState::TAKEN_OVER, GlobalState::ACTIVE);
return rb->SendError("Takeover failed!");
}
(*cntx)->SendOk();
VLOG(1) << "Takeover accepted, shutting down.";
return sf_->ShutdownCmd({}, cntx);
}
void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
cntx->transaction->ScheduleSingleHop([](Transaction* t, EngineShard* shard) {

View file

@ -99,7 +99,7 @@ class DflyCmd {
struct ReplicaInfo {
ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port,
Context::ErrHandler err_handler)
: state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{address},
: state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{std::move(address)},
listening_port(listening_port), flows{flow_count} {
}
@ -151,6 +151,10 @@ class DflyCmd {
// Switch to stable state replication.
void StartStable(CmdArgList args, ConnectionContext* cntx);
// TAKEOVER <syncid>
// Shut this master down atomically with replica promotion.
void TakeOver(CmdArgList args, ConnectionContext* cntx);
// EXPIRE
// Check all keys for expiry.
void Expire(CmdArgList args, ConnectionContext* cntx);

View file

@ -331,6 +331,10 @@ void EngineShardSet::RunBriefInParallel(U&& func, P&& pred) const {
template <typename U> void EngineShardSet::RunBlockingInParallel(U&& func) {
BlockingCounter bc{size()};
static_assert(std::is_invocable_v<U, EngineShard*>,
"Argument must be invocable EngineShard* as argument.");
static_assert(std::is_void_v<std::invoke_result_t<U, EngineShard*>>,
"Callable must not have a return value!");
for (uint32_t i = 0; i < size(); ++i) {
util::ProactorBase* dest = pp_->at(i);

View file

@ -114,6 +114,25 @@ error_code JournalSlice::Close() {
void JournalSlice::AddLogRecord(const Entry& entry, bool await) {
DCHECK(ring_buffer_);
if (entry.opcode != Op::NOOP) {
// TODO: This is preparation for AOC style journaling, currently unused.
RingItem item;
item.lsn = lsn_;
lsn_++;
item.opcode = entry.opcode;
item.txid = entry.txid;
VLOG(1) << "Writing item [" << item.lsn << "]: " << entry.ToString();
ring_buffer_->EmplaceOrOverride(move(item));
if (shard_file_) {
string line = absl::StrCat(item.lsn, " ", entry.txid, " ", entry.opcode, "\n");
error_code ec = shard_file_->Write(io::Buffer(line), file_offset_, 0);
CHECK_EC(ec);
file_offset_ += line.size();
}
}
{
std::shared_lock lk(cb_mu_);
DVLOG(2) << "AddLogRecord: run callbacks for " << entry.ToString()
@ -123,26 +142,6 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) {
k_v.second(entry, await);
}
}
if (entry.opcode == Op::NOOP)
return;
// TODO: This is preparation for AOC style journaling, currently unused.
RingItem item;
item.lsn = lsn_;
item.opcode = entry.opcode;
item.txid = entry.txid;
VLOG(1) << "Writing item [" << item.lsn << "]: " << entry.ToString();
ring_buffer_->EmplaceOrOverride(move(item));
if (shard_file_) {
string line = absl::StrCat(lsn_, " ", entry.txid, " ", entry.opcode, "\n");
error_code ec = shard_file_->Write(io::Buffer(line), file_offset_, 0);
CHECK_EC(ec);
file_offset_ += line.size();
}
++lsn_;
}
uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) {

View file

@ -695,9 +695,23 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionCon
bool is_trans_cmd = CO::IsTransKind(cid->name());
bool under_script = bool(dfly_cntx->conn_state.script_info);
bool blocked_by_loading = !dfly_cntx->journal_emulated && etl.gstate() == GlobalState::LOADING &&
(cid->opt_mask() & CO::LOADING) == 0;
if (blocked_by_loading || etl.gstate() == GlobalState::SHUTTING_DOWN) {
bool allowed_by_state = true;
switch (etl.gstate()) {
case GlobalState::LOADING:
allowed_by_state = dfly_cntx->journal_emulated || (cid->opt_mask() & CO::LOADING);
break;
case GlobalState::SHUTTING_DOWN:
allowed_by_state = false;
break;
case GlobalState::TAKEN_OVER:
allowed_by_state = cid->name() == "REPLCONF" || cid->name() == "SAVE";
break;
default:
break;
}
if (!allowed_by_state) {
VLOG(1) << "Command " << cid->name() << " not executed because global state is "
<< GlobalStateName(etl.gstate());
string err = StrCat("Can not execute during ", GlobalStateName(etl.gstate()));
(*dfly_cntx)->SendError(err);
return false;

View file

@ -228,6 +228,21 @@ void Replica::Pause(bool pause) {
sock_->proactor()->Await([&] { is_paused_ = pause; });
}
std::error_code Replica::TakeOver(std::string_view timeout) {
VLOG(1) << "Taking over";
std::error_code ec;
sock_->proactor()->Await(
[this, &ec, timeout] { ec = SendNextPhaseRequest(absl::StrCat("TAKEOVER ", timeout)); });
if (ec) {
// TODO: Handle timeout more gracefully.
return cntx_.ReportError(ec);
}
// If we successfully taken over, return and let server_family stop us.
return {};
}
void Replica::MainReplicationFb() {
VLOG(1) << "Main replication fiber started";
// Switch shard states to replication.
@ -610,7 +625,7 @@ error_code Replica::InitiateDflySync() {
RETURN_ON_ERR(cntx_.GetError());
// Send DFLY SYNC.
if (auto ec = SendNextPhaseRequest(false); ec) {
if (auto ec = SendNextPhaseRequest("SYNC"); ec) {
return cntx_.ReportError(ec);
}
@ -626,7 +641,7 @@ error_code Replica::InitiateDflySync() {
return cntx_.GetError();
// Send DFLY STARTSTABLE.
if (auto ec = SendNextPhaseRequest(true); ec) {
if (auto ec = SendNextPhaseRequest("STARTSTABLE"); ec) {
return cntx_.ReportError(ec);
}
@ -770,11 +785,10 @@ void Replica::DefaultErrorHandler(const GenericError& err) {
CloseSocket();
}
error_code Replica::SendNextPhaseRequest(bool stable) {
error_code Replica::SendNextPhaseRequest(string_view kind) {
ReqSerializer serializer{sock_.get()};
// Ask master to start sending replication stream
string_view kind = (stable) ? "STARTSTABLE"sv : "SYNC"sv;
string request = StrCat("DFLY ", kind, " ", master_context_.dfly_session_id);
VLOG(1) << "Sending: " << request;

View file

@ -115,6 +115,8 @@ class Replica {
void Pause(bool pause);
std::error_code TakeOver(std::string_view timeout);
std::string_view MasterId() const {
return master_context_.master_repl_id;
}
@ -141,8 +143,8 @@ class Replica {
void JoinAllFlows(); // Join all flows if possible.
void SetShardStates(bool replica); // Call SetReplica(replica) on all shards.
// Send DFLY SYNC or DFLY STARTSTABLE if stable is true.
std::error_code SendNextPhaseRequest(bool stable);
// Send DFLY ${kind} to the master instance.
std::error_code SendNextPhaseRequest(std::string_view kind);
void DefaultErrorHandler(const GenericError& err);

View file

@ -1054,11 +1054,12 @@ GenericError ServerFamily::DoSave(bool new_version, string_view basename, Transa
// Manage global state.
GlobalState new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::SAVING);
if (new_state != GlobalState::SAVING) {
if (new_state != GlobalState::SAVING && new_state != GlobalState::TAKEN_OVER) {
return {make_error_code(errc::operation_in_progress),
StrCat(GlobalStateName(new_state), " - can not save database")};
}
absl::Cleanup rev_state = [this] {
absl::Cleanup rev_state = [this, new_state] {
if (new_state == GlobalState::SAVING)
service_.SwitchState(GlobalState::SAVING, GlobalState::ACTIVE);
};
@ -1259,6 +1260,19 @@ void ServerFamily::BreakOnShutdown() {
dfly_cmd_->BreakOnShutdown();
}
bool ServerFamily::AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter) {
auto start = absl::Now();
for (auto* listener : listeners_) {
absl::Duration remaining_time = timeout - (absl::Now() - start);
if (remaining_time < absl::Nanoseconds(0) ||
!listener->AwaitDispatches(remaining_time, filter)) {
return false;
}
}
return true;
}
string GetPassword() {
string flag = GetFlag(FLAGS_requirepass);
if (!flag.empty()) {
@ -1933,6 +1947,41 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
}
}
void ServerFamily::ReplTakeOver(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "Starting take over";
VLOG(1) << "Acquire replica lock";
unique_lock lk(replicaof_mu_);
float_t timeout_sec;
if (!absl::SimpleAtof(ArgS(args, 0), &timeout_sec)) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (timeout_sec < 0) {
return (*cntx)->SendError("timeout is negative");
}
if (ServerState::tlocal()->is_master)
return (*cntx)->SendError("Already a master instance");
auto repl_ptr = replica_;
CHECK(repl_ptr);
auto info = replica_->GetInfo();
if (!info.full_sync_done) {
return (*cntx)->SendError("Full sync not done");
}
std::error_code ec = replica_->TakeOver(ArgS(args, 0));
if (ec)
return (*cntx)->SendError("Couldn't execute takeover");
LOG(INFO) << "Takeover successful, promoting this instance to master.";
service_.proactor_pool().AwaitFiberOnAll(
[&](util::ProactorBase* pb) { ServerState::tlocal()->is_master = true; });
replica_->Stop();
replica_.reset();
return (*cntx)->SendOk();
}
void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
if (args.size() % 2 == 1)
goto err;
@ -2083,7 +2132,7 @@ void ServerFamily::Latency(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendError(kSyntaxErr);
}
void ServerFamily::_Shutdown(CmdArgList args, ConnectionContext* cntx) {
void ServerFamily::ShutdownCmd(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 1) {
(*cntx)->SendError(kSyntaxErr);
return;
@ -2145,9 +2194,11 @@ void ServerFamily::Register(CommandRegistry* registry) {
<< CI{"LATENCY", CO::NOSCRIPT | CO::LOADING | CO::FAST, -2, 0, 0, 0}.HFUNC(Latency)
<< CI{"MEMORY", kMemOpts, -2, 0, 0, 0}.HFUNC(Memory)
<< CI{"SAVE", CO::ADMIN | CO::GLOBAL_TRANS, -1, 0, 0, 0}.HFUNC(Save)
<< CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC(_Shutdown)
<< CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC(
ShutdownCmd)
<< CI{"SLAVEOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf)
<< CI{"REPLICAOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf)
<< CI{"REPLTAKEOVER", CO::ADMIN | CO::GLOBAL_TRANS, 2, 0, 0, 0}.HFUNC(ReplTakeOver)
<< CI{"REPLCONF", CO::ADMIN | CO::LOADING, -1, 0, 0, 0}.HFUNC(ReplConf)
<< CI{"ROLE", CO::LOADING | CO::FAST | CO::NOSCRIPT, 1, 0, 0, 0}.HFUNC(Role)
<< CI{"SLOWLOG", CO::ADMIN | CO::FAST, -2, 0, 0, 0}.SetHandler(SlowLog)

View file

@ -94,6 +94,8 @@ class ServerFamily {
void Register(CommandRegistry* registry);
void Shutdown();
void ShutdownCmd(CmdArgList args, ConnectionContext* cntx);
Service& service() {
return service_;
}
@ -154,6 +156,9 @@ class ServerFamily {
void BreakOnShutdown();
bool AwaitDispatches(absl::Duration timeout,
const std::function<bool(util::Connection*)>& filter);
private:
uint32_t shard_count() const {
return shard_set->size();
@ -174,14 +179,13 @@ class ServerFamily {
void Latency(CmdArgList args, ConnectionContext* cntx);
void Psync(CmdArgList args, ConnectionContext* cntx);
void ReplicaOf(CmdArgList args, ConnectionContext* cntx);
void ReplTakeOver(CmdArgList args, ConnectionContext* cntx);
void ReplConf(CmdArgList args, ConnectionContext* cntx);
void Role(CmdArgList args, ConnectionContext* cntx);
void Save(CmdArgList args, ConnectionContext* cntx);
void Script(CmdArgList args, ConnectionContext* cntx);
void Sync(CmdArgList args, ConnectionContext* cntx);
void _Shutdown(CmdArgList args, ConnectionContext* cntx);
void SyncGeneric(std::string_view repl_master_id, uint64_t offs, ConnectionContext* cntx);
// Returns the number of loaded keys if successfull.

View file

@ -507,6 +507,7 @@ bool Transaction::RunInShard(EngineShard* shard, bool txq_ooo) {
if (IsGlobal()) {
DCHECK(!awaked_prerun && !became_suspended); // Global transactions can not be blocking.
VLOG(2) << "Releasing shard lock";
shard->shard_lock()->Release(Mode());
} else { // not global.
largs = GetLockArgs(idx);
@ -572,6 +573,7 @@ void Transaction::ScheduleInternal() {
// Lock shards
auto cb = [mode](EngineShard* shard) { shard->shard_lock()->Acquire(mode); };
shard_set->RunBriefInParallel(std::move(cb));
VLOG(1) << "Global shard lock acquired";
} else {
num_shards = unique_shard_cnt_;
DCHECK_GT(num_shards, 0u);
@ -893,8 +895,8 @@ void Transaction::RunQuickie(EngineShard* shard) {
auto& sd = shard_data_[SidToId(unique_shard_id_)];
DCHECK_EQ(0, sd.local_mask & (KEYLOCK_ACQUIRED | OUT_OF_ORDER));
DVLOG(1) << "RunQuickSingle " << DebugId() << " " << shard->shard_id() << " " << args_[0];
DCHECK(cb_ptr_) << DebugId() << " " << shard->shard_id() << " " << args_[0];
DVLOG(1) << "RunQuickSingle " << DebugId() << " " << shard->shard_id();
DCHECK(cb_ptr_) << DebugId() << " " << shard->shard_id();
// Calling the callback in somewhat safe way
try {

View file

@ -3,6 +3,7 @@ import time
import subprocess
import aiohttp
from prometheus_client.parser import text_string_to_metric_families
from redis.asyncio import Redis as RedisClient
from dataclasses import dataclass
@ -32,6 +33,10 @@ class DflyInstance:
self.args = args
self.params = params
self.proc = None
self._client : Optional[RedisClient] = None
def client(self) -> RedisClient:
return RedisClient(port=self.port)
def start(self):
self._start()

View file

@ -937,7 +937,6 @@ async def assert_lag_condition(inst, client, condition):
assert False, "Lag has never satisfied condition!"
@dfly_args({"proactor_threads": 2})
@pytest.mark.asyncio
async def test_replication_info(df_local_factory, df_seeder_factory, n_keys=2000):
@ -1069,3 +1068,114 @@ async def test_readonly_script(df_local_factory):
assert False
except aioredis.ResponseError as roe:
assert 'READONLY ' in str(roe)
take_over_cases = [
[2, 2],
[2, 4],
[4, 2],
[8, 8],
]
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_counters(df_local_factory, master_threads, replica_threads):
master = df_local_factory.create(proactor_threads=master_threads,
port=BASE_PORT,
# vmodule="journal_slice=2,dflycmd=2,main_service=1",
logtostderr=True)
replica1 = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=replica_threads)
replica2 = df_local_factory.create(
port=BASE_PORT+2, proactor_threads=replica_threads)
replica3 = df_local_factory.create(
port=BASE_PORT+3, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica1, replica2, replica3])
async with (
master.client() as c_master,
replica1.client() as c1,
master.client() as c_blocking,
replica2.client() as c2,
replica3.client() as c3,
):
await c1.execute_command(f"REPLICAOF localhost {master.port}")
await c2.execute_command(f"REPLICAOF localhost {master.port}")
await c3.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c1)
async def counter(key):
value = 0
await c_master.execute_command(f"SET {key} 0")
start = time.time()
while time.time() - start < 20:
try:
value = await c_master.execute_command(f"INCR {key}")
except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e:
break
else:
assert False, "The incrementing loop should be exited with a connection error"
return key, value
async def block_during_takeover():
"Add a blocking command during takeover to make sure it doesn't block it."
# TODO: We need to make takeover interrupt blocking commands.
return
try:
await c_blocking.execute_command("BLPOP BLOCKING_KEY1 BLOCKING_KEY2 10")
except redis.exceptions.ConnectionError:
pass
async def delayed_takeover():
await asyncio.sleep(1)
await c1.execute_command(f"REPLTAKEOVER 5")
_, _, *results = await asyncio.gather(delayed_takeover(), block_during_takeover(), *[counter(f"key{i}") for i in range(16)])
assert await c1.execute_command("role") == [b'master', []]
for key, client_value in results:
replicated_value = await c1.get(key)
assert client_value == int(replicated_value)
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_seeder(df_local_factory, df_seeder_factory, master_threads, replica_threads):
master = df_local_factory.create(proactor_threads=master_threads,
port=BASE_PORT,
dbfilename=f"dump_{master_threads}_{replica_threads}",
logtostderr=True)
replica = df_local_factory.create(
port=BASE_PORT+1, proactor_threads=replica_threads)
df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create(port=master.port, keys=1000, dbcount=5, stop_on_failure=False)
async with (
master.client() as c_master,
replica.client() as c_replica,
):
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
async def seed():
await seeder.run(target_ops=3000)
fill_task = asyncio.create_task(seed())
# Give the seeder a bit of time.
await asyncio.sleep(1)
await c_replica.execute_command(f"REPLTAKEOVER 5")
seeder.stop()
assert await c_replica.execute_command("role") == [b'master', []]
# Need to wait a bit to give time to write the shutdown snapshot
await asyncio.sleep(1)
assert master.proc.poll() == 0, "Master process did not exit correctly."
master.start()
await wait_available_async(c_master)
capture = await seeder.capture()
assert await seeder.compare(capture, port=replica.port)

View file

@ -2,6 +2,7 @@ import itertools
import sys
import asyncio
from redis import asyncio as aioredis
import redis
import random
import string
import itertools
@ -342,7 +343,7 @@ class DflySeeder:
assert await seeder.compare(capture, port=1112)
"""
def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[]):
def __init__(self, port=6379, keys=1000, val_size=50, batch_size=100, max_multikey=5, dbcount=1, multi_transaction_probability=0.3, log_file=None, unsupported_types=[], stop_on_failure=True):
self.gen = CommandGenerator(
keys, val_size, batch_size, max_multikey, unsupported_types
)
@ -350,6 +351,7 @@ class DflySeeder:
self.dbcount = dbcount
self.multi_transaction_probability = multi_transaction_probability
self.stop_flag = False
self.stop_on_failure = stop_on_failure
self.log_file = log_file
if self.log_file is not None:
@ -496,6 +498,9 @@ class DflySeeder:
try:
await pipe.execute()
except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e:
if self.stop_on_failure:
raise SystemExit(e)
except Exception as e:
raise SystemExit(e)
queue.task_done()