feat(server): Improved cancellation (#599)

This commit is contained in:
Vladislav 2022-12-27 16:01:54 +03:00 committed by GitHub
parent b48f7557b7
commit e6721d8160
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 186 additions and 108 deletions

View file

@ -263,6 +263,10 @@ std::string GenericError::Format() const {
return absl::StrCat(ec_.message(), ":", details_);
}
Context::~Context() {
JoinErrorHandler();
}
GenericError Context::GetError() {
std::lock_guard lk(mu_);
return err_;
@ -273,20 +277,45 @@ const Cancellation* Context::GetCancellation() const {
}
void Context::Cancel() {
Error(std::make_error_code(errc::operation_canceled), "Context cancelled");
ReportError(std::make_error_code(errc::operation_canceled), "Context cancelled");
}
void Context::Reset(ErrHandler handler) {
std::lock_guard lk{mu_};
JoinErrorHandler();
err_ = {};
err_handler_ = std::move(handler);
Cancellation::flag_.store(false, std::memory_order_relaxed);
}
GenericError Context::Switch(ErrHandler handler) {
GenericError Context::SwitchErrorHandler(ErrHandler handler) {
std::lock_guard lk{mu_};
if (!err_)
if (!err_) {
// No need to check for the error handler - it can't be running
// if no error is set.
err_handler_ = std::move(handler);
}
return err_;
}
void Context::JoinErrorHandler() {
if (err_handler_fb_.IsJoinable())
err_handler_fb_.Join();
}
GenericError Context::ReportErrorInternal(GenericError&& err) {
std::lock_guard lk{mu_};
if (err_)
return err_;
err_ = std::move(err);
// This context is either new or was Reset, where the handler was joined
CHECK(!err_handler_fb_.IsJoinable());
if (err_handler_)
err_handler_fb_ = util::fibers_ext::Fiber{err_handler_, err_};
Cancellation::Cancel();
return err_;
}

View file

@ -15,6 +15,7 @@
#include "facade/facade_types.h"
#include "facade/op_status.h"
#include "util/fibers/fiber.h"
namespace dfly {
@ -243,7 +244,8 @@ using AggregateGenericError = AggregateValue<GenericError>;
// Context is a utility for managing error reporting and cancellation for complex tasks.
//
// When submitting an error with `Error`, only the first is stored (as in aggregate values).
// Then a special error handler is run, if present, and the context is cancelled.
// Then a special error handler is run, if present, and the context is cancelled. The error handler
// is run in a separate handler to free up the caller.
//
// Manual cancellation with `Cancel` is simulated by reporting an `errc::operation_canceled` error.
// This allows running the error handler and representing this scenario as an error.
@ -255,10 +257,10 @@ class Context : protected Cancellation {
Context(ErrHandler err_handler) : Cancellation{}, err_{}, err_handler_{std::move(err_handler)} {
}
// Cancels the context by submitting an `errc::operation_canceled` error.
void Cancel();
using Cancellation::IsCancelled;
~Context();
void Cancel(); // Cancels the context by submitting an `errc::operation_canceled` error.
using Cancellation::IsCancelled;
const Cancellation* GetCancellation() const;
GenericError GetError();
@ -266,27 +268,11 @@ class Context : protected Cancellation {
// Report an error by submitting arguments for GenericError.
// If this is the first error that occured, then the error handler is run
// and the context is cancelled.
//
// Note: this function blocks when called from inside an error handler.
template <typename... T> GenericError Error(T... ts) {
if (!mu_.try_lock()) // TODO: Maybe use two separate locks.
return GenericError{std::forward<T>(ts)...};
std::lock_guard lk{mu_, std::adopt_lock};
if (err_)
return err_;
GenericError new_err{std::forward<T>(ts)...};
if (err_handler_)
err_handler_(new_err);
err_ = std::move(new_err);
Cancellation::Cancel();
return err_;
template <typename... T> GenericError ReportError(T... ts) {
return ReportErrorInternal(GenericError{std::forward<T>(ts)...});
}
// Reset error and cancellation flag, assign new error handler.
// Wait for error handler to stop, reset error and cancellation flag, assign new error handler.
void Reset(ErrHandler handler);
// Atomically replace the error handler if no error is present, and return the
@ -295,12 +281,21 @@ class Context : protected Cancellation {
// Beware, never do this manually in two steps. If you check for cancellation,
// set the error handler and initialize resources, then the new error handler
// will never run if the context was cancelled between the first two steps.
GenericError Switch(ErrHandler handler);
GenericError SwitchErrorHandler(ErrHandler handler);
// If any error handler is running, wait for it to stop.
void JoinErrorHandler();
private:
// Report error.
GenericError ReportErrorInternal(GenericError&& err);
private:
GenericError err_;
ErrHandler err_handler_;
::boost::fibers::mutex mu_;
ErrHandler err_handler_;
::util::fibers_ext::Fiber err_handler_fb_;
};
struct ScanOpts {

View file

@ -314,7 +314,7 @@ void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) {
}
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
DCHECK(!flow->full_sync_fb.joinable());
DCHECK(!flow->full_sync_fb.IsJoinable());
SaveMode save_mode = shard == nullptr ? SaveMode::SUMMARY : SaveMode::SINGLE_SHARD;
flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false));
@ -341,8 +341,8 @@ void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) {
}
// Wait for full sync to finish.
if (flow->full_sync_fb.joinable()) {
flow->full_sync_fb.join();
if (flow->full_sync_fb.IsJoinable()) {
flow->full_sync_fb.Join();
}
// Reset cleanup and saver
@ -382,18 +382,18 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
}
if (ec) {
cntx->Error(ec);
cntx->ReportError(ec);
return;
}
if (ec = saver->SaveBody(cntx->GetCancellation(), nullptr); ec) {
cntx->Error(ec);
cntx->ReportError(ec);
return;
}
ec = flow->conn->socket()->Write(io::Buffer(flow->eof_token));
if (ec) {
cntx->Error(ec);
cntx->ReportError(ec);
return;
}
}
@ -406,9 +406,8 @@ uint32_t DflyCmd::CreateSyncSession() {
auto err_handler = [this, sync_id](const GenericError& err) {
LOG(INFO) << "Replication error: " << err.Format();
// Stop replication in case of error.
// StopReplication needs to run async to prevent blocking
// the error handler.
// Spawn external fiber to allow destructing the context from outside
// and return from the handler immediately.
::boost::fibers::fiber{&DflyCmd::StopReplication, this, sync_id}.detach();
};
@ -473,8 +472,8 @@ void DflyCmd::CancelReplication(uint32_t sync_id, shared_ptr<ReplicaInfo> replic
}
}
if (flow->full_sync_fb.joinable()) {
flow->full_sync_fb.join();
if (flow->full_sync_fb.IsJoinable()) {
flow->full_sync_fb.Join();
}
});
@ -484,6 +483,9 @@ void DflyCmd::CancelReplication(uint32_t sync_id, shared_ptr<ReplicaInfo> replic
replica_infos_.erase(sync_id);
}
// Wait for error handler to quit.
replica_ptr->cntx.JoinErrorHandler();
LOG(INFO) << "Evicted sync session " << sync_id;
}

View file

@ -12,6 +12,7 @@
#include <memory>
#include "server/conn_context.h"
#include "util/fibers/fiber.h"
namespace facade {
class RedisReplyBuilder;
@ -91,8 +92,8 @@ class DflyCmd {
facade::Connection* conn;
::boost::fibers::fiber full_sync_fb; // Full sync fiber.
std::unique_ptr<RdbSaver> saver; // Saver used by the full sync phase.
util::fibers_ext::Fiber full_sync_fb; // Full sync fiber.
std::unique_ptr<RdbSaver> saver; // Saver used by the full sync phase.
std::string eof_token;
std::function<void()> cleanup; // Optional cleanup for cancellation.

View file

@ -7,6 +7,7 @@ extern "C" {
#include "redis/rdb.h"
}
#include <absl/cleanup/cleanup.h>
#include <absl/functional/bind_front.h>
#include <absl/strings/escaping.h>
#include <absl/strings/str_cat.h>
@ -199,23 +200,14 @@ void Replica::MainReplicationFb() {
// 3. Initiate full sync
if ((state_mask_ & R_SYNC_OK) == 0) {
// Make sure we're in LOADING state.
if (service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) != GlobalState::LOADING) {
state_mask_ = 0;
continue;
}
if (HasDflyMaster())
ec = InitiateDflySync();
else
ec = InitiatePSync();
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
if (ec) {
LOG(WARNING) << "Error syncing " << ec << " " << ec.message();
state_mask_ &= R_ENABLED; // reset all flags besides R_ENABLED
JoinAllFlows();
continue;
}
@ -230,10 +222,11 @@ void Replica::MainReplicationFb() {
else
ec = ConsumeRedisStream();
JoinAllFlows();
state_mask_ &= ~R_SYNC_OK;
}
cntx_.JoinErrorHandler();
VLOG(1) << "Main replication fiber finished";
}
@ -385,6 +378,13 @@ error_code Replica::InitiatePSync() {
SocketSource ss{sock_.get()};
io::PrefixSource ps{io_buf.InputBuffer(), &ss};
// Set LOADING state.
// TODO: Flush db on retry.
CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING);
absl::Cleanup cleanup = [this]() {
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
};
RdbLoader loader(NULL);
loader.set_source_limit(snapshot_size);
// TODO: to allow registering callbacks within loader to send '\n' pings back to master.
@ -428,8 +428,16 @@ error_code Replica::InitiatePSync() {
// Initialize and start sub-replica for each flow.
error_code Replica::InitiateDflySync() {
DCHECK_GT(num_df_flows_, 0u);
absl::Cleanup cleanup = [this]() {
// We do the following operations regardless of outcome.
JoinAllFlows();
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
};
// Initialize MultiShardExecution.
multi_shard_exe_.reset(new MultiShardExecution());
// Initialize shard flows.
shard_flows_.resize(num_df_flows_);
for (unsigned i = 0; i < num_df_flows_; ++i) {
shard_flows_[i].reset(new Replica(master_context_, i, &service_, multi_shard_exe_));
@ -438,33 +446,66 @@ error_code Replica::InitiateDflySync() {
// Blocked on until all flows got full sync cut.
fibers_ext::BlockingCounter sync_block{num_df_flows_};
// Switch to new error handler that closes flow sockets.
auto err_handler = [this, sync_block](const auto& ge) mutable {
sync_block.Cancel(); // Unblock this function.
DefaultErrorHandler(ge); // Close sockets to unblock flows.
// Unblock this function.
sync_block.Cancel();
// Make sure the flows are not in a state transition
lock_guard lk{flows_op_mu_};
// Unblock all sockets.
DefaultErrorHandler(ge);
for (auto& flow : shard_flows_)
flow->CloseSocket();
};
RETURN_ON_ERR(cntx_.Switch(std::move(err_handler)));
RETURN_ON_ERR(cntx_.SwitchErrorHandler(std::move(err_handler)));
// Make sure we're in LOADING state.
// TODO: Flush db on retry.
CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING);
// Start full sync flows.
auto partition = Partition(num_df_flows_);
shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) {
for (auto id : partition[index]) {
auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block, &cntx_);
if (ec)
cntx_.Error(ec);
}
});
{
auto partition = Partition(num_df_flows_);
auto shard_cb = [&](unsigned index, auto*) {
for (auto id : partition[index]) {
auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block, &cntx_);
if (ec)
cntx_.ReportError(ec);
}
};
// Lock to prevent the error handler from running instantly
// while the flows are in a mixed state.
lock_guard lk{flows_op_mu_};
shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb));
}
RETURN_ON_ERR(cntx_.GetError());
// Send DFLY SYNC.
if (auto ec = SendNextPhaseRequest(); ec) {
return cntx_.Error(ec);
if (auto ec = SendNextPhaseRequest(false); ec) {
return cntx_.ReportError(ec);
}
// Wait for all flows to receive full sync cut.
// In case of an error, this is unblocked by the error handler.
LOG(INFO) << "Waiting for all full sync cut confirmations";
sync_block.Wait();
LOG(INFO) << "Full sync finished";
// Check if we woke up due to cancellation.
if (cntx_.IsCancelled())
return cntx_.GetError();
// Send DFLY STARTSTABLE.
if (auto ec = SendNextPhaseRequest(true); ec) {
return cntx_.ReportError(ec);
}
// Joining flows and resetting state is done by cleanup.
LOG(INFO) << "Full sync finished ";
return cntx_.GetError();
}
@ -515,40 +556,48 @@ error_code Replica::ConsumeRedisStream() {
}
error_code Replica::ConsumeDflyStream() {
// Send DFLY STARTSTABLE.
if (auto ec = SendNextPhaseRequest(); ec) {
return cntx_.Error(ec);
// Set new error handler that closes flow sockets.
auto err_handler = [this](const auto& ge) {
// Make sure the flows are not in a state transition
lock_guard lk{flows_op_mu_};
DefaultErrorHandler(ge);
for (auto& flow : shard_flows_)
flow->CloseSocket();
};
RETURN_ON_ERR(cntx_.SwitchErrorHandler(std::move(err_handler)));
// Transition flows into stable sync.
{
auto partition = Partition(num_df_flows_);
auto shard_cb = [&](unsigned index, auto*) {
const auto& local_ids = partition[index];
for (unsigned id : local_ids) {
auto ec = shard_flows_[id]->StartStableSyncFlow(&cntx_);
if (ec)
cntx_.ReportError(ec);
}
};
// Lock to prevent error handler from running on mixed state.
lock_guard lk{flows_op_mu_};
shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb));
}
// Wait for all flows to finish full sync.
JoinAllFlows();
RETURN_ON_ERR(cntx_.Switch(absl::bind_front(&Replica::DefaultErrorHandler, this)));
vector<vector<unsigned>> partition = Partition(num_df_flows_);
shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) {
const auto& local_ids = partition[index];
for (unsigned id : local_ids) {
auto ec = shard_flows_[id]->StartStableSyncFlow(&cntx_);
if (ec)
cntx_.Error(ec);
}
});
// The only option to unblock is to cancel the context.
CHECK(cntx_.GetError());
return cntx_.GetError();
}
void Replica::CloseAllSockets() {
void Replica::CloseSocket() {
if (sock_) {
sock_->proactor()->Await([this] {
auto ec = sock_->Shutdown(SHUT_RDWR);
LOG_IF(ERROR, ec) << "Could not shutdown socket " << ec;
});
}
for (auto& flow : shard_flows_) {
flow->CloseAllSockets();
}
}
void Replica::JoinAllFlows() {
@ -560,16 +609,18 @@ void Replica::JoinAllFlows() {
}
void Replica::DefaultErrorHandler(const GenericError& err) {
CloseAllSockets();
CloseSocket();
}
error_code Replica::SendNextPhaseRequest() {
error_code Replica::SendNextPhaseRequest(bool stable) {
ReqSerializer serializer{sock_.get()};
// Ask master to start sending replication stream
string request = (state_mask_ & R_SYNC_OK) ? "STARTSTABLE" : "SYNC";
RETURN_ON_ERR(
SendCommand(StrCat("DFLY ", request, " ", master_context_.dfly_session_id), &serializer));
string_view kind = (stable) ? "STARTSTABLE"sv : "SYNC"sv;
string request = StrCat("DFLY ", kind, " ", master_context_.dfly_session_id);
LOG(INFO) << "Sending: " << request;
RETURN_ON_ERR(SendCommand(request, &serializer));
base::IoBuf io_buf{128};
unsigned consumed = 0;
@ -657,7 +708,7 @@ void Replica::FullSyncDflyFb(string eof_token, fibers_ext::BlockingCounter bc, C
// Load incoming rdb stream.
if (std::error_code ec = loader.Load(&ps); ec) {
cntx->Error(ec, "Error loading rdb format");
cntx->ReportError(ec, "Error loading rdb format");
return;
}
@ -670,7 +721,8 @@ void Replica::FullSyncDflyFb(string eof_token, fibers_ext::BlockingCounter bc, C
chained_tail.ReadAtLeast(io::MutableBytes{buf.get(), eof_token.size()}, eof_token.size());
if (!res || *res != eof_token.size()) {
cntx->Error(std::make_error_code(errc::protocol_error), "Error finding eof token in stream");
cntx->ReportError(std::make_error_code(errc::protocol_error),
"Error finding eof token in stream");
return;
}
}
@ -704,7 +756,7 @@ void Replica::StableSyncDflyFb(Context* cntx) {
while (!cntx->IsCancelled()) {
auto res = reader.ReadEntry(&ps);
if (!res) {
cntx->Error(res.error(), "Journal format error");
cntx->ReportError(res.error(), "Journal format error");
return;
}
ExecuteEntry(&executor, res.value());

View file

@ -89,10 +89,11 @@ class Replica {
std::error_code ConsumeRedisStream(); // Redis stable state.
std::error_code ConsumeDflyStream(); // Dragonfly stable state.
void CloseAllSockets(); // Close all sockets.
void JoinAllFlows(); // Join all flows if possible.
void CloseSocket(); // Close replica sockets.
void JoinAllFlows(); // Join all flows if possible.
std::error_code SendNextPhaseRequest(); // Send DFLY SYNC or DFLY STARTSTABLE.
// Send DFLY SYNC or DFLY STARTSTABLE if stable is true.
std::error_code SendNextPhaseRequest(bool stable);
void DefaultErrorHandler(const GenericError& err);
@ -180,6 +181,9 @@ class Replica {
::boost::fibers::fiber sync_fb_;
std::vector<std::unique_ptr<Replica>> shard_flows_;
// Guard operations where flows might be in a mixed state (transition/setup)
::boost::fibers::mutex flows_op_mu_;
std::unique_ptr<base::IoBuf> leftover_buf_;
std::unique_ptr<facade::RedisParser> parser_;
facade::RespVec resp_args_;

View file

@ -276,16 +276,8 @@ async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_rando
c_replicas = [aioredis.Redis(port=replica.port) for replica in replicas]
async def full_sync(c_replica):
try:
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await wait_available_async(c_replica)
except aioredis.ResponseError as e:
# This should mean master crashed during greet phase
pass
async def crash_master_fs():
await asyncio.sleep(random.random() / 10 + 0.01)
await asyncio.sleep(random.random() / 10 + 0.1 * len(replicas))
master.stop(kill=True)
async def start_master():
@ -296,8 +288,11 @@ async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_rando
await start_master()
# Crash master during full sync
await asyncio.gather(*(full_sync(c) for c in c_replicas), crash_master_fs())
# Crash master during full sync, but with all passing initial connection phase
await asyncio.gather(*(c_replica.execute_command("REPLICAOF localhost " + str(master.port))
for c_replica in c_replicas), crash_master_fs())
await asyncio.sleep(1 + len(replicas) * 0.5)
for _ in range(n_random_crashes):
await start_master()