feat(replication): First iteration on partial sync. (#1836)

First iteration on partial sync.
This commit is contained in:
Roy Jacobson 2023-09-26 10:35:50 +03:00 committed by GitHub
parent d9f4ca8003
commit d50b492e1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 524 additions and 86 deletions

View file

@ -115,7 +115,7 @@ GenericError RdbSnapshot::Start(SaveMode save_mode, const std::string& path,
}
error_code RdbSnapshot::SaveBody() {
return saver_->SaveBody(&cll_, &freq_map_);
return saver_->SaveBody(&cntx_, &freq_map_);
}
error_code RdbSnapshot::Close() {
@ -126,7 +126,7 @@ error_code RdbSnapshot::Close() {
}
void RdbSnapshot::StartInShard(EngineShard* shard) {
saver_->StartSnapshotInShard(false, &cll_, shard);
saver_->StartSnapshotInShard(false, cntx_.GetCancellation(), shard);
started_ = true;
}

View file

@ -63,7 +63,7 @@ class RdbSnapshot {
unique_ptr<RdbSaver> saver_;
RdbTypeFreqMap freq_map_;
Cancellation cll_{};
Context cntx_{};
};
struct SaveStagesController : public SaveStagesInputs {

View file

@ -8,9 +8,11 @@
#include <absl/strings/strip.h>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
#include "absl/strings/numbers.h"
#include "base/flags.h"
#include "base/logging.h"
#include "facade/dragonfly_connection.h"
@ -114,7 +116,7 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
return Thread(args, cntx);
}
if (sub_cmd == "FLOW" && args.size() == 4) {
if (sub_cmd == "FLOW" && (args.size() == 4 || args.size() == 5)) {
return Flow(args, cntx);
}
@ -241,7 +243,16 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
string_view sync_id_str = ArgS(args, 2);
string_view flow_id_str = ArgS(args, 3);
VLOG(1) << "Got DFLY FLOW " << master_id << " " << sync_id_str << " " << flow_id_str;
std::optional<LSN> seqid;
if (args.size() == 5) {
seqid.emplace();
if (!absl::SimpleAtoi(ArgS(args, 4), &seqid.value())) {
return rb->SendError(facade::kInvalidIntErr);
}
}
VLOG(1) << "Got DFLY FLOW master_id: " << master_id << " sync_id: " << sync_id_str
<< " flow: " << flow_id_str << " seq: " << seqid.value_or(-1);
if (master_id != sf_->master_id()) {
return rb->SendError(kBadMasterId);
@ -268,13 +279,33 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
absl::InsecureBitGen gen;
string eof_token = GetRandomHex(gen, 40);
cntx->replication_flow = &replica_ptr->flows[flow_id];
replica_ptr->flows[flow_id].conn = cntx->owner();
replica_ptr->flows[flow_id].eof_token = eof_token;
auto& flow = replica_ptr->flows[flow_id];
cntx->replication_flow = &flow;
flow.conn = cntx->owner();
flow.eof_token = eof_token;
flow.version = replica_ptr->version;
cntx->owner()->Migrate(shard_set->pool()->at(flow_id));
sf_->journal()->StartInThread();
std::string_view sync_type = "FULL";
if (seqid.has_value()) {
if (sf_->journal()->IsLSNInBuffer(*seqid) || sf_->journal()->GetLsn() == *seqid) {
flow.start_partial_sync_at = *seqid;
VLOG(1) << "Partial sync requested from LSN=" << flow.start_partial_sync_at.value()
<< " and is available. (current_lsn=" << sf_->journal()->GetLsn() << ")";
sync_type = "PARTIAL";
} else {
LOG(INFO) << "Partial sync requested from stale LSN=" << *seqid
<< " that the replication buffer doesn't contain this anymore (current_lsn="
<< sf_->journal()->GetLsn() << "). Will perform a full sync of the data.";
LOG(INFO) << "If this happens often you can control the replication buffer's size with the "
"--shard_repl_backlog_len option";
}
}
rb->StartArray(2);
rb->SendSimpleString("FULL");
rb->SendSimpleString(sync_type);
rb->SendSimpleString(eof_token);
}
@ -309,7 +340,7 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) {
return rb->SendError(kInvalidState);
}
LOG(INFO) << "Started full sync with replica " << replica_ptr->address << ":"
LOG(INFO) << "Started sync with replica " << replica_ptr->address << ":"
<< replica_ptr->listening_port;
replica_ptr->state.store(SyncState::FULL_SYNC, memory_order_relaxed);
@ -468,7 +499,7 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
// of the flows also contain them.
SaveMode save_mode =
shard->shard_id() == 0 ? SaveMode::SINGLE_SHARD_WITH_SUMMARY : SaveMode::SINGLE_SHARD;
flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false));
flow->saver = std::make_unique<RdbSaver>(flow->conn->socket(), save_mode, false);
flow->cleanup = [flow]() {
flow->saver->Cancel();
@ -477,11 +508,12 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
flow->saver.reset();
};
sf_->journal()->StartInThread();
// Shard can be null for io thread.
if (shard != nullptr) {
flow->saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
if (flow->start_partial_sync_at.has_value())
flow->saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
else
flow->saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
}
flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx);
@ -532,7 +564,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
// Always send original body (with header & without auto async calls) that determines the sha,
// It's stored only if it's different from the post-processed version.
string& body = data.orig_body.empty() ? data.body : data.orig_body;
script_bodies.push_back(move(body));
script_bodies.push_back(std::move(body));
}
ec = saver->SaveHeader({script_bodies, {}});
} else {
@ -544,7 +576,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
return;
}
if (ec = saver->SaveBody(cntx->GetCancellation(), nullptr); ec) {
if (ec = saver->SaveBody(cntx, nullptr); ec) {
cntx->ReportError(ec);
return;
}
@ -711,6 +743,14 @@ std::map<uint32_t, LSN> DflyCmd::ReplicationLags() const {
return rv;
}
void DflyCmd::SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version) {
auto replica_ptr = GetReplicaInfo(cntx->conn_state.replication_info.repl_session_id);
VLOG(1) << "Client version for session_id=" << cntx->conn_state.replication_info.repl_session_id
<< " is " << int(version);
replica_ptr->version = version;
}
bool DflyCmd::CheckReplicaStateOrReply(const ReplicaInfo& sync_info, SyncState expected,
RedisReplyBuilder* rb) {
if (sync_info.state != expected) {

View file

@ -40,7 +40,9 @@ struct FlowInfo {
std::unique_ptr<RdbSaver> saver; // Saver used by the full sync phase.
std::unique_ptr<JournalStreamer> streamer;
std::string eof_token;
DflyVersion version;
std::optional<LSN> start_partial_sync_at;
uint64_t last_acked_lsn;
std::function<void()> cleanup; // Optional cleanup for cancellation.
@ -99,8 +101,11 @@ class DflyCmd {
struct ReplicaInfo {
ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port,
Context::ErrHandler err_handler)
: state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{std::move(address)},
listening_port(listening_port), flows{flow_count} {
: state{SyncState::PREPARATION},
cntx{std::move(err_handler)},
address{std::move(address)},
listening_port(listening_port),
flows{flow_count} {
}
std::atomic<SyncState> state;
@ -108,6 +113,7 @@ class DflyCmd {
std::string address;
uint32_t listening_port;
DflyVersion version = DflyVersion::VER0;
std::vector<FlowInfo> flows;
Mutex mu; // See top of header for locking levels.
@ -130,6 +136,9 @@ class DflyCmd {
std::vector<ReplicaRoleInfo> GetReplicasRoleInfo();
// Sets metadata.
void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version);
private:
// JOURNAL [START/STOP]
// Start or stop journaling.
@ -139,8 +148,11 @@ class DflyCmd {
// Return connection thread index or migrate to another thread.
void Thread(CmdArgList args, ConnectionContext* cntx);
// FLOW <masterid> <syncid> <flowid>
// FLOW <masterid> <syncid> <flowid> [<seqid>]
// Register connection as flow for sync session.
// If seqid is given, it means the client wants to try partial sync.
// If it is possible, return Ok and prepare for a partial sync, else
// return error and ask the replica to execute FLOW again.
void Flow(CmdArgList args, ConnectionContext* cntx);
// SYNC <syncid>

View file

@ -20,15 +20,17 @@ using facade::kWrongTypeErr;
#ifndef RETURN_ON_ERR
#define RETURN_ON_ERR(x) \
#define RETURN_ON_ERR_T(T, x) \
do { \
std::error_code __ec = (x); \
if (__ec) { \
DLOG(ERROR) << "Error " << __ec << " while calling " #x; \
return __ec; \
return (T)(__ec); \
} \
} while (0)
#define RETURN_ON_ERR(x) RETURN_ON_ERR_T(std::error_code, x)
#endif // RETURN_ON_ERR
namespace rdb {

View file

@ -16,7 +16,7 @@
#include "base/logging.h"
#include "server/journal/serializer.h"
ABSL_FLAG(int, shard_repl_backlog_len, 1 << 10,
ABSL_FLAG(uint32_t, shard_repl_backlog_len, 1 << 10,
"The length of the circular replication log per shard");
namespace dfly {
@ -33,6 +33,14 @@ string ShardName(std::string_view base, unsigned index) {
}
*/
uint32_t NextPowerOf2(uint32_t x) {
if (x < 2) {
return 1;
}
int log = 32 - __builtin_clz(x - 1);
return 1 << log;
}
} // namespace
#define CHECK_EC(x) \
@ -53,7 +61,7 @@ void JournalSlice::Init(unsigned index) {
return;
slice_index_ = index;
ring_buffer_.emplace(absl::GetFlag(FLAGS_shard_repl_backlog_len));
ring_buffer_.emplace(NextPowerOf2(absl::GetFlag(FLAGS_shard_repl_backlog_len)));
}
#if 0
@ -144,6 +152,7 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) {
item->opcode = entry.opcode;
item->data = "";
} else {
FiberAtomicGuard fg;
// GetTail gives a pointer to a new tail entry in the buffer, possibly overriding the last entry
// if the buffer is full.
item = ring_buffer_->GetTail(true);

View file

@ -148,10 +148,13 @@ class ProtocolClient {
/**
* A convenience macro to use with ProtocolClient instances for protocol input validation.
*/
#define PC_RETURN_ON_BAD_RESPONSE(x) \
do { \
if (!(x)) { \
LOG(ERROR) << "Bad response to \"" << last_cmd_ << "\": \"" << absl::CEscape(last_resp_); \
return std::make_error_code(errc::bad_message); \
} \
#define PC_RETURN_ON_BAD_RESPONSE_T(T, x) \
do { \
if (!(x)) { \
LOG(ERROR) << "Bad response to \"" << last_cmd_ << "\": \"" << absl::CEscape(last_resp_) \
<< "\""; \
return (T)(std::make_error_code(errc::bad_message)); \
} \
} while (false)
#define PC_RETURN_ON_BAD_RESPONSE(x) PC_RETURN_ON_BAD_RESPONSE_T(std::error_code, x)

View file

@ -1843,6 +1843,7 @@ error_code RdbLoader::Load(io::Source* src) {
auto cb = mem_buf_->InputBuffer();
if (memcmp(cb.data(), "REDIS", 5) != 0) {
VLOG(1) << "Bad header: " << absl::CHexEscape(facade::ToSV(cb));
return RdbError(errc::wrong_signature);
}

View file

@ -189,8 +189,8 @@ class RdbLoader : protected RdbLoaderBase {
// Return the offset that was received with a RDB_OPCODE_JOURNAL_OFFSET command,
// or 0 if no offset was received.
uint64_t journal_offset() const {
return journal_offset_.value_or(0);
std::optional<uint64_t> journal_offset() const {
return journal_offset_;
}
// Set callback for receiving RDB_OPCODE_FULLSYNC_END.

View file

@ -673,6 +673,7 @@ error_code RdbSerializer::SaveStreamConsumers(streamCG* cg) {
}
error_code RdbSerializer::SendJournalOffset(uint64_t journal_offset) {
VLOG(2) << "SendJournalOffset";
RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_JOURNAL_OFFSET));
uint8_t buf[sizeof(uint64_t)];
absl::little_endian::Store64(buf, journal_offset);
@ -680,6 +681,7 @@ error_code RdbSerializer::SendJournalOffset(uint64_t journal_offset) {
}
error_code RdbSerializer::SendFullSyncCut() {
VLOG(2) << "SendFullSyncCut";
RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_FULLSYNC_END));
// RDB_OPCODE_FULLSYNC_END followed by 8 bytes of 0.
@ -734,6 +736,7 @@ io::Bytes RdbSerializer::PrepareFlush() {
}
error_code RdbSerializer::WriteJournalEntry(std::string_view serialized_entry) {
VLOG(2) << "WriteJournalEntry";
RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_JOURNAL_BLOB));
RETURN_ON_ERR(SaveLen(1));
RETURN_ON_ERR(SaveString(serialized_entry));
@ -893,6 +896,7 @@ class RdbSaver::Impl {
SaveMode save_mode, io::Sink* sink);
void StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard);
void StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, LSN start_lsn);
void StopSnapshotting(EngineShard* shard);
@ -1053,11 +1057,19 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll,
EngineShard* shard) {
auto& s = GetSnapshot(shard);
s.reset(new SliceSnapshot(&shard->db_slice(), &channel_, compression_mode_));
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_);
s->Start(stream_journal, cll);
}
void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard,
LSN start_lsn) {
auto& s = GetSnapshot(shard);
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_);
s->StartIncremental(cntx, start_lsn);
}
void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) {
GetSnapshot(shard)->Stop();
}
@ -1142,6 +1154,10 @@ void RdbSaver::StartSnapshotInShard(bool stream_journal, const Cancellation* cll
impl_->StartSnapshotting(stream_journal, cll, shard);
}
void RdbSaver::StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn) {
impl_->StartIncrementalSnapshotting(cntx, shard, start_lsn);
}
void RdbSaver::StopSnapshotInShard(EngineShard* shard) {
impl_->StopSnapshotting(shard);
}
@ -1159,18 +1175,21 @@ error_code RdbSaver::SaveHeader(const GlobalData& glob_state) {
return error_code{};
}
error_code RdbSaver::SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map) {
error_code RdbSaver::SaveBody(Context* cntx, RdbTypeFreqMap* freq_map) {
RETURN_ON_ERR(impl_->serializer()->FlushToSink(impl_->sink()));
if (save_mode_ == SaveMode::SUMMARY) {
impl_->serializer()->SendFullSyncCut();
} else {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cll);
error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
if (io_error) {
LOG(ERROR) << "io error " << io_error;
return io_error;
}
if (cntx->GetError()) {
return cntx->GetError();
}
}
RETURN_ON_ERR(SaveEpilog());

View file

@ -88,6 +88,9 @@ class RdbSaver {
// TODO: to implement break functionality to allow stopping early.
void StartSnapshotInShard(bool stream_journal, const Cancellation* cll, EngineShard* shard);
// Send only the incremental snapshot since start_lsn.
void StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn);
// Stops serialization in journal streaming mode in the shard's thread.
void StopSnapshotInShard(EngineShard* shard);
@ -97,7 +100,7 @@ class RdbSaver {
// Writes the RDB file into sink. Waits for the serialization to finish.
// Fills freq_map with the histogram of rdb types.
// freq_map can optionally be null.
std::error_code SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map);
std::error_code SaveBody(Context* cntx, RdbTypeFreqMap* freq_map);
void Cancel();
@ -122,7 +125,7 @@ class CompressorImpl;
class RdbSerializer {
public:
RdbSerializer(CompressionMode compression_mode);
explicit RdbSerializer(CompressionMode compression_mode);
~RdbSerializer();

View file

@ -290,6 +290,10 @@ std::error_code Replica::HandleCapaDflyResp() {
return make_error_code(errc::bad_message);
}
// If we're syncing a different replication ID, drop the saved LSNs.
if (master_context_.master_repl_id != ToSV(LastResponseArgs()[0].GetBuf())) {
last_journal_LSNs_.reset();
}
master_context_.master_repl_id = ToSV(LastResponseArgs()[0].GetBuf());
master_context_.dfly_session_id = ToSV(LastResponseArgs()[1].GetBuf());
num_df_flows_ = param_num_flows;
@ -414,9 +418,10 @@ error_code Replica::InitiateDflySync() {
absl::Cleanup cleanup = [this]() {
// We do the following operations regardless of outcome.
JoinAllFlows();
JoinDflyFlows();
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
state_mask_.fetch_and(~R_SYNCING);
last_journal_LSNs_.reset();
};
// Initialize MultiShardExecution.
@ -450,25 +455,43 @@ error_code Replica::InitiateDflySync() {
// Make sure we're in LOADING state.
CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING);
// Flush dbs.
JournalExecutor{&service_}.FlushAll();
// Start full sync flows.
state_mask_.fetch_or(R_SYNCING);
std::string_view sync_type = "full";
{
// Going out of the way to avoid using std::vector<bool>...
auto is_full_sync = std::make_unique<bool[]>(num_df_flows_);
auto partition = Partition(num_df_flows_);
CHECK(!last_journal_LSNs_ || last_journal_LSNs_->size() == shard_flows_.size());
auto shard_cb = [&](unsigned index, auto*) {
for (auto id : partition[index]) {
auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block, &cntx_);
if (ec)
cntx_.ReportError(ec);
auto ec = shard_flows_[id]->StartSyncFlow(sync_block, &cntx_,
last_journal_LSNs_.has_value()
? std::optional((*last_journal_LSNs_)[id])
: std::nullopt);
if (ec.has_value())
is_full_sync[id] = ec.value();
else
cntx_.ReportError(ec.error());
}
};
// Lock to prevent the error handler from running instantly
// while the flows are in a mixed state.
lock_guard lk{flows_op_mu_};
shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb));
size_t num_full_flows =
std::accumulate(is_full_sync.get(), is_full_sync.get() + num_df_flows_, 0);
if (num_full_flows == num_df_flows_) {
JournalExecutor{&service_}.FlushAll();
} else if (num_full_flows == 0) {
sync_type = "partial";
} else {
last_journal_LSNs_.reset();
cntx_.ReportError(std::make_error_code(errc::state_not_recoverable),
"Won't do a partial sync: some flows must fully resync");
}
}
RETURN_ON_ERR(cntx_.GetError());
@ -478,7 +501,7 @@ error_code Replica::InitiateDflySync() {
return cntx_.ReportError(ec);
}
LOG(INFO) << absl::StrCat("Started full sync with ", server().Description());
LOG(INFO) << "Started " << sync_type << " sync with " << server().Description();
// Wait for all flows to receive full sync cut.
// In case of an error, this is unblocked by the error handler.
@ -497,7 +520,7 @@ error_code Replica::InitiateDflySync() {
// Joining flows and resetting state is done by cleanup.
double seconds = double(absl::ToInt64Milliseconds(absl::Now() - start_time)) / 1000;
LOG(INFO) << "Full sync finished in " << strings::HumanReadableElapsedTime(seconds);
LOG(INFO) << sync_type << " sync finished in " << strings::HumanReadableElapsedTime(seconds);
return cntx_.GetError();
}
@ -590,7 +613,12 @@ error_code Replica::ConsumeDflyStream() {
lock_guard lk{flows_op_mu_};
shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb));
}
JoinAllFlows();
JoinDflyFlows();
last_journal_LSNs_.emplace();
for (auto& flow : shard_flows_) {
last_journal_LSNs_->push_back(flow->JournalExecutedCount());
}
LOG(INFO) << "Exit stable sync";
// The only option to unblock is to cancel the context.
@ -599,7 +627,7 @@ error_code Replica::ConsumeDflyStream() {
return cntx_.GetError();
}
void Replica::JoinAllFlows() {
void Replica::JoinDflyFlows() {
for (auto& flow : shard_flows_) {
flow->JoinFlow();
}
@ -625,30 +653,41 @@ error_code Replica::SendNextPhaseRequest(string_view kind) {
return std::error_code{};
}
error_code DflyShardReplica::StartFullSyncFlow(BlockingCounter sb, Context* cntx) {
io::Result<bool> DflyShardReplica::StartSyncFlow(BlockingCounter sb, Context* cntx,
std::optional<LSN> lsn) {
using nonstd::make_unexpected;
DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty());
RETURN_ON_ERR(ConnectAndAuth(absl::GetFlag(FLAGS_master_connect_timeout_ms) * 1ms, &cntx_));
RETURN_ON_ERR_T(make_unexpected,
ConnectAndAuth(absl::GetFlag(FLAGS_master_connect_timeout_ms) * 1ms, &cntx_));
VLOG(1) << "Sending on flow " << master_context_.master_repl_id << " "
<< master_context_.dfly_session_id << " " << flow_id_;
auto cmd = StrCat("DFLY FLOW ", master_context_.master_repl_id, " ",
master_context_.dfly_session_id, " ", flow_id_);
std::string cmd = StrCat("DFLY FLOW ", master_context_.master_repl_id, " ",
master_context_.dfly_session_id, " ", flow_id_);
// Try to negotiate a partial sync if possible.
if (lsn.has_value() && master_context_.version > DflyVersion::VER1) {
absl::StrAppend(&cmd, " ", *lsn);
}
ResetParser(/*server_mode=*/false);
leftover_buf_.emplace(128);
RETURN_ON_ERR(SendCommand(cmd));
RETURN_ON_ERR_T(make_unexpected, SendCommand(cmd));
auto read_resp = ReadRespReply(&*leftover_buf_);
if (!read_resp.has_value()) {
return read_resp.error();
return make_unexpected(read_resp.error());
}
PC_RETURN_ON_BAD_RESPONSE(CheckRespFirstTypes({RespExpr::STRING, RespExpr::STRING}));
PC_RETURN_ON_BAD_RESPONSE_T(make_unexpected,
CheckRespFirstTypes({RespExpr::STRING, RespExpr::STRING}));
string_view flow_directive = ToSV(LastResponseArgs()[0].GetBuf());
string eof_token;
PC_RETURN_ON_BAD_RESPONSE(flow_directive == "FULL");
PC_RETURN_ON_BAD_RESPONSE_T(make_unexpected,
flow_directive == "FULL" || flow_directive == "PARTIAL");
bool is_full_sync = flow_directive == "FULL";
eof_token = ToSV(LastResponseArgs()[1].GetBuf());
leftover_buf_->ConsumeInput(read_resp->left_in_buffer);
@ -658,7 +697,7 @@ error_code DflyShardReplica::StartFullSyncFlow(BlockingCounter sb, Context* cntx
sync_fb_ = fb2::Fiber("shard_full_sync", &DflyShardReplica::FullSyncDflyFb, this,
std::move(eof_token), sb, cntx);
return error_code{};
return is_full_sync;
}
error_code DflyShardReplica::StartStableSyncFlow(Context* cntx) {
@ -680,7 +719,7 @@ error_code DflyShardReplica::StartStableSyncFlow(Context* cntx) {
return std::error_code{};
}
void DflyShardReplica::FullSyncDflyFb(const string& eof_token, BlockingCounter bc, Context* cntx) {
void DflyShardReplica::FullSyncDflyFb(std::string eof_token, BlockingCounter bc, Context* cntx) {
DCHECK(leftover_buf_);
io::PrefixSource ps{leftover_buf_->InputBuffer(), Sock()};
@ -722,7 +761,13 @@ void DflyShardReplica::FullSyncDflyFb(const string& eof_token, BlockingCounter b
leftover_buf_.reset();
}
this->journal_rec_executed_.store(loader.journal_offset());
if (auto jo = loader.journal_offset(); jo.has_value()) {
this->journal_rec_executed_.store(*jo);
} else {
if (master_context_.version > DflyVersion::VER0)
cntx->ReportError(std::make_error_code(errc::protocol_error),
"Error finding journal offset in stream");
}
VLOG(1) << "FullSyncDflyFb finished after reading " << loader.bytes_read() << " bytes";
}
@ -855,10 +900,11 @@ void DflyShardReplica::ExecuteTxWithNoShardSync(TransactionData&& tx_data, Conte
}
bool DflyShardReplica::InsertTxToSharedMap(const TransactionData& tx_data) {
std::lock_guard lk{multi_shard_exe_->map_mu};
std::unique_lock lk(multi_shard_exe_->map_mu);
auto [it, was_insert] =
multi_shard_exe_->tx_sync_execution.emplace(tx_data.txid, tx_data.shard_cnt);
lk.unlock();
VLOG(2) << "txid: " << tx_data.txid << " unique_shard_cnt_: " << tx_data.shard_cnt
<< " was_insert: " << was_insert;
it->second.block.Dec();
@ -902,11 +948,11 @@ void DflyShardReplica::ExecuteTx(TransactionData&& tx_data, bool inserted_by_me,
}
VLOG(2) << "Execute txid: " << tx_data.txid;
multi_shard_exe_->map_mu.lock();
std::unique_lock lk(multi_shard_exe_->map_mu);
auto it = multi_shard_exe_->tx_sync_execution.find(tx_data.txid);
DCHECK(it != multi_shard_exe_->tx_sync_execution.end());
auto& multi_shard_data = it->second;
multi_shard_exe_->map_mu.unlock();
lk.unlock();
VLOG(2) << "Execute txid: " << tx_data.txid << " waiting for data in all shards";
// Wait until shards flows got transaction data and inserted to map.

View file

@ -109,7 +109,10 @@ class Replica : ProtocolClient {
void RedisStreamAcksFb();
void JoinAllFlows(); // Join all flows if possible.
// Joins all the flows when doing sharded replication. This is called in two
// places: Once at the end of full sync to join the full sync fibers, and twice
// if a stable sync is interrupted to join the cancelled stable sync fibers.
void JoinDflyFlows();
void SetShardStates(bool replica); // Call SetReplica(replica) on all shards.
// Send DFLY ${kind} to the master instance.
@ -164,6 +167,9 @@ class Replica : ProtocolClient {
EventCount waker_;
std::vector<std::unique_ptr<DflyShardReplica>> shard_flows_;
// A vector of the last executer LSNs when a replication is interrupted.
// Allows partial sync on reconnects.
std::optional<std::vector<LSN>> last_journal_LSNs_;
std::shared_ptr<MultiShardExecution> multi_shard_exe_;
// Guard operations where flows might be in a mixed state (transition/setup)
@ -221,13 +227,14 @@ class DflyShardReplica : public ProtocolClient {
void JoinFlow();
// Start replica initialized as dfly flow.
std::error_code StartFullSyncFlow(BlockingCounter block, Context* cntx);
// Sets is_full_sync when successful.
io::Result<bool> StartSyncFlow(BlockingCounter block, Context* cntx, std::optional<LSN>);
// Transition into stable state mode as dfly flow.
std::error_code StartStableSyncFlow(Context* cntx);
// Single flow full sync fiber spawned by StartFullSyncFlow.
void FullSyncDflyFb(const std::string& eof_token, BlockingCounter block, Context* cntx);
void FullSyncDflyFb(std::string eof_token, BlockingCounter block, Context* cntx);
// Single flow stable state sync fiber spawned by StartStableSyncFlow.
void StableSyncDflyReadFb(Context* cntx);

View file

@ -1827,9 +1827,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
if (!absl::SimpleAtoi(arg, &version)) {
return (*cntx)->SendError(kInvalidIntErr);
}
VLOG(1) << "Client version for session_id="
<< cntx->conn_state.replication_info.repl_session_id << " is " << version;
cntx->conn_state.replication_info.repl_version = DflyVersion(version);
dfly_cmd_->SetDflyClientVersion(cntx, DflyVersion(version));
} else if (cmd == "ACK" && args.size() == 2) {
// Don't send error/Ok back through the socket, because we don't want to interleave with
// the journal writes that we write into the same socket.

View file

@ -37,16 +37,16 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
DCHECK(!snapshot_fb_.IsJoinable());
auto db_cb = absl::bind_front(&SliceSnapshot::OnDbChange, this);
snapshot_version_ = db_slice_->RegisterOnChange(move(db_cb));
snapshot_version_ = db_slice_->RegisterOnChange(std::move(db_cb));
if (stream_journal) {
auto* journal = db_slice_->shard_owner()->journal();
DCHECK(journal);
auto journal_cb = absl::bind_front(&SliceSnapshot::OnJournalEntry, this);
journal_cb_id_ = journal->RegisterOnChange(move(journal_cb));
journal_cb_id_ = journal->RegisterOnChange(std::move(journal_cb));
}
serializer_.reset(new RdbSerializer(compression_mode_));
serializer_ = std::make_unique<RdbSerializer>(compression_mode_);
VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_;
@ -61,6 +61,55 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
});
}
void SliceSnapshot::StartIncremental(Context* cntx, LSN start_lsn) {
auto* journal = db_slice_->shard_owner()->journal();
DCHECK(journal);
serializer_ = std::make_unique<RdbSerializer>(compression_mode_);
snapshot_fb_ =
fb2::Fiber("incremental_snapshot", [this, journal, cntx, lsn = start_lsn]() mutable {
DCHECK(lsn <= journal->GetLsn()) << "The replica tried to sync from the future.";
VLOG(1) << "Starting incremental snapshot from lsn=" << lsn;
// The replica sends the LSN of the next entry is wants to receive.
while (!cntx->IsCancelled() && journal->IsLSNInBuffer(lsn)) {
serializer_->WriteJournalEntry(journal->GetEntry(lsn));
PushSerializedToChannel(false);
lsn++;
}
VLOG(1) << "Last LSN sent in incremental snapshot was " << (lsn - 1);
// This check is safe, but it is not trivially safe.
// We rely here on the fact that JournalSlice::AddLogRecord can
// only preempt while holding the callback lock.
// That guarantees that if we have processed the last LSN the callback
// will only be added after JournalSlice::AddLogRecord has finished
// iterating its callbacks and we won't process the record twice.
// We have to make sure we don't preempt ourselves before registering the callback!
// GetLsn() is always the next lsn that we expect to create.
if (journal->GetLsn() == lsn) {
{
FiberAtomicGuard fg;
serializer_->SendFullSyncCut();
}
auto journal_cb = absl::bind_front(&SliceSnapshot::OnJournalEntry, this);
journal_cb_id_ = journal->RegisterOnChange(std::move(journal_cb));
PushSerializedToChannel(true);
} else {
// We stopped but we didn't manage to send the whole stream.
cntx->ReportError(
std::make_error_code(errc::state_not_recoverable),
absl::StrCat("Partial sync was unsuccessful because entry #", lsn,
" was dropped from the buffer. Current lsn=", journal->GetLsn()));
Cancel();
}
});
}
void SliceSnapshot::Stop() {
// Wait for serialization to finish in any case.
Join();
@ -91,8 +140,7 @@ void SliceSnapshot::Cancel() {
void SliceSnapshot::Join() {
// Fiber could have already been joined by Stop.
if (snapshot_fb_.IsJoinable())
snapshot_fb_.Join();
snapshot_fb_.JoinIfNeeded();
}
// The algorithm is to go over all the buckets and serialize those with

View file

@ -61,6 +61,13 @@ class SliceSnapshot {
// In journal streaming mode it needs to be stopped by either Stop or Cancel.
void Start(bool stream_journal, const Cancellation* cll);
// Initialize a snapshot that sends only the missing journal updates
// since start_lsn and then registers a callback switches into the
// journal streaming mode until stopped.
// If we're slower than the buffer and can't continue, `Cancel()` is
// called.
void StartIncremental(Context* cntx, LSN start_lsn);
// Stop snapshot. Only needs to be called for journal streaming mode.
void Stop();

View file

@ -17,16 +17,20 @@ const char* GetVersion();
// Please document for each new entry what the behavior changes are
// and to which released versions this corresponds.
enum class DflyVersion {
// Versions <=1.3
// ver <= 1.3
VER0,
// Versions 1.4<=
// 1.4 <= ver <= 1.10
// - Supports receiving ACKs from replicas
// - Sends version back on REPLCONF capa dragonfly
VER1,
// 1.11 <= ver
// Supports limited partial sync
VER2,
// Always points to the latest version
CURRENT_VER = VER1,
CURRENT_VER = VER2,
};
} // namespace dfly

View file

@ -4,8 +4,9 @@ import subprocess
import aiohttp
import logging
import os
import re
import psutil
from typing import Optional
from typing import Optional, List
from prometheus_client.parser import text_string_to_metric_families
from redis.asyncio import Redis as RedisClient
@ -42,6 +43,7 @@ class DflyInstance:
self.params = params
self.proc: Optional[subprocess.Popen] = None
self._client: Optional[RedisClient] = None
self.log_files: List[str] = []
self.dynamic_port = False
if self.params.existing_port:
@ -55,6 +57,12 @@ class DflyInstance:
self._port = None
self.dynamic_port = True
# Some tests check the log files, so make sure the log files
# exist even when people try to debug their test.
if "logtostderr" in self.params.args:
self.params.args.remove("logtostderr")
self.params.args.append("alsologtostderr")
def __del__(self):
assert self.proc == None
@ -94,6 +102,7 @@ class DflyInstance:
time.sleep(0.05)
else:
raise DflyStartException("Process didn't start listening on port in time")
self.log_files = self.get_logs_from_psutil()
def stop(self, kill=False):
proc, self.proc = self.proc, None
@ -180,6 +189,14 @@ class DflyInstance:
return ports.pop()
raise RuntimeError("Couldn't parse port")
def get_logs_from_psutil(self) -> List[str]:
p = psutil.Process(self.proc.pid)
rv = []
for file in p.open_files():
if ".log." in file.path and "dragonfly" in file.path:
rv.append(file.path)
return rv
@staticmethod
def format_args(args):
out = []
@ -200,6 +217,17 @@ class DflyInstance:
for metric_family in text_string_to_metric_families(data)
}
def is_in_logs(self, pattern):
if self.proc is not None:
raise RuntimeError("Must close server first")
matcher = re.compile(pattern)
for path in self.log_files:
for line in open(path):
if matcher.search(line):
return True
return False
class DflyInstanceFactory:
"""

75
tests/dragonfly/proxy.py Normal file
View file

@ -0,0 +1,75 @@
import asyncio
import random
class Proxy:
def __init__(self, host, port, remote_host, remote_port):
self.host = host
self.port = port
self.remote_host = remote_host
self.remote_port = remote_port
self.stop_connections = []
self.server = None
def __del__(self):
self.close()
async def handle(self, reader, writer):
remote_reader, remote_writer = await asyncio.open_connection(
self.remote_host, self.remote_port
)
async def forward(reader, writer):
while True:
data = await reader.read(1024)
if not data:
break
writer.write(data)
await writer.drain()
writer.close()
task1 = asyncio.ensure_future(forward(reader, remote_writer))
task2 = asyncio.ensure_future(forward(remote_reader, writer))
def cleanup():
task1.cancel()
task2.cancel()
writer.close()
remote_writer.close()
self.stop_connections.append(cleanup)
try:
await asyncio.gather(task1, task2)
except (asyncio.CancelledError, ConnectionResetError):
pass
finally:
cleanup()
if cleanup in self.stop_connections:
self.stop_connections.remove(cleanup)
async def start(self):
self.server = await asyncio.start_server(self.handle, self.host, self.port)
if self.port == 0:
_, port = self.server.sockets[0].getsockname()[:2]
self.port = port
async with self.server:
await self.server.serve_forever()
def drop_connection(self):
"""
Randomally drop one connection
"""
if self.stop_connections:
cb = random.choice(self.stop_connections)
self.stop_connections.remove(cb)
cb()
def close(self):
if self.server is not None:
self.server.close()
self.server = None
for cb in self.stop_connections:
cb()

View file

@ -9,6 +9,7 @@ from .utility import *
from . import DflyInstanceFactory, dfly_args
import pymemcache
import logging
from .proxy import Proxy
ADMIN_PORT = 1211
@ -1364,13 +1365,17 @@ async def test_tls_replication(
# busy wait for 'replica' instance to have replication status 'status'
async def wait_for_replica_status(replica: aioredis.Redis, status: str, wait_for_seconds=0.01):
while True:
async def wait_for_replica_status(
replica: aioredis.Redis, status: str, wait_for_seconds=0.01, timeout=20
):
start = time.time()
while (time.time() - start) < timeout:
await asyncio.sleep(wait_for_seconds)
info = await replica.info("replication")
if info["master_link_status"] == status:
return
raise RuntimeError("Client did not become available in time!")
@pytest.mark.asyncio
@ -1397,7 +1402,8 @@ async def test_replicaof_flag(df_local_factory):
c_replica = aioredis.Redis(port=replica.port)
await wait_available_async(c_replica) # give it time to startup
await wait_for_replica_status(c_replica, status="up") # wait until we have a connection
# wait until we have a connection
await wait_for_replica_status(c_replica, status="up")
await check_all_replicas_finished([c_replica], c_master)
dbsize = await c_replica.dbsize()
@ -1486,7 +1492,7 @@ async def test_replicaof_flag_disconnect(df_local_factory):
val = await c_replica.get("KEY")
assert b"VALUE" == val
await c_replica.replicaof("no", "one") # disconnect
await c_replica.replicaof("no", "one") # disconnect
role = await c_replica.role()
assert role[0] == b"master"
@ -1548,3 +1554,129 @@ async def test_df_crash_on_replicaof_flag(df_local_factory):
master.stop()
replica.stop()
async def test_network_disconnect(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=6)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([replica, master])
seeder = df_seeder_factory.create(port=master.port)
async with replica.client() as c_replica:
await seeder.run(target_deviation=0.1)
proxy = Proxy("localhost", 1111, "localhost", master.port)
task = asyncio.create_task(proxy.start())
await c_replica.execute_command(f"REPLICAOF localhost {proxy.port}")
for _ in range(10):
await asyncio.sleep(random.randint(0, 10) / 10)
proxy.drop_connection()
# Give time to detect dropped connection and reconnect
await asyncio.sleep(1.0)
await wait_for_replica_status(c_replica, status="up")
await wait_available_async(c_replica)
capture = await seeder.capture()
assert await seeder.compare(capture, replica.port)
proxy.close()
try:
await task
except asyncio.exceptions.CancelledError:
pass
master.stop()
replica.stop()
assert replica.is_in_logs("partial sync finished in")
async def test_network_disconnect_active_stream(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=4, shard_repl_backlog_len=10000)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([replica, master])
seeder = df_seeder_factory.create(port=master.port)
async with replica.client() as c_replica, master.client() as c_master:
await seeder.run(target_deviation=0.1)
proxy = Proxy("localhost", 1112, "localhost", master.port)
task = asyncio.create_task(proxy.start())
await c_replica.execute_command(f"REPLICAOF localhost {proxy.port}")
fill_task = asyncio.create_task(seeder.run(target_ops=10000))
for _ in range(3):
await asyncio.sleep(random.randint(10, 20) / 10)
proxy.drop_connection()
seeder.stop()
await fill_task
# Give time to detect dropped connection and reconnect
await asyncio.sleep(1.0)
await wait_for_replica_status(c_replica, status="up")
await wait_available_async(c_replica)
logging.debug(await c_replica.execute_command("INFO REPLICATION"))
logging.debug(await c_master.execute_command("INFO REPLICATION"))
capture = await seeder.capture()
assert await seeder.compare(capture, replica.port)
proxy.close()
try:
await task
except asyncio.exceptions.CancelledError:
pass
master.stop()
replica.stop()
assert replica.is_in_logs("partial sync finished in")
async def test_network_disconnect_small_buffer(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=4, shard_repl_backlog_len=1)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([replica, master])
seeder = df_seeder_factory.create(port=master.port)
async with replica.client() as c_replica, master.client() as c_master:
await seeder.run(target_deviation=0.1)
proxy = Proxy("localhost", 1113, "localhost", master.port)
task = asyncio.create_task(proxy.start())
await c_replica.execute_command(f"REPLICAOF localhost {proxy.port}")
fill_task = asyncio.create_task(seeder.run(target_ops=10000))
for _ in range(3):
await asyncio.sleep(random.randint(5, 10) / 10)
proxy.drop_connection()
seeder.stop()
await fill_task
# Give time to detect dropped connection and reconnect
await asyncio.sleep(1.0)
await wait_for_replica_status(c_replica, status="up")
await wait_available_async(c_replica)
# logging.debug(await c_replica.execute_command("INFO REPLICATION"))
# logging.debug(await c_master.execute_command("INFO REPLICATION"))
capture = await seeder.capture()
assert await seeder.compare(capture, replica.port)
proxy.close()
try:
await task
except asyncio.exceptions.CancelledError:
pass
master.stop()
replica.stop()
assert master.is_in_logs("Partial sync requested from stale LSN")

View file

@ -1,4 +1,5 @@
import itertools
import logging
import sys
import asyncio
from redis import asyncio as aioredis
@ -39,10 +40,11 @@ def batch_fill_data(client, gen, batch_size=100):
client.mset({k: v for k, v, in group})
async def wait_available_async(client: aioredis.Redis):
async def wait_available_async(client: aioredis.Redis, timeout=10):
"""Block until instance exits loading phase"""
its = 0
while True:
start = time.time()
while (time.time() - start) < timeout:
try:
await client.get("key")
return
@ -56,6 +58,7 @@ async def wait_available_async(client: aioredis.Redis):
print("W", end="", flush=True)
await asyncio.sleep(0.01)
its += 1
raise RuntimeError("Client did not become available in time!")
class SizeChange(Enum):
@ -374,7 +377,7 @@ class DflySeeder:
Run a seeding cycle on all dbs either until stop(), a fixed number of commands (target_ops)
or until reaching an allowed deviation from the target number of keys (target_deviation)
"""
print(f"Running ops:{target_ops} deviation:{target_deviation}")
logging.debug(f"Running ops:{target_ops} deviation:{target_deviation}")
self.stop_flag = False
queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)]
producer = asyncio.create_task(
@ -392,7 +395,7 @@ class DflySeeder:
took = time.time() - time_start
qps = round(cmdcount * self.dbcount / took, 2)
print(f"Filling took: {took}, QPS: {qps}")
logging.debug(f"Filling took: {took}, QPS: {qps}")
def stop(self):
"""Stop all invocations to run"""
@ -407,6 +410,7 @@ class DflySeeder:
if port is None:
port = self.port
logging.debug(f"Starting capture from {port=}")
keys = sorted(list(self.gen.keys_and_types()))
captures = await asyncio.gather(