feat(server): Replication errors & cancellation (#501)

This commit is contained in:
Vladislav 2022-11-22 19:17:31 +03:00 committed by GitHub
parent 77ed4a22dd
commit 893c741c14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 543 additions and 199 deletions

View file

@ -8,6 +8,7 @@
#include <absl/strings/str_cat.h>
#include <absl/types/span.h>
#include <atomic>
#include <boost/fiber/mutex.hpp>
#include <string_view>
#include <vector>
@ -197,6 +198,82 @@ using AggregateStatus = AggregateValue<facade::OpStatus>;
static_assert(facade::OpStatus::OK == facade::OpStatus{},
"Default intitialization should be OK value");
// Re-usable component for signaling cancellation.
// Simple wrapper around atomic flag.
struct Cancellation {
void Cancel() {
flag_.store(true, std::memory_order_relaxed);
}
bool IsCancelled() const {
return flag_.load(std::memory_order_relaxed);
}
private:
std::atomic_bool flag_;
};
// Error wrapper, that stores error_code and optional string message.
class GenericError {
public:
GenericError() = default;
GenericError(std::error_code ec) : ec_{ec}, details_{} {
}
GenericError(std::error_code ec, std::string details) : ec_{ec}, details_{std::move(details)} {
}
std::pair<std::error_code, const std::string&> Get() const {
return {ec_, details_};
}
std::error_code GetError() const {
return ec_;
}
const std::string& GetDetails() const {
return details_;
}
operator bool() const {
return bool(ec_);
}
private:
std::error_code ec_;
std::string details_;
};
using AggregateGenericError = AggregateValue<GenericError>;
// Contest combines Cancellation and AggregateGenericError in one class.
// Allows setting an error_handler to run on errors.
class Context : public Cancellation {
public:
// The error handler should return false if this error is ignored.
using ErrHandler = std::function<bool(const GenericError&)>;
Context() = default;
Context(ErrHandler err_handler) : Cancellation{}, err_handler_{std::move(err_handler)} {
}
template <typename... T> void Error(T... ts) {
std::lock_guard lk{mu_};
if (err_)
return;
GenericError new_err{std::forward<T>(ts)...};
if (!err_handler_ || err_handler_(new_err)) {
err_ = std::move(new_err);
Cancel();
}
}
private:
GenericError err_;
ErrHandler err_handler_;
::boost::fibers::mutex mu_;
};
struct ScanOpts {
std::string_view pattern;
size_t limit = 10;

View file

@ -98,26 +98,6 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
rb->SendError(kSyntaxErr);
}
void DflyCmd::OnClose(ConnectionContext* cntx) {
unsigned session_id = cntx->conn_state.repl_session_id;
unsigned flow_id = cntx->conn_state.repl_flow_id;
if (!session_id)
return;
if (flow_id == kuint32max) {
DeleteSyncSession(session_id);
} else {
shared_ptr<SyncInfo> sync_info = GetSyncInfo(session_id);
if (sync_info) {
lock_guard lk(sync_info->mu);
if (sync_info->state != SyncState::CANCELLED) {
UnregisterFlow(&sync_info->flows[flow_id]);
}
}
}
}
void DflyCmd::Journal(CmdArgList args, ConnectionContext* cntx) {
DCHECK_GE(args.size(), 3u);
ToUpper(&args[2]);
@ -227,12 +207,12 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
return rb->SendError(facade::kInvalidIntErr);
}
auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb);
auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb);
if (!sync_id)
return;
unique_lock lk(sync_info->mu);
if (sync_info->state != SyncState::PREPARATION)
unique_lock lk(replica_ptr->mu);
if (replica_ptr->state != SyncState::PREPARATION)
return rb->SendError(kInvalidState);
// Set meta info on connection.
@ -243,7 +223,7 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
absl::InsecureBitGen gen;
string eof_token = GetRandomHex(gen, 40);
sync_info->flows[flow_id] = FlowInfo{cntx->owner(), eof_token};
replica_ptr->flows[flow_id] = FlowInfo{cntx->owner(), eof_token};
listener_->Migrate(cntx->owner(), shard_set->pool()->at(flow_id));
rb->StartArray(2);
@ -257,12 +237,12 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "Got DFLY SYNC " << sync_id_str;
auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb);
auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb);
if (!sync_id)
return;
unique_lock lk(sync_info->mu);
if (!CheckReplicaStateOrReply(*sync_info, SyncState::PREPARATION, rb))
unique_lock lk(replica_ptr->mu);
if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::PREPARATION, rb))
return;
// Start full sync.
@ -270,8 +250,9 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) {
TransactionGuard tg{cntx->transaction};
AggregateStatus status;
auto cb = [this, &status, sync_info = sync_info](unsigned index, auto*) {
status = StartFullSyncInThread(&sync_info->flows[index], EngineShard::tlocal());
auto cb = [this, &status, replica_ptr](unsigned index, auto*) {
status = StartFullSyncInThread(&replica_ptr->flows[index], &replica_ptr->cntx,
EngineShard::tlocal());
};
shard_set->pool()->AwaitFiberOnAll(std::move(cb));
@ -280,7 +261,7 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) {
return rb->SendError(kInvalidState);
}
sync_info->state = SyncState::FULL_SYNC;
replica_ptr->state = SyncState::FULL_SYNC;
return rb->SendOk();
}
@ -290,20 +271,24 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "Got DFLY STARTSTABLE " << sync_id_str;
auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb);
auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb);
if (!sync_id)
return;
unique_lock lk(sync_info->mu);
if (!CheckReplicaStateOrReply(*sync_info, SyncState::FULL_SYNC, rb))
unique_lock lk(replica_ptr->mu);
if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::FULL_SYNC, rb))
return;
{
TransactionGuard tg{cntx->transaction};
AggregateStatus status;
auto cb = [this, &status, sync_info = sync_info](unsigned index, auto*) {
status = StartStableSyncInThread(&sync_info->flows[index], EngineShard::tlocal());
auto cb = [this, &status, replica_ptr](unsigned index, auto*) {
EngineShard* shard = EngineShard::tlocal();
FlowInfo* flow = &replica_ptr->flows[index];
StopFullSyncInThread(flow, shard);
status = StartStableSyncInThread(flow, shard);
return OpStatus::OK;
};
shard_set->pool()->AwaitFiberOnAll(std::move(cb));
@ -312,7 +297,7 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) {
return rb->SendError(kInvalidState);
}
sync_info->state = SyncState::STABLE_SYNC;
replica_ptr->state = SyncState::STABLE_SYNC;
return rb->SendOk();
}
@ -326,49 +311,64 @@ void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) {
return rb->SendOk();
}
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, EngineShard* shard) {
DCHECK(!flow->fb.joinable());
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) {
DCHECK(!flow->full_sync_fb.joinable());
SaveMode save_mode = shard == nullptr ? SaveMode::SUMMARY : SaveMode::SINGLE_SHARD;
flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false));
flow->cleanup = [flow]() {
flow->saver->Cancel();
flow->TryShutdownSocket();
};
// Shard can be null for io thread.
if (shard != nullptr) {
auto ec = sf_->journal()->OpenInThread(false, string_view());
CHECK(!ec);
flow->saver->StartSnapshotInShard(true, shard);
CHECK(!sf_->journal()->OpenInThread(false, ""sv)); // can only happen in persistent mode.
flow->saver->StartSnapshotInShard(true, cntx, shard);
}
flow->fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow);
flow->full_sync_fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow, cntx);
return OpStatus::OK;
}
OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, EngineShard* shard) {
void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) {
// Shard can be null for io thread.
if (shard != nullptr) {
flow->saver->StopSnapshotInShard(shard);
}
// Wait for full sync to finish.
if (flow->fb.joinable()) {
flow->fb.join();
if (flow->full_sync_fb.joinable()) {
flow->full_sync_fb.join();
}
if (shard != nullptr) {
// Reset cleanup and saver
flow->cleanup = []() {};
flow->saver.reset();
}
// TODO: Add cancellation.
auto cb = sf_->journal()->RegisterOnChange([flow](const journal::Entry& je) {
OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, EngineShard* shard) {
// Register journal listener and cleanup.
uint32_t cb_id = 0;
if (shard != nullptr) {
cb_id = sf_->journal()->RegisterOnChange([flow](const journal::Entry& je) {
// TODO: Serialize event.
ReqSerializer serializer{flow->conn->socket()};
serializer.SendCommand(absl::StrCat("SET ", je.key, " ", je.pval_ptr->ToString()));
});
}
flow->cleanup = [flow, this, cb_id]() {
if (cb_id)
sf_->journal()->Unregister(cb_id);
flow->TryShutdownSocket();
};
return OpStatus::OK;
}
void DflyCmd::FullSyncFb(FlowInfo* flow) {
void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
error_code ec;
RdbSaver* saver = flow->saver.get();
@ -380,92 +380,121 @@ void DflyCmd::FullSyncFb(FlowInfo* flow) {
}
if (ec) {
LOG(ERROR) << ec;
return;
return cntx->Error(ec);
}
// TODO: we should be able to stop earlier if requested.
ec = saver->SaveBody(nullptr);
if (ec) {
LOG(ERROR) << ec;
return;
if (ec = saver->SaveBody(cntx, nullptr); ec) {
return cntx->Error(ec);
}
VLOG(1) << "Sending full sync EOF";
ec = flow->conn->socket()->Write(io::Buffer(flow->eof_token));
if (ec) {
LOG(ERROR) << ec;
return;
return cntx->Error(ec);
}
}
uint32_t DflyCmd::CreateSyncSession() {
unique_lock lk(mu_);
unsigned sync_id = next_sync_id_++;
auto sync_info = make_shared<SyncInfo>();
sync_info->flows.resize(shard_set->size() + 1);
unsigned flow_count = shard_set->size() + 1;
auto err_handler = [this, sync_id](const GenericError& err) {
LOG(INFO) << "Replication error: " << err.GetError().message() << " " << err.GetDetails();
auto [it, inserted] = sync_infos_.emplace(next_sync_id_, std::move(sync_info));
// Stop replication in case of error.
// StopReplication needs to run async to prevent blocking
// the error handler.
::boost::fibers::fiber{&DflyCmd::StopReplication, this, sync_id}.detach();
return true; // Cancel context
};
auto replica_ptr = make_shared<ReplicaInfo>(flow_count, std::move(err_handler));
auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr));
CHECK(inserted);
return next_sync_id_++;
return sync_id;
}
void DflyCmd::UnregisterFlow(FlowInfo* flow) {
// TODO: Cancel saver operations.
flow->conn = nullptr;
flow->saver.reset();
}
void DflyCmd::DeleteSyncSession(uint32_t sync_id) {
shared_ptr<SyncInfo> sync_info;
// Remove sync_info from map.
// Store by value to keep alive.
{
unique_lock lk(mu_);
auto it = sync_infos_.find(sync_id);
if (it == sync_infos_.end())
void DflyCmd::OnClose(ConnectionContext* cntx) {
unsigned session_id = cntx->conn_state.repl_session_id;
if (!session_id)
return;
sync_info = it->second;
sync_infos_.erase(it);
}
auto replica_ptr = GetReplicaInfo(session_id);
if (!replica_ptr)
return;
// Wait for all operations to finish.
// Set state to CANCELLED so no other operations will run.
{
unique_lock lk(sync_info->mu);
sync_info->state = SyncState::CANCELLED;
}
// Try to cleanup flows.
for (auto& flow : sync_info->flows) {
if (flow.conn != nullptr) {
VLOG(1) << "Flow connection " << flow.conn->GetName() << " is still alive"
<< " on sync_id " << sync_id;
}
// TODO: Implement cancellation.
if (flow.fb.joinable()) {
VLOG(1) << "Force joining fiber on on sync_id " << sync_id;
flow.fb.join();
}
}
// Because CancelReplication holds the per-replica mutex,
// aborting connection will block here until cancellation finishes.
// This allows keeping resources alive during the cleanup phase.
CancelReplication(session_id, replica_ptr);
}
shared_ptr<DflyCmd::SyncInfo> DflyCmd::GetSyncInfo(uint32_t sync_id) {
void DflyCmd::StopReplication(uint32_t sync_id) {
auto replica_ptr = GetReplicaInfo(sync_id);
if (!replica_ptr)
return;
CancelReplication(sync_id, replica_ptr);
}
void DflyCmd::CancelReplication(uint32_t sync_id, shared_ptr<ReplicaInfo> replica_ptr) {
lock_guard lk(replica_ptr->mu);
if (replica_ptr->state == SyncState::CANCELLED) {
return;
}
LOG(INFO) << "Cancelling sync session " << sync_id;
// Update replica_ptr state and cancel context.
replica_ptr->state = SyncState::CANCELLED;
replica_ptr->cntx.Cancel();
// Run cleanup for shard threads.
shard_set->AwaitRunningOnShardQueue([replica_ptr](EngineShard* shard) {
FlowInfo* flow = &replica_ptr->flows[shard->shard_id()];
if (flow->cleanup) {
flow->cleanup();
}
});
// Wait for tasks to finish.
shard_set->pool()->AwaitFiberOnAll([replica_ptr](unsigned index, auto*) {
FlowInfo* flow = &replica_ptr->flows[index];
// Cleanup hasn't been run for io-thread.
if (EngineShard::tlocal() == nullptr) {
if (flow->cleanup) {
flow->cleanup();
}
}
if (flow->full_sync_fb.joinable()) {
flow->full_sync_fb.join();
}
});
// Remove ReplicaInfo from global map
{
lock_guard lk(mu_);
replica_infos_.erase(sync_id);
}
LOG(INFO) << "Evicted sync session " << sync_id;
}
shared_ptr<DflyCmd::ReplicaInfo> DflyCmd::GetReplicaInfo(uint32_t sync_id) {
unique_lock lk(mu_);
auto it = sync_infos_.find(sync_id);
if (it != sync_infos_.end())
auto it = replica_infos_.find(sync_id);
if (it != replica_infos_.end())
return it->second;
return {};
}
pair<uint32_t, shared_ptr<DflyCmd::SyncInfo>> DflyCmd::GetSyncInfoOrReply(std::string_view id_str,
RedisReplyBuilder* rb) {
pair<uint32_t, shared_ptr<DflyCmd::ReplicaInfo>> DflyCmd::GetReplicaInfoOrReply(
std::string_view id_str, RedisReplyBuilder* rb) {
unique_lock lk(mu_);
uint32_t sync_id;
@ -474,8 +503,8 @@ pair<uint32_t, shared_ptr<DflyCmd::SyncInfo>> DflyCmd::GetSyncInfoOrReply(std::s
return {0, nullptr};
}
auto sync_it = sync_infos_.find(sync_id);
if (sync_it == sync_infos_.end()) {
auto sync_it = replica_infos_.find(sync_id);
if (sync_it == replica_infos_.end()) {
rb->SendError(kIdNotFound);
return {0, nullptr};
}
@ -483,7 +512,7 @@ pair<uint32_t, shared_ptr<DflyCmd::SyncInfo>> DflyCmd::GetSyncInfoOrReply(std::s
return {sync_id, sync_it->second};
}
bool DflyCmd::CheckReplicaStateOrReply(const SyncInfo& sync_info, SyncState expected,
bool DflyCmd::CheckReplicaStateOrReply(const ReplicaInfo& sync_info, SyncState expected,
RedisReplyBuilder* rb) {
if (sync_info.state != expected) {
rb->SendError(kInvalidState);
@ -506,4 +535,11 @@ void DflyCmd::BreakOnShutdown() {
VLOG(1) << "BreakOnShutdown";
}
void DflyCmd::FlowInfo::TryShutdownSocket() {
// Close socket for clean disconnect.
if (conn->socket()->IsOpen()) {
conn->socket()->Shutdown(SHUT_RDWR);
}
}
} // namespace dfly

View file

@ -5,9 +5,11 @@
#pragma once
#include <absl/container/btree_map.h>
#include <memory.h>
#include <atomic>
#include <boost/fiber/fiber.hpp>
#include <boost/fiber/mutex.hpp>
#include <memory>
#include "server/conn_context.h"
@ -29,29 +31,84 @@ namespace journal {
class Journal;
} // namespace journal
// DflyCmd is responsible for managing replication. A master instance can be connected
// to many replica instances, what is more, each of them can open multiple connections.
// This is why its important to understand replica lifecycle management before making
// any crucial changes.
//
// A ReplicaInfo instance is responsible for managing a replica's state and is accessible by its
// sync_id. Each per-thread connection is called a Flow and is represented by the FlowInfo
// instance, accessible by its index.
//
// An important aspect is synchronization and efficient locking. Two levels of locking are used:
// 1. Global locking.
// Member mutex `mu_` is used for synchronizing operations connected with internal data
// structures.
// 2. Per-replica locking
// ReplicaInfo contains a separate mutex that is used for replica-only routines. It is held
// during state transitions (start full sync, start stable state sync), cancellation and member
// access.
//
// Upon first connection from the replica, a new ReplicaInfo is created.
// It tranistions through the following phases:
// 1. Preparation
// During this start phase the "flows" are set up - one connection for every master thread. Those
// connections registered by the FLOW command sent from each newly opened connection.
// 2. Full sync
// This phase is initiated by the SYNC command. It makes sure all flows are connected and the
// replica is in a valid state.
// 3. Stable state sync
// After the replica has received confirmation, that each flow is ready to transition, it sends a
// STARTSTABLE command. This transitions the replica into streaming journal changes.
// 4. Cancellation
// This can happed due to an error at any phase or through a normal abort. For properly releasing
// resources we need to run a multi-step cancellation procedure:
// 1. Transition state
// We obtain the ReplicaInfo lock, transition into the cancelled state and cancel the context.
// 2. Joining tasks
// Running tasks will stop on receiving the cancellation flag. Each FlowInfo has also an
// optional cleanup handler, that is invoked after cancelling. This should allow recovering
// from any state. The flows task will be awaited and joined if present.
// 3. Unlocking the mutex
// Now that all tasks have finished and all cleanup handlers have run, we can safely release
// the per-replica mutex, so that all OnClose handlers will unblock and internal resources
// will be released by dragonfly. Then the ReplicaInfo is removed from the global map.
//
//
class DflyCmd {
public:
// See header comments for state descriptions.
enum class SyncState { PREPARATION, FULL_SYNC, STABLE_SYNC, CANCELLED };
// Stores information related to a single flow.
struct FlowInfo {
FlowInfo() = default;
FlowInfo(facade::Connection* conn, const std::string& eof_token)
: conn(conn), eof_token(eof_token){};
: conn{conn}, eof_token{eof_token} {};
// Shutdown associated socket if its still open.
void TryShutdownSocket();
facade::Connection* conn;
::boost::fibers::fiber full_sync_fb; // Full sync fiber.
std::unique_ptr<RdbSaver> saver; // Saver used by the full sync phase.
std::string eof_token;
std::unique_ptr<RdbSaver> saver;
::boost::fibers::fiber fb;
std::function<void()> cleanup; // Optional cleanup for cancellation.
};
struct SyncInfo {
SyncState state = SyncState::PREPARATION;
// Stores information related to a single replica.
struct ReplicaInfo {
ReplicaInfo(unsigned flow_count, Context::ErrHandler err_handler)
: state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, flows{flow_count} {
}
SyncState state;
Context cntx;
std::vector<FlowInfo> flows;
::boost::fibers::mutex mu; // guard operations on replica.
::boost::fibers::mutex mu; // See top of header for locking levels.
};
public:
@ -93,39 +150,44 @@ class DflyCmd {
void Expire(CmdArgList args, ConnectionContext* cntx);
// Start full sync in thread. Start FullSyncFb. Called for each flow.
facade::OpStatus StartFullSyncInThread(FlowInfo* flow, EngineShard* shard);
facade::OpStatus StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard);
// Stop full sync in thread. Run state switch cleanup.
void StopFullSyncInThread(FlowInfo* flow, EngineShard* shard);
// Start stable sync in thread. Called for each flow.
facade::OpStatus StartStableSyncInThread(FlowInfo* flow, EngineShard* shard);
// Fiber that runs full sync for each flow.
void FullSyncFb(FlowInfo* flow);
void FullSyncFb(FlowInfo* flow, Context* cntx);
// Unregister flow. Must be called when flow disconnects.
void UnregisterFlow(FlowInfo*);
// Main entrypoint for stopping replication.
void StopReplication(uint32_t sync_id);
// Delete sync session. Cleanup flows.
void DeleteSyncSession(uint32_t sync_id);
// Transition into cancelled state, run cleanup.
void CancelReplication(uint32_t sync_id, std::shared_ptr<ReplicaInfo> replica_info_ptr);
// Get SyncInfo by sync_id.
std::shared_ptr<SyncInfo> GetSyncInfo(uint32_t sync_id);
// Get ReplicaInfo by sync_id.
std::shared_ptr<ReplicaInfo> GetReplicaInfo(uint32_t sync_id);
// Find sync info by id or send error reply.
std::pair<uint32_t, std::shared_ptr<SyncInfo>> GetSyncInfoOrReply(std::string_view id,
facade::RedisReplyBuilder* rb);
bool CheckReplicaStateOrReply(const SyncInfo& si, SyncState expected,
std::pair<uint32_t, std::shared_ptr<ReplicaInfo>> GetReplicaInfoOrReply(
std::string_view id, facade::RedisReplyBuilder* rb);
// Check replica is in expected state and flows are set-up correctly.
bool CheckReplicaStateOrReply(const ReplicaInfo& ri, SyncState expected,
facade::RedisReplyBuilder* rb);
private:
ServerFamily* sf_;
util::ListenerInterface* listener_;
TxId journal_txid_ = 0;
absl::btree_map<uint32_t, std::shared_ptr<SyncInfo>> sync_infos_;
uint32_t next_sync_id_ = 1;
absl::btree_map<uint32_t, std::shared_ptr<ReplicaInfo>> replica_infos_;
::boost::fibers::mutex mu_; // guard sync info and journal operations.
::boost::fibers::mutex mu_; // Guard global operations. See header top for locking levels.
};
} // namespace dfly

View file

@ -225,11 +225,26 @@ class EngineShardSet {
RunBriefInParallel(std::forward<U>(func), [](auto i) { return true; });
}
// Runs a brief function on selected shards. Waits for it to complete.
// Runs a brief function on selected shard thread. Waits for it to complete.
template <typename U, typename P> void RunBriefInParallel(U&& func, P&& pred) const;
template <typename U> void RunBlockingInParallel(U&& func);
// Runs func on all shards via the same shard queue that's been used by transactions framework.
// The functions running inside the shard queue run atomically (sequentially)
// with respect each other on the same shard.
template <typename U> void AwaitRunningOnShardQueue(U&& func) {
util::fibers_ext::BlockingCounter bc{unsigned(shard_queue_.size())};
for (size_t i = 0; i < shard_queue_.size(); ++i) {
Add(i, [&func, bc]() mutable {
func(EngineShard::tlocal());
bc.Dec();
});
}
bc.Wait();
}
// Used in tests
void TEST_EnableHeartBeat();
void TEST_EnableCacheMode();

View file

@ -118,9 +118,11 @@ error_code JournalSlice::Close() {
void JournalSlice::AddLogRecord(const Entry& entry) {
DCHECK(ring_buffer_);
iterating_cb_arr_ = true;
for (const auto& k_v : change_cb_arr_) {
k_v.second(entry);
}
iterating_cb_arr_ = false;
RingItem item;
item.lsn = lsn_;
@ -146,12 +148,12 @@ uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) {
}
void JournalSlice::Unregister(uint32_t id) {
for (auto it = change_cb_arr_.begin(); it != change_cb_arr_.end(); ++it) {
if (it->first == id) {
CHECK(!iterating_cb_arr_);
auto it = find_if(change_cb_arr_.begin(), change_cb_arr_.end(),
[id](const auto& e) { return e.first == id; });
CHECK(it != change_cb_arr_.end());
change_cb_arr_.erase(it);
break;
}
}
}
} // namespace journal

View file

@ -48,12 +48,13 @@ class JournalSlice {
void Unregister(uint32_t);
private:
struct RingItem;
std::string shard_path_;
std::unique_ptr<util::uring::LinuxFile> shard_file_;
std::optional<base::RingBuffer<RingItem>> ring_buffer_;
bool iterating_cb_arr_ = false;
std::vector<std::pair<uint32_t, ChangeCallback>> change_cb_arr_;
size_t file_offset_ = 0;

View file

@ -739,11 +739,11 @@ class RdbSaver::Impl {
// correct closing semantics - channel is closing when K producers marked it as closed.
Impl(bool align_writes, unsigned producers_len, io::Sink* sink);
void StartSnapshotting(bool stream_journal, EngineShard* shard);
void StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard);
void StopSnapshotting(EngineShard* shard);
error_code ConsumeChannel();
error_code ConsumeChannel(const Cancellation* cll);
error_code Flush() {
if (aligned_buf_)
@ -764,6 +764,8 @@ class RdbSaver::Impl {
return &meta_serializer_;
}
void Cancel();
private:
unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard);
@ -797,7 +799,7 @@ error_code RdbSaver::Impl::SaveAuxFieldStrStr(string_view key, string_view val)
return error_code{};
}
error_code RdbSaver::Impl::ConsumeChannel() {
error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
error_code io_error;
uint8_t buf[16];
@ -812,10 +814,13 @@ error_code RdbSaver::Impl::ConsumeChannel() {
auto& channel = channel_;
while (channel.Pop(record)) {
if (io_error)
if (io_error || cll->IsCancelled())
continue;
do {
if (cll->IsCancelled())
continue;
if (record.db_index != last_db_index) {
unsigned enclen = SerializeLen(record.db_index, buf + 1);
string_view str{(char*)buf, enclen + 1};
@ -855,17 +860,32 @@ error_code RdbSaver::Impl::ConsumeChannel() {
return io_error;
}
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, EngineShard* shard) {
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll,
EngineShard* shard) {
auto& s = GetSnapshot(shard);
s.reset(new SliceSnapshot(&shard->db_slice(), &channel_));
s->Start(stream_journal);
s->Start(stream_journal, cll);
}
void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) {
GetSnapshot(shard)->Stop();
}
void RdbSaver::Impl::Cancel() {
auto* shard = EngineShard::tlocal();
if (!shard)
return;
auto& snapshot = GetSnapshot(shard);
if (snapshot)
snapshot->Cancel();
dfly::SliceSnapshot::DbRecord rec;
while (channel_.Pop(rec)) {
}
}
void RdbSaver::Impl::FillFreqMap(RdbTypeFreqMap* dest) const {
for (auto& ptr : shard_snapshots_) {
const RdbTypeFreqMap& src_map = ptr->freq_map();
@ -905,8 +925,9 @@ RdbSaver::RdbSaver(::io::Sink* sink, SaveMode save_mode, bool align_writes) {
RdbSaver::~RdbSaver() {
}
void RdbSaver::StartSnapshotInShard(bool stream_journal, EngineShard* shard) {
impl_->StartSnapshotting(stream_journal, shard);
void RdbSaver::StartSnapshotInShard(bool stream_journal, const Cancellation* cll,
EngineShard* shard) {
impl_->StartSnapshotting(stream_journal, cll, shard);
}
void RdbSaver::StopSnapshotInShard(EngineShard* shard) {
@ -924,14 +945,14 @@ error_code RdbSaver::SaveHeader(const StringVec& lua_scripts) {
return error_code{};
}
error_code RdbSaver::SaveBody(RdbTypeFreqMap* freq_map) {
error_code RdbSaver::SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map) {
RETURN_ON_ERR(impl_->serializer()->FlushMem());
if (save_mode_ == SaveMode::SUMMARY) {
impl_->serializer()->SendFullSyncCut();
} else {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel();
error_code io_error = impl_->ConsumeChannel(cll);
if (io_error) {
LOG(ERROR) << "io error " << io_error;
return io_error;
@ -1001,4 +1022,8 @@ error_code RdbSaver::SaveAuxFieldStrInt(string_view key, int64_t val) {
return impl_->SaveAuxFieldStrStr(key, string_view(buf, vlen));
}
void RdbSaver::Cancel() {
impl_->Cancel();
}
} // namespace dfly

View file

@ -72,7 +72,7 @@ class RdbSaver {
// Initiates the serialization in the shard's thread.
// TODO: to implement break functionality to allow stopping early.
void StartSnapshotInShard(bool stream_journal, EngineShard* shard);
void StartSnapshotInShard(bool stream_journal, const Cancellation* cll, EngineShard* shard);
// Stops serialization in journal streaming mode in the shard's thread.
void StopSnapshotInShard(EngineShard* shard);
@ -83,7 +83,9 @@ class RdbSaver {
// Writes the RDB file into sink. Waits for the serialization to finish.
// Fills freq_map with the histogram of rdb types.
// freq_map can optionally be null.
std::error_code SaveBody(RdbTypeFreqMap* freq_map);
std::error_code SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map);
void Cancel();
SaveMode Mode() const {
return save_mode_;

View file

@ -149,6 +149,15 @@ void Replica::Stop() {
LOG_IF(ERROR, ec) << "Could not shutdown socket " << ec;
});
}
// Close sub flows.
auto partition = Partition(num_df_flows_);
shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) {
for (auto id : partition[index]) {
shard_flows_[id]->Stop();
}
});
if (sync_fb_.joinable())
sync_fb_.join();
}

View file

@ -195,6 +195,8 @@ class RdbSnapshot {
std::unique_ptr<io::Sink> io_sink_;
std::unique_ptr<RdbSaver> saver_;
RdbTypeFreqMap freq_map_;
Cancellation cll_{};
};
io::Result<size_t> LinuxWriteWrapper::WriteSome(const iovec* v, uint32_t len) {
@ -229,7 +231,7 @@ error_code RdbSnapshot::Start(SaveMode save_mode, const std::string& path,
}
error_code RdbSnapshot::SaveBody() {
return saver_->SaveBody(&freq_map_);
return saver_->SaveBody(&cll_, &freq_map_);
}
error_code RdbSnapshot::Close() {
@ -241,7 +243,7 @@ error_code RdbSnapshot::Close() {
}
void RdbSnapshot::StartInShard(EngineShard* shard) {
saver_->StartSnapshotInShard(false, shard);
saver_->StartSnapshotInShard(false, &cll_, shard);
started_ = true;
}

View file

@ -34,7 +34,7 @@ SliceSnapshot::SliceSnapshot(DbSlice* slice, RecordChannel* dest) : db_slice_(sl
SliceSnapshot::~SliceSnapshot() {
}
void SliceSnapshot::Start(bool stream_journal) {
void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
DCHECK(!snapshot_fb_.joinable());
auto on_change = [this](DbIndex db_index, const DbSlice::ChangeReq& req) {
@ -54,9 +54,11 @@ void SliceSnapshot::Start(bool stream_journal) {
sfile_.reset(new io::StringFile);
rdb_serializer_.reset(new RdbSerializer(sfile_.get()));
snapshot_fb_ = fiber([this, stream_journal] {
SerializeEntriesFb();
if (!stream_journal) {
snapshot_fb_ = fiber([this, stream_journal, cll] {
SerializeEntriesFb(cll);
if (cll->IsCancelled()) {
Cancel();
} else if (!stream_journal) {
CloseRecordChannel();
}
db_slice_->UnregisterOnChange(snapshot_version_);
@ -75,6 +77,14 @@ void SliceSnapshot::Stop() {
CloseRecordChannel();
}
void SliceSnapshot::Cancel() {
CloseRecordChannel();
if (journal_cb_id_) {
db_slice_->shard_owner()->journal()->Unregister(journal_cb_id_);
journal_cb_id_ = 0;
}
}
void SliceSnapshot::Join() {
// Fiber could have already been joined by Stop.
if (snapshot_fb_.joinable())
@ -82,12 +92,15 @@ void SliceSnapshot::Join() {
}
// Serializes all the entries with version less than snapshot_version_.
void SliceSnapshot::SerializeEntriesFb() {
void SliceSnapshot::SerializeEntriesFb(const Cancellation* cll) {
this_fiber::properties<FiberProps>().set_name(
absl::StrCat("SliceSnapshot", ProactorBase::GetIndex()));
PrimeTable::Cursor cursor;
for (DbIndex db_indx = 0; db_indx < db_array_.size(); ++db_indx) {
if (cll->IsCancelled())
return;
if (!db_array_[db_indx])
continue;
@ -100,6 +113,9 @@ void SliceSnapshot::SerializeEntriesFb() {
mu_.unlock();
do {
if (cll->IsCancelled())
return;
PrimeTable::Cursor next = pt->Traverse(cursor, [this](auto it) { this->SaveCb(move(it)); });
cursor = next;
@ -126,6 +142,7 @@ void SliceSnapshot::SerializeEntriesFb() {
mu_.lock();
mu_.unlock();
for (unsigned i = 10; i > 1; i--)
CHECK(!rdb_serializer_->SendFullSyncCut());
FlushSfile(true);
@ -138,7 +155,12 @@ void SliceSnapshot::CloseRecordChannel() {
// Can not think of anything more elegant.
mu_.lock();
mu_.unlock();
// Make sure we close the channel only once with a CAS check.
bool actual = false;
if (closed_chan_.compare_exchange_strong(actual, true)) {
dest_->StartClosing();
}
}
// This function should not block and should not preempt because it's called

View file

@ -4,6 +4,7 @@
#pragma once
#include <atomic>
#include <bitset>
#include "io/file.h"
@ -36,12 +37,14 @@ class SliceSnapshot {
SliceSnapshot(DbSlice* slice, RecordChannel* dest);
~SliceSnapshot();
void Start(bool stream_journal);
void Start(bool stream_journal, const Cancellation* cll);
void Stop(); // only needs to be called in journal streaming mode.
void Join();
void Cancel();
uint64_t snapshot_version() const {
return snapshot_version_;
}
@ -61,7 +64,7 @@ class SliceSnapshot {
private:
void CloseRecordChannel();
void SerializeEntriesFb();
void SerializeEntriesFb(const Cancellation* cll);
void SerializeSingleEntry(DbIndex db_index, const PrimeKey& pk, const PrimeValue& pv,
RdbSerializer* serializer);
@ -98,6 +101,8 @@ class SliceSnapshot {
uint32_t journal_cb_id_ = 0;
::boost::fibers::fiber snapshot_fb_;
std::atomic_bool closed_chan_{false};
};
} // namespace dfly

View file

@ -843,7 +843,7 @@ bool Transaction::ScheduleUniqueShard(EngineShard* shard) {
sd.pq_pos = shard->txq()->Insert(this);
DCHECK_EQ(0, sd.local_mask & KEYLOCK_ACQUIRED);
bool lock_acquired = shard->db_slice().Acquire(mode, lock_args);
shard->db_slice().Acquire(mode, lock_args);
sd.local_mask |= KEYLOCK_ACQUIRED;
DVLOG(1) << "Rescheduling into TxQueue " << DebugId();

View file

@ -2,8 +2,11 @@
import pytest
import asyncio
import aioredis
import random
from itertools import count, chain, repeat
from .utility import *
from . import dfly_args
BASE_PORT = 1111
@ -12,6 +15,10 @@ BASE_PORT = 1111
Test full replication pipeline. Test full sync with streaming changes and stable state streaming.
"""
# 1. Number of master threads
# 2. Number of threads for each replica
# 3. Number of keys stored and sent in full sync
# 4. Number of keys overwritten during full sync
replication_cases = [
(8, [8], 20000, 5000),
(8, [8], 10000, 10000),
@ -80,61 +87,140 @@ async def test_replication_all(df_local_factory, t_master, t_replicas, n_keys, n
"""
Test replica crash during full sync on multiple replicas without altering data during replication.
Test disconnecting replicas during different phases with constantly streaming changes to master.
Three types are tested:
1. Replicas crashing during full sync state
2. Replicas crashing during stable sync state
3. Replicas disconnecting normally with REPLICAOF NO ONE during stable state
"""
# (threads_master, threads_replicas, n entries)
simple_full_sync_multi_crash_cases = [
(5, [1] * 15, 5000),
(5, [1] * 20, 5000),
(5, [1] * 25, 5000)
# 1. Number of master threads
# 2. Number of threads for each replica that crashes during full sync
# 3. Number of threads for each replica that crashes during stable sync
# 4. Number of threads for each replica that disconnects normally
# 5. Number of distinct keys that are constantly streamed
disconnect_cases = [
# balanced
(8, [4, 4], [4, 4], [4], 10000),
(8, [2] * 6, [2] * 6, [2, 2], 10000),
# full sync heavy
(8, [4] * 6, [], [], 10000),
(8, [2] * 12, [], [], 10000),
# stable state heavy
(8, [], [4] * 6, [], 10000),
(8, [], [2] * 12, [], 10000),
# disconnect only
(8, [], [], [2] * 6, 10000)
]
@pytest.mark.asyncio
@pytest.mark.skip(reason="test is currently crashing")
@pytest.mark.parametrize("t_master, t_replicas, n_keys", simple_full_sync_multi_crash_cases)
async def test_simple_full_sync_mutli_crash(df_local_factory, t_master, t_replicas, n_keys):
def data_gen(): return gen_test_data(n_keys)
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
@pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases)
async def test_disconnect(df_local_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master,logtostdout="")
replicas = [
df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t)
for i, t in enumerate(t_replicas)
(df_local_factory.create(
port=BASE_PORT+i+1, proactor_threads=t), crash_fs)
for i, (t, crash_fs) in enumerate(
chain(
zip(t_crash_fs, repeat(0)),
zip(t_crash_ss, repeat(1)),
zip(t_disonnect, repeat(2))
)
)
]
# Start master and fill with test data
# Start master
master.start()
c_master = aioredis.Redis(port=master.port, single_connection_client=True)
await batch_fill_data_async(c_master, data_gen())
# Start replica tasks in parallel
tasks = [
asyncio.create_task(run_sfs_crash_replica(
replica, master, data_gen), name="replica-"+str(replica.port))
for replica in replicas
# Start replicas and create clients
for replica, _ in replicas:
replica.start()
c_replicas = [
(replica, aioredis.Redis(port=replica.port), crash_type)
for replica, crash_type in replicas
]
for task in tasks:
assert await task
def replicas_of_type(tfunc):
return [
args for args in c_replicas
if tfunc(args[2])
]
# Check master is ok
await batch_check_data_async(c_master, data_gen())
# Start data fill loop
async def fill_loop():
local_c = aioredis.Redis(
port=master.port, single_connection_client=True)
for seed in count(1):
await batch_fill_data_async(local_c, gen_test_data(n_keys, seed=seed))
await c_master.connection_pool.disconnect()
async def run_sfs_crash_replica(replica, master, data_gen):
replica.start()
c_replica = aioredis.Redis(
port=replica.port, single_connection_client=None)
fill_task = asyncio.create_task(fill_loop())
# Run full sync
async def full_sync(replica, c_replica, crash_type):
c_replica = aioredis.Redis(port=replica.port)
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
if crash_type == 0:
await asyncio.sleep(random.random()/100+0.01)
replica.stop(kill=True)
else:
await wait_available_async(c_replica)
# Kill the replica after a short delay
await asyncio.sleep(0.0)
await asyncio.gather(*(full_sync(*args) for args in c_replicas))
# Wait for master to stream a bit more
await asyncio.sleep(0.1)
# Check master survived full sync crashes
assert await c_master.ping()
# Check phase-2 replicas survived
for _, c_replica, _ in replicas_of_type(lambda t: t > 0):
assert await c_replica.ping()
# Run stable state crashes
async def stable_sync(replica, c_replica, crash_type):
await asyncio.sleep(random.random() / 100)
replica.stop(kill=True)
await c_replica.connection_pool.disconnect()
return True
await asyncio.gather(*(stable_sync(*args) for args
in replicas_of_type(lambda t: t == 1)))
# Check master survived all crashes
assert await c_master.ping()
# Check phase 3 replica survived
for _, c_replica, _ in replicas_of_type(lambda t: t > 1):
assert await c_replica.ping()
# Stop streaming
fill_task.cancel()
# Check master survived all crashes
assert await c_master.ping()
# Check phase 3 replicas are up-to-date and there is no gap or lag
def check_gen(): return gen_test_data(n_keys//5, seed=0)
await batch_fill_data_async(c_master, check_gen())
await asyncio.sleep(0.1)
for _, c_replica, _ in replicas_of_type(lambda t: t > 1):
await batch_check_data_async(c_replica, check_gen())
# Check disconnects
async def disconnect(replica, c_replica, crash_type):
await asyncio.sleep(random.random() / 100)
await c_replica.execute_command("REPLICAOF NO ONE")
await asyncio.gather(*(disconnect(*args) for args
in replicas_of_type(lambda t: t == 2)))
# Check phase 3 replica survived
for _, c_replica, _ in replicas_of_type(lambda t: t == 2):
assert await c_replica.ping()
await batch_check_data_async(c_replica, check_gen())
# Check master survived all disconnects
assert await c_master.ping()