feat(replication): First iteration on partial sync. (#1836)

First iteration on partial sync.
This commit is contained in:
Roy Jacobson 2023-09-26 10:35:50 +03:00 committed by GitHub
parent d9f4ca8003
commit d50b492e1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 524 additions and 86 deletions

View file

@ -115,7 +115,7 @@ GenericError RdbSnapshot::Start(SaveMode save_mode, const std::string& path,
} }
error_code RdbSnapshot::SaveBody() { error_code RdbSnapshot::SaveBody() {
return saver_->SaveBody(&cll_, &freq_map_); return saver_->SaveBody(&cntx_, &freq_map_);
} }
error_code RdbSnapshot::Close() { error_code RdbSnapshot::Close() {
@ -126,7 +126,7 @@ error_code RdbSnapshot::Close() {
} }
void RdbSnapshot::StartInShard(EngineShard* shard) { void RdbSnapshot::StartInShard(EngineShard* shard) {
saver_->StartSnapshotInShard(false, &cll_, shard); saver_->StartSnapshotInShard(false, cntx_.GetCancellation(), shard);
started_ = true; started_ = true;
} }

View file

@ -63,7 +63,7 @@ class RdbSnapshot {
unique_ptr<RdbSaver> saver_; unique_ptr<RdbSaver> saver_;
RdbTypeFreqMap freq_map_; RdbTypeFreqMap freq_map_;
Cancellation cll_{}; Context cntx_{};
}; };
struct SaveStagesController : public SaveStagesInputs { struct SaveStagesController : public SaveStagesInputs {

View file

@ -8,9 +8,11 @@
#include <absl/strings/strip.h> #include <absl/strings/strip.h>
#include <limits> #include <limits>
#include <memory>
#include <optional> #include <optional>
#include <utility> #include <utility>
#include "absl/strings/numbers.h"
#include "base/flags.h" #include "base/flags.h"
#include "base/logging.h" #include "base/logging.h"
#include "facade/dragonfly_connection.h" #include "facade/dragonfly_connection.h"
@ -114,7 +116,7 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
return Thread(args, cntx); return Thread(args, cntx);
} }
if (sub_cmd == "FLOW" && args.size() == 4) { if (sub_cmd == "FLOW" && (args.size() == 4 || args.size() == 5)) {
return Flow(args, cntx); return Flow(args, cntx);
} }
@ -241,7 +243,16 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
string_view sync_id_str = ArgS(args, 2); string_view sync_id_str = ArgS(args, 2);
string_view flow_id_str = ArgS(args, 3); string_view flow_id_str = ArgS(args, 3);
VLOG(1) << "Got DFLY FLOW " << master_id << " " << sync_id_str << " " << flow_id_str; std::optional<LSN> seqid;
if (args.size() == 5) {
seqid.emplace();
if (!absl::SimpleAtoi(ArgS(args, 4), &seqid.value())) {
return rb->SendError(facade::kInvalidIntErr);
}
}
VLOG(1) << "Got DFLY FLOW master_id: " << master_id << " sync_id: " << sync_id_str
<< " flow: " << flow_id_str << " seq: " << seqid.value_or(-1);
if (master_id != sf_->master_id()) { if (master_id != sf_->master_id()) {
return rb->SendError(kBadMasterId); return rb->SendError(kBadMasterId);
@ -268,13 +279,33 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
absl::InsecureBitGen gen; absl::InsecureBitGen gen;
string eof_token = GetRandomHex(gen, 40); string eof_token = GetRandomHex(gen, 40);
cntx->replication_flow = &replica_ptr->flows[flow_id]; auto& flow = replica_ptr->flows[flow_id];
replica_ptr->flows[flow_id].conn = cntx->owner(); cntx->replication_flow = &flow;
replica_ptr->flows[flow_id].eof_token = eof_token; flow.conn = cntx->owner();
flow.eof_token = eof_token;
flow.version = replica_ptr->version;
cntx->owner()->Migrate(shard_set->pool()->at(flow_id)); cntx->owner()->Migrate(shard_set->pool()->at(flow_id));
sf_->journal()->StartInThread();
std::string_view sync_type = "FULL";
if (seqid.has_value()) {
if (sf_->journal()->IsLSNInBuffer(*seqid) || sf_->journal()->GetLsn() == *seqid) {
flow.start_partial_sync_at = *seqid;
VLOG(1) << "Partial sync requested from LSN=" << flow.start_partial_sync_at.value()
<< " and is available. (current_lsn=" << sf_->journal()->GetLsn() << ")";
sync_type = "PARTIAL";
} else {
LOG(INFO) << "Partial sync requested from stale LSN=" << *seqid
<< " that the replication buffer doesn't contain this anymore (current_lsn="
<< sf_->journal()->GetLsn() << "). Will perform a full sync of the data.";
LOG(INFO) << "If this happens often you can control the replication buffer's size with the "
"--shard_repl_backlog_len option";
}
}
rb->StartArray(2); rb->StartArray(2);
rb->SendSimpleString("FULL"); rb->SendSimpleString(sync_type);
rb->SendSimpleString(eof_token); rb->SendSimpleString(eof_token);
} }
@ -309,7 +340,7 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) {
return rb->SendError(kInvalidState); return rb->SendError(kInvalidState);
} }
LOG(INFO) << "Started full sync with replica " << replica_ptr->address << ":" LOG(INFO) << "Started sync with replica " << replica_ptr->address << ":"
<< replica_ptr->listening_port; << replica_ptr->listening_port;
replica_ptr->state.store(SyncState::FULL_SYNC, memory_order_relaxed); replica_ptr->state.store(SyncState::FULL_SYNC, memory_order_relaxed);
@ -468,7 +499,7 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
// of the flows also contain them. // of the flows also contain them.
SaveMode save_mode = SaveMode save_mode =
shard->shard_id() == 0 ? SaveMode::SINGLE_SHARD_WITH_SUMMARY : SaveMode::SINGLE_SHARD; shard->shard_id() == 0 ? SaveMode::SINGLE_SHARD_WITH_SUMMARY : SaveMode::SINGLE_SHARD;
flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false)); flow->saver = std::make_unique<RdbSaver>(flow->conn->socket(), save_mode, false);
flow->cleanup = [flow]() { flow->cleanup = [flow]() {
flow->saver->Cancel(); flow->saver->Cancel();
@ -477,11 +508,12 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
flow->saver.reset(); flow->saver.reset();
}; };
sf_->journal()->StartInThread();
// Shard can be null for io thread. // Shard can be null for io thread.
if (shard != nullptr) { if (shard != nullptr) {
flow->saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard); if (flow->start_partial_sync_at.has_value())
flow->saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
else
flow->saver->StartSnapshotInShard(true, cntx->GetCancellation(), shard);
} }
flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx); flow->full_sync_fb = fb2::Fiber("full_sync", &DflyCmd::FullSyncFb, this, flow, cntx);
@ -532,7 +564,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
// Always send original body (with header & without auto async calls) that determines the sha, // Always send original body (with header & without auto async calls) that determines the sha,
// It's stored only if it's different from the post-processed version. // It's stored only if it's different from the post-processed version.
string& body = data.orig_body.empty() ? data.body : data.orig_body; string& body = data.orig_body.empty() ? data.body : data.orig_body;
script_bodies.push_back(move(body)); script_bodies.push_back(std::move(body));
} }
ec = saver->SaveHeader({script_bodies, {}}); ec = saver->SaveHeader({script_bodies, {}});
} else { } else {
@ -544,7 +576,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) {
return; return;
} }
if (ec = saver->SaveBody(cntx->GetCancellation(), nullptr); ec) { if (ec = saver->SaveBody(cntx, nullptr); ec) {
cntx->ReportError(ec); cntx->ReportError(ec);
return; return;
} }
@ -711,6 +743,14 @@ std::map<uint32_t, LSN> DflyCmd::ReplicationLags() const {
return rv; return rv;
} }
void DflyCmd::SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version) {
auto replica_ptr = GetReplicaInfo(cntx->conn_state.replication_info.repl_session_id);
VLOG(1) << "Client version for session_id=" << cntx->conn_state.replication_info.repl_session_id
<< " is " << int(version);
replica_ptr->version = version;
}
bool DflyCmd::CheckReplicaStateOrReply(const ReplicaInfo& sync_info, SyncState expected, bool DflyCmd::CheckReplicaStateOrReply(const ReplicaInfo& sync_info, SyncState expected,
RedisReplyBuilder* rb) { RedisReplyBuilder* rb) {
if (sync_info.state != expected) { if (sync_info.state != expected) {

View file

@ -40,7 +40,9 @@ struct FlowInfo {
std::unique_ptr<RdbSaver> saver; // Saver used by the full sync phase. std::unique_ptr<RdbSaver> saver; // Saver used by the full sync phase.
std::unique_ptr<JournalStreamer> streamer; std::unique_ptr<JournalStreamer> streamer;
std::string eof_token; std::string eof_token;
DflyVersion version;
std::optional<LSN> start_partial_sync_at;
uint64_t last_acked_lsn; uint64_t last_acked_lsn;
std::function<void()> cleanup; // Optional cleanup for cancellation. std::function<void()> cleanup; // Optional cleanup for cancellation.
@ -99,8 +101,11 @@ class DflyCmd {
struct ReplicaInfo { struct ReplicaInfo {
ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port, ReplicaInfo(unsigned flow_count, std::string address, uint32_t listening_port,
Context::ErrHandler err_handler) Context::ErrHandler err_handler)
: state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, address{std::move(address)}, : state{SyncState::PREPARATION},
listening_port(listening_port), flows{flow_count} { cntx{std::move(err_handler)},
address{std::move(address)},
listening_port(listening_port),
flows{flow_count} {
} }
std::atomic<SyncState> state; std::atomic<SyncState> state;
@ -108,6 +113,7 @@ class DflyCmd {
std::string address; std::string address;
uint32_t listening_port; uint32_t listening_port;
DflyVersion version = DflyVersion::VER0;
std::vector<FlowInfo> flows; std::vector<FlowInfo> flows;
Mutex mu; // See top of header for locking levels. Mutex mu; // See top of header for locking levels.
@ -130,6 +136,9 @@ class DflyCmd {
std::vector<ReplicaRoleInfo> GetReplicasRoleInfo(); std::vector<ReplicaRoleInfo> GetReplicasRoleInfo();
// Sets metadata.
void SetDflyClientVersion(ConnectionContext* cntx, DflyVersion version);
private: private:
// JOURNAL [START/STOP] // JOURNAL [START/STOP]
// Start or stop journaling. // Start or stop journaling.
@ -139,8 +148,11 @@ class DflyCmd {
// Return connection thread index or migrate to another thread. // Return connection thread index or migrate to another thread.
void Thread(CmdArgList args, ConnectionContext* cntx); void Thread(CmdArgList args, ConnectionContext* cntx);
// FLOW <masterid> <syncid> <flowid> // FLOW <masterid> <syncid> <flowid> [<seqid>]
// Register connection as flow for sync session. // Register connection as flow for sync session.
// If seqid is given, it means the client wants to try partial sync.
// If it is possible, return Ok and prepare for a partial sync, else
// return error and ask the replica to execute FLOW again.
void Flow(CmdArgList args, ConnectionContext* cntx); void Flow(CmdArgList args, ConnectionContext* cntx);
// SYNC <syncid> // SYNC <syncid>

View file

@ -20,15 +20,17 @@ using facade::kWrongTypeErr;
#ifndef RETURN_ON_ERR #ifndef RETURN_ON_ERR
#define RETURN_ON_ERR(x) \ #define RETURN_ON_ERR_T(T, x) \
do { \ do { \
std::error_code __ec = (x); \ std::error_code __ec = (x); \
if (__ec) { \ if (__ec) { \
DLOG(ERROR) << "Error " << __ec << " while calling " #x; \ DLOG(ERROR) << "Error " << __ec << " while calling " #x; \
return __ec; \ return (T)(__ec); \
} \ } \
} while (0) } while (0)
#define RETURN_ON_ERR(x) RETURN_ON_ERR_T(std::error_code, x)
#endif // RETURN_ON_ERR #endif // RETURN_ON_ERR
namespace rdb { namespace rdb {

View file

@ -16,7 +16,7 @@
#include "base/logging.h" #include "base/logging.h"
#include "server/journal/serializer.h" #include "server/journal/serializer.h"
ABSL_FLAG(int, shard_repl_backlog_len, 1 << 10, ABSL_FLAG(uint32_t, shard_repl_backlog_len, 1 << 10,
"The length of the circular replication log per shard"); "The length of the circular replication log per shard");
namespace dfly { namespace dfly {
@ -33,6 +33,14 @@ string ShardName(std::string_view base, unsigned index) {
} }
*/ */
uint32_t NextPowerOf2(uint32_t x) {
if (x < 2) {
return 1;
}
int log = 32 - __builtin_clz(x - 1);
return 1 << log;
}
} // namespace } // namespace
#define CHECK_EC(x) \ #define CHECK_EC(x) \
@ -53,7 +61,7 @@ void JournalSlice::Init(unsigned index) {
return; return;
slice_index_ = index; slice_index_ = index;
ring_buffer_.emplace(absl::GetFlag(FLAGS_shard_repl_backlog_len)); ring_buffer_.emplace(NextPowerOf2(absl::GetFlag(FLAGS_shard_repl_backlog_len)));
} }
#if 0 #if 0
@ -144,6 +152,7 @@ void JournalSlice::AddLogRecord(const Entry& entry, bool await) {
item->opcode = entry.opcode; item->opcode = entry.opcode;
item->data = ""; item->data = "";
} else { } else {
FiberAtomicGuard fg;
// GetTail gives a pointer to a new tail entry in the buffer, possibly overriding the last entry // GetTail gives a pointer to a new tail entry in the buffer, possibly overriding the last entry
// if the buffer is full. // if the buffer is full.
item = ring_buffer_->GetTail(true); item = ring_buffer_->GetTail(true);

View file

@ -148,10 +148,13 @@ class ProtocolClient {
/** /**
* A convenience macro to use with ProtocolClient instances for protocol input validation. * A convenience macro to use with ProtocolClient instances for protocol input validation.
*/ */
#define PC_RETURN_ON_BAD_RESPONSE(x) \ #define PC_RETURN_ON_BAD_RESPONSE_T(T, x) \
do { \ do { \
if (!(x)) { \ if (!(x)) { \
LOG(ERROR) << "Bad response to \"" << last_cmd_ << "\": \"" << absl::CEscape(last_resp_); \ LOG(ERROR) << "Bad response to \"" << last_cmd_ << "\": \"" << absl::CEscape(last_resp_) \
return std::make_error_code(errc::bad_message); \ << "\""; \
} \ return (T)(std::make_error_code(errc::bad_message)); \
} \
} while (false) } while (false)
#define PC_RETURN_ON_BAD_RESPONSE(x) PC_RETURN_ON_BAD_RESPONSE_T(std::error_code, x)

View file

@ -1843,6 +1843,7 @@ error_code RdbLoader::Load(io::Source* src) {
auto cb = mem_buf_->InputBuffer(); auto cb = mem_buf_->InputBuffer();
if (memcmp(cb.data(), "REDIS", 5) != 0) { if (memcmp(cb.data(), "REDIS", 5) != 0) {
VLOG(1) << "Bad header: " << absl::CHexEscape(facade::ToSV(cb));
return RdbError(errc::wrong_signature); return RdbError(errc::wrong_signature);
} }

View file

@ -189,8 +189,8 @@ class RdbLoader : protected RdbLoaderBase {
// Return the offset that was received with a RDB_OPCODE_JOURNAL_OFFSET command, // Return the offset that was received with a RDB_OPCODE_JOURNAL_OFFSET command,
// or 0 if no offset was received. // or 0 if no offset was received.
uint64_t journal_offset() const { std::optional<uint64_t> journal_offset() const {
return journal_offset_.value_or(0); return journal_offset_;
} }
// Set callback for receiving RDB_OPCODE_FULLSYNC_END. // Set callback for receiving RDB_OPCODE_FULLSYNC_END.

View file

@ -673,6 +673,7 @@ error_code RdbSerializer::SaveStreamConsumers(streamCG* cg) {
} }
error_code RdbSerializer::SendJournalOffset(uint64_t journal_offset) { error_code RdbSerializer::SendJournalOffset(uint64_t journal_offset) {
VLOG(2) << "SendJournalOffset";
RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_JOURNAL_OFFSET)); RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_JOURNAL_OFFSET));
uint8_t buf[sizeof(uint64_t)]; uint8_t buf[sizeof(uint64_t)];
absl::little_endian::Store64(buf, journal_offset); absl::little_endian::Store64(buf, journal_offset);
@ -680,6 +681,7 @@ error_code RdbSerializer::SendJournalOffset(uint64_t journal_offset) {
} }
error_code RdbSerializer::SendFullSyncCut() { error_code RdbSerializer::SendFullSyncCut() {
VLOG(2) << "SendFullSyncCut";
RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_FULLSYNC_END)); RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_FULLSYNC_END));
// RDB_OPCODE_FULLSYNC_END followed by 8 bytes of 0. // RDB_OPCODE_FULLSYNC_END followed by 8 bytes of 0.
@ -734,6 +736,7 @@ io::Bytes RdbSerializer::PrepareFlush() {
} }
error_code RdbSerializer::WriteJournalEntry(std::string_view serialized_entry) { error_code RdbSerializer::WriteJournalEntry(std::string_view serialized_entry) {
VLOG(2) << "WriteJournalEntry";
RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_JOURNAL_BLOB)); RETURN_ON_ERR(WriteOpcode(RDB_OPCODE_JOURNAL_BLOB));
RETURN_ON_ERR(SaveLen(1)); RETURN_ON_ERR(SaveLen(1));
RETURN_ON_ERR(SaveString(serialized_entry)); RETURN_ON_ERR(SaveString(serialized_entry));
@ -893,6 +896,7 @@ class RdbSaver::Impl {
SaveMode save_mode, io::Sink* sink); SaveMode save_mode, io::Sink* sink);
void StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard); void StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard);
void StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, LSN start_lsn);
void StopSnapshotting(EngineShard* shard); void StopSnapshotting(EngineShard* shard);
@ -1053,11 +1057,19 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll, void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll,
EngineShard* shard) { EngineShard* shard) {
auto& s = GetSnapshot(shard); auto& s = GetSnapshot(shard);
s.reset(new SliceSnapshot(&shard->db_slice(), &channel_, compression_mode_)); s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_);
s->Start(stream_journal, cll); s->Start(stream_journal, cll);
} }
void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard,
LSN start_lsn) {
auto& s = GetSnapshot(shard);
s = std::make_unique<SliceSnapshot>(&shard->db_slice(), &channel_, compression_mode_);
s->StartIncremental(cntx, start_lsn);
}
void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) { void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) {
GetSnapshot(shard)->Stop(); GetSnapshot(shard)->Stop();
} }
@ -1142,6 +1154,10 @@ void RdbSaver::StartSnapshotInShard(bool stream_journal, const Cancellation* cll
impl_->StartSnapshotting(stream_journal, cll, shard); impl_->StartSnapshotting(stream_journal, cll, shard);
} }
void RdbSaver::StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn) {
impl_->StartIncrementalSnapshotting(cntx, shard, start_lsn);
}
void RdbSaver::StopSnapshotInShard(EngineShard* shard) { void RdbSaver::StopSnapshotInShard(EngineShard* shard) {
impl_->StopSnapshotting(shard); impl_->StopSnapshotting(shard);
} }
@ -1159,18 +1175,21 @@ error_code RdbSaver::SaveHeader(const GlobalData& glob_state) {
return error_code{}; return error_code{};
} }
error_code RdbSaver::SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map) { error_code RdbSaver::SaveBody(Context* cntx, RdbTypeFreqMap* freq_map) {
RETURN_ON_ERR(impl_->serializer()->FlushToSink(impl_->sink())); RETURN_ON_ERR(impl_->serializer()->FlushToSink(impl_->sink()));
if (save_mode_ == SaveMode::SUMMARY) { if (save_mode_ == SaveMode::SUMMARY) {
impl_->serializer()->SendFullSyncCut(); impl_->serializer()->SendFullSyncCut();
} else { } else {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size(); VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cll); error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
if (io_error) { if (io_error) {
LOG(ERROR) << "io error " << io_error; LOG(ERROR) << "io error " << io_error;
return io_error; return io_error;
} }
if (cntx->GetError()) {
return cntx->GetError();
}
} }
RETURN_ON_ERR(SaveEpilog()); RETURN_ON_ERR(SaveEpilog());

View file

@ -88,6 +88,9 @@ class RdbSaver {
// TODO: to implement break functionality to allow stopping early. // TODO: to implement break functionality to allow stopping early.
void StartSnapshotInShard(bool stream_journal, const Cancellation* cll, EngineShard* shard); void StartSnapshotInShard(bool stream_journal, const Cancellation* cll, EngineShard* shard);
// Send only the incremental snapshot since start_lsn.
void StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn);
// Stops serialization in journal streaming mode in the shard's thread. // Stops serialization in journal streaming mode in the shard's thread.
void StopSnapshotInShard(EngineShard* shard); void StopSnapshotInShard(EngineShard* shard);
@ -97,7 +100,7 @@ class RdbSaver {
// Writes the RDB file into sink. Waits for the serialization to finish. // Writes the RDB file into sink. Waits for the serialization to finish.
// Fills freq_map with the histogram of rdb types. // Fills freq_map with the histogram of rdb types.
// freq_map can optionally be null. // freq_map can optionally be null.
std::error_code SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map); std::error_code SaveBody(Context* cntx, RdbTypeFreqMap* freq_map);
void Cancel(); void Cancel();
@ -122,7 +125,7 @@ class CompressorImpl;
class RdbSerializer { class RdbSerializer {
public: public:
RdbSerializer(CompressionMode compression_mode); explicit RdbSerializer(CompressionMode compression_mode);
~RdbSerializer(); ~RdbSerializer();

View file

@ -290,6 +290,10 @@ std::error_code Replica::HandleCapaDflyResp() {
return make_error_code(errc::bad_message); return make_error_code(errc::bad_message);
} }
// If we're syncing a different replication ID, drop the saved LSNs.
if (master_context_.master_repl_id != ToSV(LastResponseArgs()[0].GetBuf())) {
last_journal_LSNs_.reset();
}
master_context_.master_repl_id = ToSV(LastResponseArgs()[0].GetBuf()); master_context_.master_repl_id = ToSV(LastResponseArgs()[0].GetBuf());
master_context_.dfly_session_id = ToSV(LastResponseArgs()[1].GetBuf()); master_context_.dfly_session_id = ToSV(LastResponseArgs()[1].GetBuf());
num_df_flows_ = param_num_flows; num_df_flows_ = param_num_flows;
@ -414,9 +418,10 @@ error_code Replica::InitiateDflySync() {
absl::Cleanup cleanup = [this]() { absl::Cleanup cleanup = [this]() {
// We do the following operations regardless of outcome. // We do the following operations regardless of outcome.
JoinAllFlows(); JoinDflyFlows();
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
state_mask_.fetch_and(~R_SYNCING); state_mask_.fetch_and(~R_SYNCING);
last_journal_LSNs_.reset();
}; };
// Initialize MultiShardExecution. // Initialize MultiShardExecution.
@ -450,25 +455,43 @@ error_code Replica::InitiateDflySync() {
// Make sure we're in LOADING state. // Make sure we're in LOADING state.
CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING); CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING);
// Flush dbs.
JournalExecutor{&service_}.FlushAll();
// Start full sync flows. // Start full sync flows.
state_mask_.fetch_or(R_SYNCING); state_mask_.fetch_or(R_SYNCING);
std::string_view sync_type = "full";
{ {
// Going out of the way to avoid using std::vector<bool>...
auto is_full_sync = std::make_unique<bool[]>(num_df_flows_);
auto partition = Partition(num_df_flows_); auto partition = Partition(num_df_flows_);
CHECK(!last_journal_LSNs_ || last_journal_LSNs_->size() == shard_flows_.size());
auto shard_cb = [&](unsigned index, auto*) { auto shard_cb = [&](unsigned index, auto*) {
for (auto id : partition[index]) { for (auto id : partition[index]) {
auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block, &cntx_); auto ec = shard_flows_[id]->StartSyncFlow(sync_block, &cntx_,
if (ec) last_journal_LSNs_.has_value()
cntx_.ReportError(ec); ? std::optional((*last_journal_LSNs_)[id])
: std::nullopt);
if (ec.has_value())
is_full_sync[id] = ec.value();
else
cntx_.ReportError(ec.error());
} }
}; };
// Lock to prevent the error handler from running instantly // Lock to prevent the error handler from running instantly
// while the flows are in a mixed state. // while the flows are in a mixed state.
lock_guard lk{flows_op_mu_}; lock_guard lk{flows_op_mu_};
shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb)); shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb));
size_t num_full_flows =
std::accumulate(is_full_sync.get(), is_full_sync.get() + num_df_flows_, 0);
if (num_full_flows == num_df_flows_) {
JournalExecutor{&service_}.FlushAll();
} else if (num_full_flows == 0) {
sync_type = "partial";
} else {
last_journal_LSNs_.reset();
cntx_.ReportError(std::make_error_code(errc::state_not_recoverable),
"Won't do a partial sync: some flows must fully resync");
}
} }
RETURN_ON_ERR(cntx_.GetError()); RETURN_ON_ERR(cntx_.GetError());
@ -478,7 +501,7 @@ error_code Replica::InitiateDflySync() {
return cntx_.ReportError(ec); return cntx_.ReportError(ec);
} }
LOG(INFO) << absl::StrCat("Started full sync with ", server().Description()); LOG(INFO) << "Started " << sync_type << " sync with " << server().Description();
// Wait for all flows to receive full sync cut. // Wait for all flows to receive full sync cut.
// In case of an error, this is unblocked by the error handler. // In case of an error, this is unblocked by the error handler.
@ -497,7 +520,7 @@ error_code Replica::InitiateDflySync() {
// Joining flows and resetting state is done by cleanup. // Joining flows and resetting state is done by cleanup.
double seconds = double(absl::ToInt64Milliseconds(absl::Now() - start_time)) / 1000; double seconds = double(absl::ToInt64Milliseconds(absl::Now() - start_time)) / 1000;
LOG(INFO) << "Full sync finished in " << strings::HumanReadableElapsedTime(seconds); LOG(INFO) << sync_type << " sync finished in " << strings::HumanReadableElapsedTime(seconds);
return cntx_.GetError(); return cntx_.GetError();
} }
@ -590,7 +613,12 @@ error_code Replica::ConsumeDflyStream() {
lock_guard lk{flows_op_mu_}; lock_guard lk{flows_op_mu_};
shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb)); shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb));
} }
JoinAllFlows(); JoinDflyFlows();
last_journal_LSNs_.emplace();
for (auto& flow : shard_flows_) {
last_journal_LSNs_->push_back(flow->JournalExecutedCount());
}
LOG(INFO) << "Exit stable sync"; LOG(INFO) << "Exit stable sync";
// The only option to unblock is to cancel the context. // The only option to unblock is to cancel the context.
@ -599,7 +627,7 @@ error_code Replica::ConsumeDflyStream() {
return cntx_.GetError(); return cntx_.GetError();
} }
void Replica::JoinAllFlows() { void Replica::JoinDflyFlows() {
for (auto& flow : shard_flows_) { for (auto& flow : shard_flows_) {
flow->JoinFlow(); flow->JoinFlow();
} }
@ -625,30 +653,41 @@ error_code Replica::SendNextPhaseRequest(string_view kind) {
return std::error_code{}; return std::error_code{};
} }
error_code DflyShardReplica::StartFullSyncFlow(BlockingCounter sb, Context* cntx) { io::Result<bool> DflyShardReplica::StartSyncFlow(BlockingCounter sb, Context* cntx,
std::optional<LSN> lsn) {
using nonstd::make_unexpected;
DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty()); DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty());
RETURN_ON_ERR(ConnectAndAuth(absl::GetFlag(FLAGS_master_connect_timeout_ms) * 1ms, &cntx_)); RETURN_ON_ERR_T(make_unexpected,
ConnectAndAuth(absl::GetFlag(FLAGS_master_connect_timeout_ms) * 1ms, &cntx_));
VLOG(1) << "Sending on flow " << master_context_.master_repl_id << " " VLOG(1) << "Sending on flow " << master_context_.master_repl_id << " "
<< master_context_.dfly_session_id << " " << flow_id_; << master_context_.dfly_session_id << " " << flow_id_;
auto cmd = StrCat("DFLY FLOW ", master_context_.master_repl_id, " ", std::string cmd = StrCat("DFLY FLOW ", master_context_.master_repl_id, " ",
master_context_.dfly_session_id, " ", flow_id_); master_context_.dfly_session_id, " ", flow_id_);
// Try to negotiate a partial sync if possible.
if (lsn.has_value() && master_context_.version > DflyVersion::VER1) {
absl::StrAppend(&cmd, " ", *lsn);
}
ResetParser(/*server_mode=*/false); ResetParser(/*server_mode=*/false);
leftover_buf_.emplace(128); leftover_buf_.emplace(128);
RETURN_ON_ERR(SendCommand(cmd)); RETURN_ON_ERR_T(make_unexpected, SendCommand(cmd));
auto read_resp = ReadRespReply(&*leftover_buf_); auto read_resp = ReadRespReply(&*leftover_buf_);
if (!read_resp.has_value()) { if (!read_resp.has_value()) {
return read_resp.error(); return make_unexpected(read_resp.error());
} }
PC_RETURN_ON_BAD_RESPONSE(CheckRespFirstTypes({RespExpr::STRING, RespExpr::STRING})); PC_RETURN_ON_BAD_RESPONSE_T(make_unexpected,
CheckRespFirstTypes({RespExpr::STRING, RespExpr::STRING}));
string_view flow_directive = ToSV(LastResponseArgs()[0].GetBuf()); string_view flow_directive = ToSV(LastResponseArgs()[0].GetBuf());
string eof_token; string eof_token;
PC_RETURN_ON_BAD_RESPONSE(flow_directive == "FULL"); PC_RETURN_ON_BAD_RESPONSE_T(make_unexpected,
flow_directive == "FULL" || flow_directive == "PARTIAL");
bool is_full_sync = flow_directive == "FULL";
eof_token = ToSV(LastResponseArgs()[1].GetBuf()); eof_token = ToSV(LastResponseArgs()[1].GetBuf());
leftover_buf_->ConsumeInput(read_resp->left_in_buffer); leftover_buf_->ConsumeInput(read_resp->left_in_buffer);
@ -658,7 +697,7 @@ error_code DflyShardReplica::StartFullSyncFlow(BlockingCounter sb, Context* cntx
sync_fb_ = fb2::Fiber("shard_full_sync", &DflyShardReplica::FullSyncDflyFb, this, sync_fb_ = fb2::Fiber("shard_full_sync", &DflyShardReplica::FullSyncDflyFb, this,
std::move(eof_token), sb, cntx); std::move(eof_token), sb, cntx);
return error_code{}; return is_full_sync;
} }
error_code DflyShardReplica::StartStableSyncFlow(Context* cntx) { error_code DflyShardReplica::StartStableSyncFlow(Context* cntx) {
@ -680,7 +719,7 @@ error_code DflyShardReplica::StartStableSyncFlow(Context* cntx) {
return std::error_code{}; return std::error_code{};
} }
void DflyShardReplica::FullSyncDflyFb(const string& eof_token, BlockingCounter bc, Context* cntx) { void DflyShardReplica::FullSyncDflyFb(std::string eof_token, BlockingCounter bc, Context* cntx) {
DCHECK(leftover_buf_); DCHECK(leftover_buf_);
io::PrefixSource ps{leftover_buf_->InputBuffer(), Sock()}; io::PrefixSource ps{leftover_buf_->InputBuffer(), Sock()};
@ -722,7 +761,13 @@ void DflyShardReplica::FullSyncDflyFb(const string& eof_token, BlockingCounter b
leftover_buf_.reset(); leftover_buf_.reset();
} }
this->journal_rec_executed_.store(loader.journal_offset()); if (auto jo = loader.journal_offset(); jo.has_value()) {
this->journal_rec_executed_.store(*jo);
} else {
if (master_context_.version > DflyVersion::VER0)
cntx->ReportError(std::make_error_code(errc::protocol_error),
"Error finding journal offset in stream");
}
VLOG(1) << "FullSyncDflyFb finished after reading " << loader.bytes_read() << " bytes"; VLOG(1) << "FullSyncDflyFb finished after reading " << loader.bytes_read() << " bytes";
} }
@ -855,10 +900,11 @@ void DflyShardReplica::ExecuteTxWithNoShardSync(TransactionData&& tx_data, Conte
} }
bool DflyShardReplica::InsertTxToSharedMap(const TransactionData& tx_data) { bool DflyShardReplica::InsertTxToSharedMap(const TransactionData& tx_data) {
std::lock_guard lk{multi_shard_exe_->map_mu}; std::unique_lock lk(multi_shard_exe_->map_mu);
auto [it, was_insert] = auto [it, was_insert] =
multi_shard_exe_->tx_sync_execution.emplace(tx_data.txid, tx_data.shard_cnt); multi_shard_exe_->tx_sync_execution.emplace(tx_data.txid, tx_data.shard_cnt);
lk.unlock();
VLOG(2) << "txid: " << tx_data.txid << " unique_shard_cnt_: " << tx_data.shard_cnt VLOG(2) << "txid: " << tx_data.txid << " unique_shard_cnt_: " << tx_data.shard_cnt
<< " was_insert: " << was_insert; << " was_insert: " << was_insert;
it->second.block.Dec(); it->second.block.Dec();
@ -902,11 +948,11 @@ void DflyShardReplica::ExecuteTx(TransactionData&& tx_data, bool inserted_by_me,
} }
VLOG(2) << "Execute txid: " << tx_data.txid; VLOG(2) << "Execute txid: " << tx_data.txid;
multi_shard_exe_->map_mu.lock(); std::unique_lock lk(multi_shard_exe_->map_mu);
auto it = multi_shard_exe_->tx_sync_execution.find(tx_data.txid); auto it = multi_shard_exe_->tx_sync_execution.find(tx_data.txid);
DCHECK(it != multi_shard_exe_->tx_sync_execution.end()); DCHECK(it != multi_shard_exe_->tx_sync_execution.end());
auto& multi_shard_data = it->second; auto& multi_shard_data = it->second;
multi_shard_exe_->map_mu.unlock(); lk.unlock();
VLOG(2) << "Execute txid: " << tx_data.txid << " waiting for data in all shards"; VLOG(2) << "Execute txid: " << tx_data.txid << " waiting for data in all shards";
// Wait until shards flows got transaction data and inserted to map. // Wait until shards flows got transaction data and inserted to map.

View file

@ -109,7 +109,10 @@ class Replica : ProtocolClient {
void RedisStreamAcksFb(); void RedisStreamAcksFb();
void JoinAllFlows(); // Join all flows if possible. // Joins all the flows when doing sharded replication. This is called in two
// places: Once at the end of full sync to join the full sync fibers, and twice
// if a stable sync is interrupted to join the cancelled stable sync fibers.
void JoinDflyFlows();
void SetShardStates(bool replica); // Call SetReplica(replica) on all shards. void SetShardStates(bool replica); // Call SetReplica(replica) on all shards.
// Send DFLY ${kind} to the master instance. // Send DFLY ${kind} to the master instance.
@ -164,6 +167,9 @@ class Replica : ProtocolClient {
EventCount waker_; EventCount waker_;
std::vector<std::unique_ptr<DflyShardReplica>> shard_flows_; std::vector<std::unique_ptr<DflyShardReplica>> shard_flows_;
// A vector of the last executer LSNs when a replication is interrupted.
// Allows partial sync on reconnects.
std::optional<std::vector<LSN>> last_journal_LSNs_;
std::shared_ptr<MultiShardExecution> multi_shard_exe_; std::shared_ptr<MultiShardExecution> multi_shard_exe_;
// Guard operations where flows might be in a mixed state (transition/setup) // Guard operations where flows might be in a mixed state (transition/setup)
@ -221,13 +227,14 @@ class DflyShardReplica : public ProtocolClient {
void JoinFlow(); void JoinFlow();
// Start replica initialized as dfly flow. // Start replica initialized as dfly flow.
std::error_code StartFullSyncFlow(BlockingCounter block, Context* cntx); // Sets is_full_sync when successful.
io::Result<bool> StartSyncFlow(BlockingCounter block, Context* cntx, std::optional<LSN>);
// Transition into stable state mode as dfly flow. // Transition into stable state mode as dfly flow.
std::error_code StartStableSyncFlow(Context* cntx); std::error_code StartStableSyncFlow(Context* cntx);
// Single flow full sync fiber spawned by StartFullSyncFlow. // Single flow full sync fiber spawned by StartFullSyncFlow.
void FullSyncDflyFb(const std::string& eof_token, BlockingCounter block, Context* cntx); void FullSyncDflyFb(std::string eof_token, BlockingCounter block, Context* cntx);
// Single flow stable state sync fiber spawned by StartStableSyncFlow. // Single flow stable state sync fiber spawned by StartStableSyncFlow.
void StableSyncDflyReadFb(Context* cntx); void StableSyncDflyReadFb(Context* cntx);

View file

@ -1827,9 +1827,7 @@ void ServerFamily::ReplConf(CmdArgList args, ConnectionContext* cntx) {
if (!absl::SimpleAtoi(arg, &version)) { if (!absl::SimpleAtoi(arg, &version)) {
return (*cntx)->SendError(kInvalidIntErr); return (*cntx)->SendError(kInvalidIntErr);
} }
VLOG(1) << "Client version for session_id=" dfly_cmd_->SetDflyClientVersion(cntx, DflyVersion(version));
<< cntx->conn_state.replication_info.repl_session_id << " is " << version;
cntx->conn_state.replication_info.repl_version = DflyVersion(version);
} else if (cmd == "ACK" && args.size() == 2) { } else if (cmd == "ACK" && args.size() == 2) {
// Don't send error/Ok back through the socket, because we don't want to interleave with // Don't send error/Ok back through the socket, because we don't want to interleave with
// the journal writes that we write into the same socket. // the journal writes that we write into the same socket.

View file

@ -37,16 +37,16 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
DCHECK(!snapshot_fb_.IsJoinable()); DCHECK(!snapshot_fb_.IsJoinable());
auto db_cb = absl::bind_front(&SliceSnapshot::OnDbChange, this); auto db_cb = absl::bind_front(&SliceSnapshot::OnDbChange, this);
snapshot_version_ = db_slice_->RegisterOnChange(move(db_cb)); snapshot_version_ = db_slice_->RegisterOnChange(std::move(db_cb));
if (stream_journal) { if (stream_journal) {
auto* journal = db_slice_->shard_owner()->journal(); auto* journal = db_slice_->shard_owner()->journal();
DCHECK(journal); DCHECK(journal);
auto journal_cb = absl::bind_front(&SliceSnapshot::OnJournalEntry, this); auto journal_cb = absl::bind_front(&SliceSnapshot::OnJournalEntry, this);
journal_cb_id_ = journal->RegisterOnChange(move(journal_cb)); journal_cb_id_ = journal->RegisterOnChange(std::move(journal_cb));
} }
serializer_.reset(new RdbSerializer(compression_mode_)); serializer_ = std::make_unique<RdbSerializer>(compression_mode_);
VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_; VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_;
@ -61,6 +61,55 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
}); });
} }
void SliceSnapshot::StartIncremental(Context* cntx, LSN start_lsn) {
auto* journal = db_slice_->shard_owner()->journal();
DCHECK(journal);
serializer_ = std::make_unique<RdbSerializer>(compression_mode_);
snapshot_fb_ =
fb2::Fiber("incremental_snapshot", [this, journal, cntx, lsn = start_lsn]() mutable {
DCHECK(lsn <= journal->GetLsn()) << "The replica tried to sync from the future.";
VLOG(1) << "Starting incremental snapshot from lsn=" << lsn;
// The replica sends the LSN of the next entry is wants to receive.
while (!cntx->IsCancelled() && journal->IsLSNInBuffer(lsn)) {
serializer_->WriteJournalEntry(journal->GetEntry(lsn));
PushSerializedToChannel(false);
lsn++;
}
VLOG(1) << "Last LSN sent in incremental snapshot was " << (lsn - 1);
// This check is safe, but it is not trivially safe.
// We rely here on the fact that JournalSlice::AddLogRecord can
// only preempt while holding the callback lock.
// That guarantees that if we have processed the last LSN the callback
// will only be added after JournalSlice::AddLogRecord has finished
// iterating its callbacks and we won't process the record twice.
// We have to make sure we don't preempt ourselves before registering the callback!
// GetLsn() is always the next lsn that we expect to create.
if (journal->GetLsn() == lsn) {
{
FiberAtomicGuard fg;
serializer_->SendFullSyncCut();
}
auto journal_cb = absl::bind_front(&SliceSnapshot::OnJournalEntry, this);
journal_cb_id_ = journal->RegisterOnChange(std::move(journal_cb));
PushSerializedToChannel(true);
} else {
// We stopped but we didn't manage to send the whole stream.
cntx->ReportError(
std::make_error_code(errc::state_not_recoverable),
absl::StrCat("Partial sync was unsuccessful because entry #", lsn,
" was dropped from the buffer. Current lsn=", journal->GetLsn()));
Cancel();
}
});
}
void SliceSnapshot::Stop() { void SliceSnapshot::Stop() {
// Wait for serialization to finish in any case. // Wait for serialization to finish in any case.
Join(); Join();
@ -91,8 +140,7 @@ void SliceSnapshot::Cancel() {
void SliceSnapshot::Join() { void SliceSnapshot::Join() {
// Fiber could have already been joined by Stop. // Fiber could have already been joined by Stop.
if (snapshot_fb_.IsJoinable()) snapshot_fb_.JoinIfNeeded();
snapshot_fb_.Join();
} }
// The algorithm is to go over all the buckets and serialize those with // The algorithm is to go over all the buckets and serialize those with

View file

@ -61,6 +61,13 @@ class SliceSnapshot {
// In journal streaming mode it needs to be stopped by either Stop or Cancel. // In journal streaming mode it needs to be stopped by either Stop or Cancel.
void Start(bool stream_journal, const Cancellation* cll); void Start(bool stream_journal, const Cancellation* cll);
// Initialize a snapshot that sends only the missing journal updates
// since start_lsn and then registers a callback switches into the
// journal streaming mode until stopped.
// If we're slower than the buffer and can't continue, `Cancel()` is
// called.
void StartIncremental(Context* cntx, LSN start_lsn);
// Stop snapshot. Only needs to be called for journal streaming mode. // Stop snapshot. Only needs to be called for journal streaming mode.
void Stop(); void Stop();

View file

@ -17,16 +17,20 @@ const char* GetVersion();
// Please document for each new entry what the behavior changes are // Please document for each new entry what the behavior changes are
// and to which released versions this corresponds. // and to which released versions this corresponds.
enum class DflyVersion { enum class DflyVersion {
// Versions <=1.3 // ver <= 1.3
VER0, VER0,
// Versions 1.4<= // 1.4 <= ver <= 1.10
// - Supports receiving ACKs from replicas // - Supports receiving ACKs from replicas
// - Sends version back on REPLCONF capa dragonfly // - Sends version back on REPLCONF capa dragonfly
VER1, VER1,
// 1.11 <= ver
// Supports limited partial sync
VER2,
// Always points to the latest version // Always points to the latest version
CURRENT_VER = VER1, CURRENT_VER = VER2,
}; };
} // namespace dfly } // namespace dfly

View file

@ -4,8 +4,9 @@ import subprocess
import aiohttp import aiohttp
import logging import logging
import os import os
import re
import psutil import psutil
from typing import Optional from typing import Optional, List
from prometheus_client.parser import text_string_to_metric_families from prometheus_client.parser import text_string_to_metric_families
from redis.asyncio import Redis as RedisClient from redis.asyncio import Redis as RedisClient
@ -42,6 +43,7 @@ class DflyInstance:
self.params = params self.params = params
self.proc: Optional[subprocess.Popen] = None self.proc: Optional[subprocess.Popen] = None
self._client: Optional[RedisClient] = None self._client: Optional[RedisClient] = None
self.log_files: List[str] = []
self.dynamic_port = False self.dynamic_port = False
if self.params.existing_port: if self.params.existing_port:
@ -55,6 +57,12 @@ class DflyInstance:
self._port = None self._port = None
self.dynamic_port = True self.dynamic_port = True
# Some tests check the log files, so make sure the log files
# exist even when people try to debug their test.
if "logtostderr" in self.params.args:
self.params.args.remove("logtostderr")
self.params.args.append("alsologtostderr")
def __del__(self): def __del__(self):
assert self.proc == None assert self.proc == None
@ -94,6 +102,7 @@ class DflyInstance:
time.sleep(0.05) time.sleep(0.05)
else: else:
raise DflyStartException("Process didn't start listening on port in time") raise DflyStartException("Process didn't start listening on port in time")
self.log_files = self.get_logs_from_psutil()
def stop(self, kill=False): def stop(self, kill=False):
proc, self.proc = self.proc, None proc, self.proc = self.proc, None
@ -180,6 +189,14 @@ class DflyInstance:
return ports.pop() return ports.pop()
raise RuntimeError("Couldn't parse port") raise RuntimeError("Couldn't parse port")
def get_logs_from_psutil(self) -> List[str]:
p = psutil.Process(self.proc.pid)
rv = []
for file in p.open_files():
if ".log." in file.path and "dragonfly" in file.path:
rv.append(file.path)
return rv
@staticmethod @staticmethod
def format_args(args): def format_args(args):
out = [] out = []
@ -200,6 +217,17 @@ class DflyInstance:
for metric_family in text_string_to_metric_families(data) for metric_family in text_string_to_metric_families(data)
} }
def is_in_logs(self, pattern):
if self.proc is not None:
raise RuntimeError("Must close server first")
matcher = re.compile(pattern)
for path in self.log_files:
for line in open(path):
if matcher.search(line):
return True
return False
class DflyInstanceFactory: class DflyInstanceFactory:
""" """

75
tests/dragonfly/proxy.py Normal file
View file

@ -0,0 +1,75 @@
import asyncio
import random
class Proxy:
def __init__(self, host, port, remote_host, remote_port):
self.host = host
self.port = port
self.remote_host = remote_host
self.remote_port = remote_port
self.stop_connections = []
self.server = None
def __del__(self):
self.close()
async def handle(self, reader, writer):
remote_reader, remote_writer = await asyncio.open_connection(
self.remote_host, self.remote_port
)
async def forward(reader, writer):
while True:
data = await reader.read(1024)
if not data:
break
writer.write(data)
await writer.drain()
writer.close()
task1 = asyncio.ensure_future(forward(reader, remote_writer))
task2 = asyncio.ensure_future(forward(remote_reader, writer))
def cleanup():
task1.cancel()
task2.cancel()
writer.close()
remote_writer.close()
self.stop_connections.append(cleanup)
try:
await asyncio.gather(task1, task2)
except (asyncio.CancelledError, ConnectionResetError):
pass
finally:
cleanup()
if cleanup in self.stop_connections:
self.stop_connections.remove(cleanup)
async def start(self):
self.server = await asyncio.start_server(self.handle, self.host, self.port)
if self.port == 0:
_, port = self.server.sockets[0].getsockname()[:2]
self.port = port
async with self.server:
await self.server.serve_forever()
def drop_connection(self):
"""
Randomally drop one connection
"""
if self.stop_connections:
cb = random.choice(self.stop_connections)
self.stop_connections.remove(cb)
cb()
def close(self):
if self.server is not None:
self.server.close()
self.server = None
for cb in self.stop_connections:
cb()

View file

@ -9,6 +9,7 @@ from .utility import *
from . import DflyInstanceFactory, dfly_args from . import DflyInstanceFactory, dfly_args
import pymemcache import pymemcache
import logging import logging
from .proxy import Proxy
ADMIN_PORT = 1211 ADMIN_PORT = 1211
@ -1364,13 +1365,17 @@ async def test_tls_replication(
# busy wait for 'replica' instance to have replication status 'status' # busy wait for 'replica' instance to have replication status 'status'
async def wait_for_replica_status(replica: aioredis.Redis, status: str, wait_for_seconds=0.01): async def wait_for_replica_status(
while True: replica: aioredis.Redis, status: str, wait_for_seconds=0.01, timeout=20
):
start = time.time()
while (time.time() - start) < timeout:
await asyncio.sleep(wait_for_seconds) await asyncio.sleep(wait_for_seconds)
info = await replica.info("replication") info = await replica.info("replication")
if info["master_link_status"] == status: if info["master_link_status"] == status:
return return
raise RuntimeError("Client did not become available in time!")
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1397,7 +1402,8 @@ async def test_replicaof_flag(df_local_factory):
c_replica = aioredis.Redis(port=replica.port) c_replica = aioredis.Redis(port=replica.port)
await wait_available_async(c_replica) # give it time to startup await wait_available_async(c_replica) # give it time to startup
await wait_for_replica_status(c_replica, status="up") # wait until we have a connection # wait until we have a connection
await wait_for_replica_status(c_replica, status="up")
await check_all_replicas_finished([c_replica], c_master) await check_all_replicas_finished([c_replica], c_master)
dbsize = await c_replica.dbsize() dbsize = await c_replica.dbsize()
@ -1486,7 +1492,7 @@ async def test_replicaof_flag_disconnect(df_local_factory):
val = await c_replica.get("KEY") val = await c_replica.get("KEY")
assert b"VALUE" == val assert b"VALUE" == val
await c_replica.replicaof("no", "one") # disconnect await c_replica.replicaof("no", "one") # disconnect
role = await c_replica.role() role = await c_replica.role()
assert role[0] == b"master" assert role[0] == b"master"
@ -1548,3 +1554,129 @@ async def test_df_crash_on_replicaof_flag(df_local_factory):
master.stop() master.stop()
replica.stop() replica.stop()
async def test_network_disconnect(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=6)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([replica, master])
seeder = df_seeder_factory.create(port=master.port)
async with replica.client() as c_replica:
await seeder.run(target_deviation=0.1)
proxy = Proxy("localhost", 1111, "localhost", master.port)
task = asyncio.create_task(proxy.start())
await c_replica.execute_command(f"REPLICAOF localhost {proxy.port}")
for _ in range(10):
await asyncio.sleep(random.randint(0, 10) / 10)
proxy.drop_connection()
# Give time to detect dropped connection and reconnect
await asyncio.sleep(1.0)
await wait_for_replica_status(c_replica, status="up")
await wait_available_async(c_replica)
capture = await seeder.capture()
assert await seeder.compare(capture, replica.port)
proxy.close()
try:
await task
except asyncio.exceptions.CancelledError:
pass
master.stop()
replica.stop()
assert replica.is_in_logs("partial sync finished in")
async def test_network_disconnect_active_stream(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=4, shard_repl_backlog_len=10000)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([replica, master])
seeder = df_seeder_factory.create(port=master.port)
async with replica.client() as c_replica, master.client() as c_master:
await seeder.run(target_deviation=0.1)
proxy = Proxy("localhost", 1112, "localhost", master.port)
task = asyncio.create_task(proxy.start())
await c_replica.execute_command(f"REPLICAOF localhost {proxy.port}")
fill_task = asyncio.create_task(seeder.run(target_ops=10000))
for _ in range(3):
await asyncio.sleep(random.randint(10, 20) / 10)
proxy.drop_connection()
seeder.stop()
await fill_task
# Give time to detect dropped connection and reconnect
await asyncio.sleep(1.0)
await wait_for_replica_status(c_replica, status="up")
await wait_available_async(c_replica)
logging.debug(await c_replica.execute_command("INFO REPLICATION"))
logging.debug(await c_master.execute_command("INFO REPLICATION"))
capture = await seeder.capture()
assert await seeder.compare(capture, replica.port)
proxy.close()
try:
await task
except asyncio.exceptions.CancelledError:
pass
master.stop()
replica.stop()
assert replica.is_in_logs("partial sync finished in")
async def test_network_disconnect_small_buffer(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=4, shard_repl_backlog_len=1)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([replica, master])
seeder = df_seeder_factory.create(port=master.port)
async with replica.client() as c_replica, master.client() as c_master:
await seeder.run(target_deviation=0.1)
proxy = Proxy("localhost", 1113, "localhost", master.port)
task = asyncio.create_task(proxy.start())
await c_replica.execute_command(f"REPLICAOF localhost {proxy.port}")
fill_task = asyncio.create_task(seeder.run(target_ops=10000))
for _ in range(3):
await asyncio.sleep(random.randint(5, 10) / 10)
proxy.drop_connection()
seeder.stop()
await fill_task
# Give time to detect dropped connection and reconnect
await asyncio.sleep(1.0)
await wait_for_replica_status(c_replica, status="up")
await wait_available_async(c_replica)
# logging.debug(await c_replica.execute_command("INFO REPLICATION"))
# logging.debug(await c_master.execute_command("INFO REPLICATION"))
capture = await seeder.capture()
assert await seeder.compare(capture, replica.port)
proxy.close()
try:
await task
except asyncio.exceptions.CancelledError:
pass
master.stop()
replica.stop()
assert master.is_in_logs("Partial sync requested from stale LSN")

View file

@ -1,4 +1,5 @@
import itertools import itertools
import logging
import sys import sys
import asyncio import asyncio
from redis import asyncio as aioredis from redis import asyncio as aioredis
@ -39,10 +40,11 @@ def batch_fill_data(client, gen, batch_size=100):
client.mset({k: v for k, v, in group}) client.mset({k: v for k, v, in group})
async def wait_available_async(client: aioredis.Redis): async def wait_available_async(client: aioredis.Redis, timeout=10):
"""Block until instance exits loading phase""" """Block until instance exits loading phase"""
its = 0 its = 0
while True: start = time.time()
while (time.time() - start) < timeout:
try: try:
await client.get("key") await client.get("key")
return return
@ -56,6 +58,7 @@ async def wait_available_async(client: aioredis.Redis):
print("W", end="", flush=True) print("W", end="", flush=True)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
its += 1 its += 1
raise RuntimeError("Client did not become available in time!")
class SizeChange(Enum): class SizeChange(Enum):
@ -374,7 +377,7 @@ class DflySeeder:
Run a seeding cycle on all dbs either until stop(), a fixed number of commands (target_ops) Run a seeding cycle on all dbs either until stop(), a fixed number of commands (target_ops)
or until reaching an allowed deviation from the target number of keys (target_deviation) or until reaching an allowed deviation from the target number of keys (target_deviation)
""" """
print(f"Running ops:{target_ops} deviation:{target_deviation}") logging.debug(f"Running ops:{target_ops} deviation:{target_deviation}")
self.stop_flag = False self.stop_flag = False
queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)] queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)]
producer = asyncio.create_task( producer = asyncio.create_task(
@ -392,7 +395,7 @@ class DflySeeder:
took = time.time() - time_start took = time.time() - time_start
qps = round(cmdcount * self.dbcount / took, 2) qps = round(cmdcount * self.dbcount / took, 2)
print(f"Filling took: {took}, QPS: {qps}") logging.debug(f"Filling took: {took}, QPS: {qps}")
def stop(self): def stop(self):
"""Stop all invocations to run""" """Stop all invocations to run"""
@ -407,6 +410,7 @@ class DflySeeder:
if port is None: if port is None:
port = self.port port = self.port
logging.debug(f"Starting capture from {port=}")
keys = sorted(list(self.gen.keys_and_types())) keys = sorted(list(self.gen.keys_and_types()))
captures = await asyncio.gather( captures = await asyncio.gather(