feat(server): Basic capped full sync (#440)

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2022-11-06 17:27:43 +03:00 committed by GitHub
parent 2ed4d3489b
commit 8424f74bec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 639 additions and 75 deletions

View file

@ -72,6 +72,10 @@ class Connection : public util::Connection {
CopyCharBuf(name, sizeof(name_), name_); CopyCharBuf(name, sizeof(name_), name_);
} }
const char* GetName() const {
return name_;
}
void SetPhase(std::string_view phase) { void SetPhase(std::string_view phase) {
CopyCharBuf(phase, sizeof(phase_), phase_); CopyCharBuf(phase, sizeof(phase_), phase_);
} }

View file

@ -89,7 +89,7 @@ struct ConnectionState {
// If this server is master, and this connection is from a secondary replica, // If this server is master, and this connection is from a secondary replica,
// then it holds positive sync session id. // then it holds positive sync session id.
uint32_t repl_session_id = 0; uint32_t repl_session_id = 0;
uint32_t repl_threadid = kuint32max; uint32_t repl_flow_id = kuint32max;
ExecInfo exec_info; ExecInfo exec_info;
std::optional<ScriptInfo> script_info; std::optional<ScriptInfo> script_info;

View file

@ -3,6 +3,7 @@
// //
#include "server/dflycmd.h" #include "server/dflycmd.h"
#include <absl/random/random.h>
#include <absl/strings/str_cat.h> #include <absl/strings/str_cat.h>
#include <absl/strings/strip.h> #include <absl/strings/strip.h>
@ -12,6 +13,8 @@
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/error.h" #include "server/error.h"
#include "server/journal/journal.h" #include "server/journal/journal.h"
#include "server/rdb_save.h"
#include "server/script_mgr.h"
#include "server/server_family.h" #include "server/server_family.h"
#include "server/server_state.h" #include "server/server_state.h"
#include "server/transaction.h" #include "server/transaction.h"
@ -27,8 +30,10 @@ using namespace std;
using util::ProactorBase; using util::ProactorBase;
namespace { namespace {
const char kBadMasterId[] = "bad master id";
const char kIdNotFound[] = "syncid not found"; const char kIdNotFound[] = "syncid not found";
const char kInvalidSyncId[] = "bad sync id"; const char kInvalidSyncId[] = "bad sync id";
const char kInvalidState[] = "invalid state";
bool ToSyncId(string_view str, uint32_t* num) { bool ToSyncId(string_view str, uint32_t* num) {
if (!absl::StartsWith(str, "SYNC")) if (!absl::StartsWith(str, "SYNC"))
@ -37,6 +42,22 @@ bool ToSyncId(string_view str, uint32_t* num) {
return absl::SimpleAtoi(str, num); return absl::SimpleAtoi(str, num);
} }
struct TransactionGuard {
constexpr static auto kEmptyCb = [](Transaction* t, EngineShard* shard) { return OpStatus::OK; };
TransactionGuard(Transaction* t) : t(t) {
t->Schedule();
t->Execute(kEmptyCb, false);
}
~TransactionGuard() {
t->Execute(kEmptyCb, true);
}
Transaction* t;
};
} // namespace } // namespace
DflyCmd::DflyCmd(util::ListenerInterface* listener, ServerFamily* server_family) DflyCmd::DflyCmd(util::ListenerInterface* listener, ServerFamily* server_family)
@ -58,7 +79,11 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
return Thread(args, cntx); return Thread(args, cntx);
} }
if (sub_cmd == "SYNC" && args.size() == 5) { if (sub_cmd == "FLOW" && args.size() == 5) {
return Flow(args, cntx);
}
if (sub_cmd == "SYNC" && args.size() == 3) {
return Sync(args, cntx); return Sync(args, cntx);
} }
@ -70,8 +95,22 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
} }
void DflyCmd::OnClose(ConnectionContext* cntx) { void DflyCmd::OnClose(ConnectionContext* cntx) {
if (cntx->conn_state.repl_session_id > 0 && cntx->conn_state.repl_threadid != kuint32max) { unsigned session_id = cntx->conn_state.repl_session_id;
DeleteSyncSession(cntx->conn_state.repl_session_id); unsigned flow_id = cntx->conn_state.repl_flow_id;
if (!session_id)
return;
if (flow_id == kuint32max) {
DeleteSyncSession(session_id);
} else {
shared_ptr<SyncInfo> sync_info = GetSyncInfo(session_id);
if (sync_info) {
lock_guard lk(sync_info->mu);
if (sync_info->state != SyncState::CANCELLED) {
UnregisterFlow(&sync_info->flows[flow_id]);
}
}
} }
} }
@ -164,39 +203,88 @@ void DflyCmd::Thread(CmdArgList args, ConnectionContext* cntx) {
return rb->SendOk(); return rb->SendOk();
} }
rb->SendError(kInvalidIntErr); return rb->SendError(kInvalidIntErr);
return;
} }
void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder()); RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
string_view masterid = ArgS(args, 2); string_view master_id = ArgS(args, 2);
string_view sync_id_str = ArgS(args, 3); string_view sync_id_str = ArgS(args, 3);
string_view flow_id_str = ArgS(args, 4); string_view flow_id_str = ArgS(args, 4);
unsigned flow_id; VLOG(1) << "Got DFLY FLOW " << master_id << " " << sync_id_str << " " << flow_id_str;
VLOG(1) << "Got DFLY SYNC " << masterid << " " << sync_id_str << " " << flow_id_str;
if (masterid != sf_->master_id()) { if (master_id != sf_->master_id()) {
return rb->SendError("Bad master id"); return rb->SendError(kBadMasterId);
} }
if (!absl::SimpleAtoi(flow_id_str, &flow_id) || !absl::StartsWith(sync_id_str, "SYNC")) { unsigned flow_id;
return rb->SendError(kSyntaxErr); if (!absl::SimpleAtoi(flow_id_str, &flow_id) || flow_id >= shard_set->pool()->size()) {
return rb->SendError(facade::kInvalidIntErr);
} }
auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb); auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb);
if (!sync_id) if (!sync_id)
return; return;
// assuming here that shard id and thread id is the same thing. unique_lock lk(sync_info->mu);
if (int(flow_id) != ProactorBase::GetIndex()) { if (sync_info->state != SyncState::PREPARATION)
listener_->Migrate(cntx->owner(), shard_set->pool()->at(flow_id)); return rb->SendError(kInvalidState);
// Set meta info on connection.
cntx->owner()->SetName(absl::StrCat("repl_flow_", sync_id));
cntx->conn_state.repl_session_id = sync_id;
cntx->conn_state.repl_flow_id = flow_id;
absl::InsecureBitGen gen;
string eof_token = GetRandomHex(gen, 40);
sync_info->flows[flow_id] = FlowInfo{cntx->owner(), eof_token};
listener_->Migrate(cntx->owner(), shard_set->pool()->at(flow_id));
rb->StartArray(2);
rb->SendSimpleString("FULL");
rb->SendSimpleString(eof_token);
}
void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
string_view sync_id_str = ArgS(args, 2);
VLOG(1) << "Got DFLY SYNC " << sync_id_str;
auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb);
if (!sync_id)
return;
unique_lock lk(sync_info->mu);
if (sync_info->state != SyncState::PREPARATION)
return rb->SendError(kInvalidState);
// Check all flows are connected.
// This might happen if a flow abruptly disconnected before sending the SYNC request.
for (const FlowInfo& flow : sync_info->flows) {
if (!flow.conn) {
return rb->SendError(kInvalidState);
}
} }
(void)sync_id; // Start full sync.
(void)sync_info; {
TransactionGuard tg{cntx->transaction};
AggregateStatus status;
auto cb = [this, &status, sync_info = sync_info](unsigned index, auto*) {
status = StartFullSyncInThread(&sync_info->flows[index], EngineShard::tlocal());
};
shard_set->pool()->AwaitFiberOnAll(std::move(cb));
// TODO: Send better error
if (*status != OpStatus::OK)
return rb->SendError(kInvalidState);
}
sync_info->state = SyncState::FULL_SYNC;
return rb->SendOk(); return rb->SendOk();
} }
@ -210,30 +298,123 @@ void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) {
return rb->SendOk(); return rb->SendOk();
} }
OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, EngineShard* shard) {
DCHECK(!flow->fb.joinable());
SaveMode save_mode = shard == nullptr ? SaveMode::SUMMARY : SaveMode::SINGLE_SHARD;
flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false));
if (shard != nullptr) {
flow->saver->StartSnapshotInShard(false, shard);
}
flow->fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow);
return OpStatus::OK;
}
void DflyCmd::FullSyncFb(FlowInfo* flow) {
error_code ec;
RdbSaver* saver = flow->saver.get();
if (saver->Mode() == SaveMode::SUMMARY) {
auto scripts = sf_->script_mgr()->GetLuaScripts();
ec = saver->SaveHeader(scripts);
} else {
ec = saver->SaveHeader({});
}
if (ec) {
LOG(ERROR) << ec;
return;
}
if (saver->Mode() != SaveMode::SUMMARY) {
// TODO: we should be able to stop earlier if requested.
ec = saver->SaveBody(nullptr);
if (ec) {
LOG(ERROR) << ec;
return;
}
}
ec = flow->conn->socket()->Write(io::Buffer(flow->eof_token));
if (ec) {
LOG(ERROR) << ec;
return;
}
ec = flow->conn->socket()->Shutdown(SHUT_RDWR);
}
uint32_t DflyCmd::CreateSyncSession() { uint32_t DflyCmd::CreateSyncSession() {
unique_lock lk(mu_); unique_lock lk(mu_);
auto [it, inserted] = sync_infos_.emplace(next_sync_id_, new SyncInfo); auto sync_info = make_shared<SyncInfo>();
sync_info->flows.resize(shard_set->size() + 1);
auto [it, inserted] = sync_infos_.emplace(next_sync_id_, std::move(sync_info));
CHECK(inserted); CHECK(inserted);
return next_sync_id_++; return next_sync_id_++;
} }
void DflyCmd::UnregisterFlow(FlowInfo* flow) {
// TODO: Cancel saver operations.
flow->conn = nullptr;
flow->saver.reset();
}
void DflyCmd::DeleteSyncSession(uint32_t sync_id) { void DflyCmd::DeleteSyncSession(uint32_t sync_id) {
shared_ptr<SyncInfo> sync_info;
// Remove sync_info from map.
// Store by value to keep alive.
{
unique_lock lk(mu_);
auto it = sync_infos_.find(sync_id);
if (it == sync_infos_.end())
return;
sync_info = it->second;
sync_infos_.erase(it);
}
// Wait for all operations to finish.
// Set state to CANCELLED so no other operations will run.
{
unique_lock lk(sync_info->mu);
sync_info->state = SyncState::CANCELLED;
}
// Try to cleanup flows.
for (auto& flow : sync_info->flows) {
if (flow.conn != nullptr) {
VLOG(1) << "Flow connection " << flow.conn->GetName() << " is still alive"
<< " on sync_id " << sync_id;
}
// TODO: Implement cancellation.
if (flow.fb.joinable()) {
VLOG(1) << "Force joining fiber on on sync_id " << sync_id;
flow.fb.join();
}
}
}
shared_ptr<DflyCmd::SyncInfo> DflyCmd::GetSyncInfo(uint32_t sync_id) {
unique_lock lk(mu_); unique_lock lk(mu_);
auto it = sync_infos_.find(sync_id); auto it = sync_infos_.find(sync_id);
if (it == sync_infos_.end()) if (it != sync_infos_.end())
return; return it->second;
return {};
delete it->second;
sync_infos_.erase(it);
} }
pair<uint32_t, DflyCmd::SyncInfo*> DflyCmd::GetSyncInfoOrReply(std::string_view id_str, pair<uint32_t, shared_ptr<DflyCmd::SyncInfo>> DflyCmd::GetSyncInfoOrReply(std::string_view id_str,
RedisReplyBuilder* rb) { RedisReplyBuilder* rb) {
uint32_t sync_id; unique_lock lk(mu_);
uint32_t sync_id;
if (!ToSyncId(id_str, &sync_id)) { if (!ToSyncId(id_str, &sync_id)) {
rb->SendError(kInvalidSyncId); rb->SendError(kInvalidSyncId);
return {0, nullptr}; return {0, nullptr};

View file

@ -5,6 +5,9 @@
#pragma once #pragma once
#include <absl/container/btree_map.h> #include <absl/container/btree_map.h>
#include <memory.h>
#include <boost/fiber/fiber.hpp>
#include "server/conn_context.h" #include "server/conn_context.h"
@ -20,6 +23,7 @@ namespace dfly {
class EngineShardSet; class EngineShardSet;
class ServerFamily; class ServerFamily;
class RdbSaver;
namespace journal { namespace journal {
class Journal; class Journal;
@ -27,12 +31,27 @@ class Journal;
class DflyCmd { class DflyCmd {
public: public:
enum class SyncState { PREPARATION, FULL_SYNC }; enum class SyncState { PREPARATION, FULL_SYNC, CANCELLED };
struct FlowInfo {
FlowInfo() = default;
FlowInfo(facade::Connection* conn, const std::string& eof_token)
: conn(conn), eof_token(eof_token){};
facade::Connection* conn;
std::string eof_token;
std::unique_ptr<RdbSaver> saver;
::boost::fibers::fiber fb;
};
struct SyncInfo { struct SyncInfo {
SyncState state = SyncState::PREPARATION; SyncState state = SyncState::PREPARATION;
int64_t tx_id = 0; std::vector<FlowInfo> flows;
::boost::fibers::mutex mu; // guard operations on replica.
}; };
public: public:
@ -57,6 +76,10 @@ class DflyCmd {
// Return connection thread index or migrate to another thread. // Return connection thread index or migrate to another thread.
void Thread(CmdArgList args, ConnectionContext* cntx); void Thread(CmdArgList args, ConnectionContext* cntx);
// FLOW <masterid> <syncid> <flowid>
// Register connection as flow for sync session.
void Flow(CmdArgList args, ConnectionContext* cntx);
// SYNC <masterid> <syncid> <flowid> // SYNC <masterid> <syncid> <flowid>
// Migrate connection to required flow thread. // Migrate connection to required flow thread.
// Stub: will be replcaed with full sync. // Stub: will be replcaed with full sync.
@ -66,19 +89,31 @@ class DflyCmd {
// Check all keys for expiry. // Check all keys for expiry.
void Expire(CmdArgList args, ConnectionContext* cntx); void Expire(CmdArgList args, ConnectionContext* cntx);
// Delete sync session. // Start full sync in thread. Start FullSyncFb. Called for each flow.
facade::OpStatus StartFullSyncInThread(FlowInfo* flow, EngineShard* shard);
// Fiber that runs full sync for each flow.
void FullSyncFb(FlowInfo* flow);
// Unregister flow. Must be called when flow disconnects.
void UnregisterFlow(FlowInfo*);
// Delete sync session. Cleanup flows.
void DeleteSyncSession(uint32_t sync_id); void DeleteSyncSession(uint32_t sync_id);
// Get SyncInfo by sync_id.
std::shared_ptr<SyncInfo> GetSyncInfo(uint32_t sync_id);
// Find sync info by id or send error reply. // Find sync info by id or send error reply.
std::pair<uint32_t, SyncInfo*> GetSyncInfoOrReply(std::string_view id, std::pair<uint32_t, std::shared_ptr<SyncInfo>> GetSyncInfoOrReply(std::string_view id,
facade::RedisReplyBuilder* rb); facade::RedisReplyBuilder* rb);
ServerFamily* sf_; ServerFamily* sf_;
util::ListenerInterface* listener_; util::ListenerInterface* listener_;
TxId journal_txid_ = 0; TxId journal_txid_ = 0;
absl::btree_map<uint32_t, SyncInfo*> sync_infos_; absl::btree_map<uint32_t, std::shared_ptr<SyncInfo>> sync_infos_;
uint32_t next_sync_id_ = 1; uint32_t next_sync_id_ = 1;
::boost::fibers::mutex mu_; // guard sync info and journal operations. ::boost::fibers::mutex mu_; // guard sync info and journal operations.

View file

@ -254,6 +254,8 @@ void Replica::ReplicateFb() {
state_mask_ &= R_ENABLED; // reset all flags besides R_ENABLED state_mask_ &= R_ENABLED; // reset all flags besides R_ENABLED
continue; continue;
} }
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
VLOG(1) << "Replica greet ok"; VLOG(1) << "Replica greet ok";
} }
@ -358,8 +360,8 @@ error_code Replica::Greet() {
master_context_.dfly_session_id = param1; master_context_.dfly_session_id = param1;
num_df_flows_ = param2; num_df_flows_ = param2;
VLOG(1) << "Master id: " << param0 << ", sync id: " << param1 VLOG(1) << "Master id: " << param0 << ", sync id: " << param1 << ", num journals "
<< ", num journals " << num_df_flows_; << num_df_flows_;
} else { } else {
LOG(ERROR) << "Bad response " << ToSV(io_buf.InputBuffer()); LOG(ERROR) << "Bad response " << ToSV(io_buf.InputBuffer());
@ -474,6 +476,27 @@ error_code Replica::InitiateDflySync() {
if (ec) if (ec)
return ec; return ec;
ReqSerializer serializer{sock_.get()};
// Master waits for this command in order to start sending replication stream.
serializer.SendCommand(StrCat("DFLY SYNC ", master_context_.dfly_session_id));
RETURN_ON_ERR(serializer.ec());
base::IoBuf io_buf{128};
unsigned consumed = 0;
RETURN_ON_ERR(ReadRespReply(&io_buf, &consumed));
if (resp_args_.size() != 1 || resp_args_.front().type != RespExpr::STRING ||
ToSV(resp_args_.front().GetBuf()) != "OK") {
LOG(ERROR) << "Sync failed " << ToSV(io_buf.InputBuffer());
return make_error_code(errc::bad_message);
}
for (unsigned i = 0; i < num_df_flows_; ++i) {
shard_flows_[i]->sync_fb_.join();
}
LOG(INFO) << "Full sync finished";
state_mask_ |= R_SYNC_OK; state_mask_ |= R_SYNC_OK;
return error_code{}; return error_code{};
@ -710,6 +733,28 @@ error_code Replica::ParseAndExecute(base::IoBuf* io_buf) {
return error_code{}; return error_code{};
} }
void Replica::ReplicateDFFb(unique_ptr<base::IoBuf> io_buf, string eof_token) {
SocketSource ss{sock_.get()};
io::PrefixSource ps{io_buf->InputBuffer(), &ss};
RdbLoader loader(NULL);
loader.Load(&ps);
if (!eof_token.empty()) {
unique_ptr<uint8_t[]> buf(new uint8_t[eof_token.size()]);
// pass leftover data from the loader.
io::PrefixSource chained(loader.Leftover(), &ps);
VLOG(1) << "Before reading from chained stream";
io::Result<size_t> eof_res = chained.Read(io::MutableBytes{buf.get(), eof_token.size()});
if (!eof_res || *eof_res != eof_token.size()) {
LOG(ERROR) << "Error finding eof token in the stream";
}
// TODO - to compare tokens
}
VLOG(1) << "ReplicateDFFb finished after reading " << loader.bytes_read() << " bytes";
}
error_code Replica::StartFlow() { error_code Replica::StartFlow() {
CHECK(!sock_); CHECK(!sock_);
DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty()); DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty());
@ -720,22 +765,42 @@ error_code Replica::StartFlow() {
sock_.reset(mythread->CreateSocket()); sock_.reset(mythread->CreateSocket());
RETURN_ON_ERR(sock_->Connect(master_context_.master_ep)); RETURN_ON_ERR(sock_->Connect(master_context_.master_ep));
VLOG(1) << "Sending on flow " << master_context_.master_repl_id << " "
<< master_context_.dfly_session_id << " " << master_context_.flow_id;
ReqSerializer serializer{sock_.get()}; ReqSerializer serializer{sock_.get()};
serializer.SendCommand(StrCat("DFLY SYNC ", master_context_.master_repl_id, " ", serializer.SendCommand(StrCat("DFLY FLOW ", master_context_.master_repl_id, " ",
master_context_.dfly_session_id, " ", master_context_.flow_id)); master_context_.dfly_session_id, " ", master_context_.flow_id));
RETURN_ON_ERR(serializer.ec()); RETURN_ON_ERR(serializer.ec());
parser_.reset(new RedisParser{false}); // client mode parser_.reset(new RedisParser{false}); // client mode
base::IoBuf io_buf{128};
unsigned consumed = 0;
RETURN_ON_ERR(ReadRespReply(&io_buf, &consumed));
if (resp_args_.size() != 1 || resp_args_.front().type != RespExpr::STRING || std::unique_ptr<base::IoBuf> io_buf{new base::IoBuf(128)};
ToSV(resp_args_.front().GetBuf()) != "OK") { unsigned consumed = 0;
LOG(ERROR) << "Bad SYNC response " << ToSV(io_buf.InputBuffer()); RETURN_ON_ERR(ReadRespReply(io_buf.get(), &consumed)); // uses parser_
if (resp_args_.size() < 2 || resp_args_[0].type != RespExpr::STRING ||
resp_args_[1].type != RespExpr::STRING) {
LOG(ERROR) << "Bad FLOW response " << ToSV(io_buf->InputBuffer());
return make_error_code(errc::bad_message); return make_error_code(errc::bad_message);
} }
string_view flow_directive = ToSV(resp_args_[0].GetBuf());
string eof_token;
if (flow_directive == "FULL") {
eof_token = ToSV(resp_args_[1].GetBuf());
} else {
LOG(ERROR) << "Bad FLOW response " << ToSV(io_buf->InputBuffer());
}
io_buf->ConsumeInput(consumed);
state_mask_ = R_ENABLED | R_TCP_CONNECTED;
// We can not discard io_buf because it may contain data
// besides the response we parsed. Therefore we pass it further to ReplicateDFFb.
sync_fb_ =
::boost::fibers::fiber(&Replica::ReplicateDFFb, this, std::move(io_buf), move(eof_token));
return error_code{}; return error_code{};
} }

View file

@ -17,7 +17,6 @@ class Service;
class ConnectionContext; class ConnectionContext;
class Replica { class Replica {
// The attributes of the master we are connecting to. // The attributes of the master we are connecting to.
struct MasterContext { struct MasterContext {
std::string host; std::string host;
@ -102,6 +101,9 @@ class Replica {
std::error_code StartFlow(); std::error_code StartFlow();
// Full sync fiber function.
void ReplicateDFFb(std::unique_ptr<base::IoBuf> io_buf, std::string eof_token);
Service& service_; Service& service_;
::boost::fibers::fiber sync_fb_; ::boost::fibers::fiber sync_fb_;

View file

@ -872,7 +872,7 @@ error_code ServerFamily::DoSave(bool new_version, Transaction* trans, string* er
// Save summary file. // Save summary file.
{ {
const auto& scripts = script_mgr_->GetLuaScripts(); const auto scripts = script_mgr_->GetLuaScripts();
auto& summary_snapshot = snapshots[shard_set->size()]; auto& summary_snapshot = snapshots[shard_set->size()];
summary_snapshot.reset(new RdbSnapshot(fq_threadpool_.get())); summary_snapshot.reset(new RdbSnapshot(fq_threadpool_.get()));
if (ec = DoPartialSave(filename, path, now, scripts, summary_snapshot.get(), nullptr)) { if (ec = DoPartialSave(filename, path, now, scripts, summary_snapshot.get(), nullptr)) {
@ -899,7 +899,7 @@ error_code ServerFamily::DoSave(bool new_version, Transaction* trans, string* er
VLOG(1) << "Saving to " << path; VLOG(1) << "Saving to " << path;
snapshots[0].reset(new RdbSnapshot(fq_threadpool_.get())); snapshots[0].reset(new RdbSnapshot(fq_threadpool_.get()));
const auto& lua_scripts = script_mgr_->GetLuaScripts(); const auto lua_scripts = script_mgr_->GetLuaScripts();
ec = snapshots[0]->Start(SaveMode::RDB, path.generic_string(), lua_scripts); ec = snapshots[0]->Start(SaveMode::RDB, path.generic_string(), lua_scripts);
if (!ec) { if (!ec) {
@ -1471,6 +1471,12 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
replica_.swap(new_replica); replica_.swap(new_replica);
GlobalState new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING);
if (new_state != GlobalState::LOADING) {
LOG(WARNING) << GlobalStateName(new_state) << " in progress, ignored";
return;
}
// Flushing all the data after we marked this instance as replica. // Flushing all the data after we marked this instance as replica.
Transaction* transaction = cntx->transaction; Transaction* transaction = cntx->transaction;
transaction->Schedule(); transaction->Schedule();
@ -1484,6 +1490,7 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
// Replica sends response in either case. No need to send response in this function. // Replica sends response in either case. No need to send response in this function.
// It's a bit confusing but simpler. // It's a bit confusing but simpler.
if (!replica_->Run(cntx)) { if (!replica_->Run(cntx)) {
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
replica_.reset(); replica_.reset();
} }

View file

@ -1,4 +1,7 @@
import pytest import pytest
import typing
import time
import subprocess
import time import time
import subprocess import subprocess
@ -14,6 +17,7 @@ class DflyInstance:
self.path = path self.path = path
self.args = args self.args = args
self.cwd = cwd self.cwd = cwd
self.proc = None
def start(self): def start(self):
arglist = DflyInstance.format_args(self.args) arglist = DflyInstance.format_args(self.args)
@ -29,14 +33,21 @@ class DflyInstance:
raise Exception( raise Exception(
f"Failed to start instance, return code {return_code}") f"Failed to start instance, return code {return_code}")
def stop(self): def stop(self, kill=False):
proc, self.proc = self.proc, None
if proc is None:
return
print(f"Stopping instance on {self.port}") print(f"Stopping instance on {self.port}")
try: try:
self.proc.terminate() if kill:
outs, errs = self.proc.communicate(timeout=15) proc.kill()
else:
proc.terminate()
outs, errs = proc.communicate(timeout=15)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
print("Unable to terminate DragonflyDB gracefully, it was killed") print("Unable to terminate DragonflyDB gracefully, it was killed")
outs, errs = self.proc.communicate() outs, errs = proc.communicate()
print(outs, errs) print(outs, errs)
def __getitem__(self, k): def __getitem__(self, k):
@ -64,12 +75,21 @@ class DflyInstanceFactory:
self.cwd = cwd self.cwd = cwd
self.path = path self.path = path
self.args = args self.args = args
self.instances = []
def create(self, **kwargs) -> DflyInstance: def create(self, **kwargs) -> DflyInstance:
args = {**self.args, **kwargs} args = {**self.args, **kwargs}
for k, v in args.items(): for k, v in args.items():
args[k] = v.format(**self.env) if isinstance(v, str) else v args[k] = v.format(**self.env) if isinstance(v, str) else v
return DflyInstance(self.path, args, self.cwd)
instance = DflyInstance(self.path, args, self.cwd)
self.instances.append(instance)
return instance
def stop_all(self):
"""Stop all lanched instances."""
for instance in self.instances:
instance.stop()
def dfly_args(*args): def dfly_args(*args):

View file

@ -50,7 +50,17 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
scripts_dir, '../../build-dbg/dragonfly')) scripts_dir, '../../build-dbg/dragonfly'))
args = request.param if request.param else {} args = request.param if request.param else {}
return DflyInstanceFactory(test_env, tmp_dir, path=path, args=args) factory = DflyInstanceFactory(test_env, tmp_dir, path=path, args=args)
yield factory
factory.stop_all()
@pytest.fixture(scope="function")
def df_local_factory(df_factory: DflyInstanceFactory):
factory = DflyInstanceFactory(
df_factory.env, df_factory.cwd, df_factory.path, df_factory.args)
yield factory
factory.stop_all()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -61,6 +71,7 @@ def df_server(df_factory: DflyInstanceFactory) -> DflyInstance:
""" """
instance = df_factory.create() instance = df_factory.create()
instance.start() instance.start()
yield instance yield instance
clients_left = None clients_left = None

View file

@ -1,4 +1,5 @@
from dragonfly import dfly_multi_test_args from . import dfly_multi_test_args
from .utility import batch_fill_data, gen_test_data
@dfly_multi_test_args({'keys_output_limit': 512}, {'keys_output_limit': 1024}) @dfly_multi_test_args({'keys_output_limit': 512}, {'keys_output_limit': 1024})
@ -6,7 +7,7 @@ class TestKeys:
def test_max_keys(self, client, df_server): def test_max_keys(self, client, df_server):
max_keys = df_server['keys_output_limit'] max_keys = df_server['keys_output_limit']
for x in range(max_keys*3): batch_fill_data(client, gen_test_data(max_keys * 3))
client.set(str(x), str(x))
keys = client.keys() keys = client.keys()
assert len(keys) in range(max_keys, max_keys+512) assert len(keys) in range(max_keys, max_keys+512)

View file

@ -0,0 +1,167 @@
import pytest
import asyncio
import aioredis
import redis
import time
from .utility import *
BASE_PORT = 1111
"""
Test simple full sync on one replica without altering data during replication.
"""
# (threads_master, threads_replica, n entries)
simple_full_sync_cases = [
(2, 2, 100),
(8, 2, 500),
(2, 8, 500),
(6, 4, 500)
]
@pytest.mark.parametrize("t_master, t_replica, n_keys", simple_full_sync_cases)
def test_simple_full_sync(df_local_factory, t_master, t_replica, n_keys):
master = df_local_factory.create(port=1111, proactor_threads=t_master)
replica = df_local_factory.create(port=1112, proactor_threads=t_replica)
# Start master and fill with test data
master.start()
c_master = redis.Redis(port=master.port)
batch_fill_data(c_master, gen_test_data(n_keys))
# Start replica and run REPLICAOF
replica.start()
c_replica = redis.Redis(port=replica.port)
c_replica.replicaof("localhost", str(master.port))
# Check replica received test data
wait_available(c_replica)
batch_check_data(c_replica, gen_test_data(n_keys))
# Stop replication manually
c_replica.replicaof("NO", "ONE")
assert c_replica.set("writeable", "true")
# Check test data persisted
batch_check_data(c_replica, gen_test_data(n_keys))
"""
Test simple full sync on multiple replicas without altering data during replication.
The replicas start running in parallel.
"""
# (threads_master, threads_replicas, n entries)
simple_full_sync_multi_cases = [
(4, [3, 2], 500),
(8, [6, 5, 4], 500),
(8, [2] * 5, 100),
(4, [1] * 20, 500)
]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, n_keys", simple_full_sync_multi_cases)
async def test_simple_full_sync_multi(df_local_factory, t_master, t_replicas, n_keys):
def data_gen(): return gen_test_data(n_keys)
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [
df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t)
for i, t in enumerate(t_replicas)
]
# Start master and fill with test data
master.start()
c_master = aioredis.Redis(port=master.port, single_connection_client=True)
await batch_fill_data_async(c_master, data_gen())
# Start replica tasks in parallel
tasks = [
asyncio.create_task(run_sfs_replica(
replica, master, data_gen), name="replica-"+str(replica.port))
for replica in replicas
]
for task in tasks:
assert await task
await c_master.connection_pool.disconnect()
async def run_sfs_replica(replica, master, data_gen):
replica.start()
c_replica = aioredis.Redis(
port=replica.port, single_connection_client=None)
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
await wait_available_async(c_replica)
await batch_check_data_async(c_replica, data_gen())
await c_replica.connection_pool.disconnect()
return True
"""
Test replica crash during full sync on multiple replicas without altering data during replication.
"""
# (threads_master, threads_replicas, n entries)
simple_full_sync_multi_crash_cases = [
(5, [1] * 15, 5000),
(5, [1] * 20, 5000),
(5, [1] * 25, 5000)
]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, n_keys", simple_full_sync_multi_crash_cases)
async def test_simple_full_sync_mutli_crash(df_local_factory, t_master, t_replicas, n_keys):
def data_gen(): return gen_test_data(n_keys)
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [
df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t)
for i, t in enumerate(t_replicas)
]
# Start master and fill with test data
master.start()
c_master = aioredis.Redis(port=master.port, single_connection_client=True)
await batch_fill_data_async(c_master, data_gen())
# Start replica tasks in parallel
tasks = [
asyncio.create_task(run_sfs_crash_replica(
replica, master, data_gen), name="replica-"+str(replica.port))
for replica in replicas
]
for task in tasks:
assert await task
# Check master is ok
await batch_check_data_async(c_master, data_gen())
await c_master.connection_pool.disconnect()
async def run_sfs_crash_replica(replica, master, data_gen):
replica.start()
c_replica = aioredis.Redis(
port=replica.port, single_connection_client=None)
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
# Kill the replica after a short delay
await asyncio.sleep(0.0)
replica.stop(kill=True)
await c_replica.connection_pool.disconnect()
return True

View file

@ -4,33 +4,22 @@ import redis
import string import string
import os import os
import glob import glob
from pathlib import Path from pathlib import Path
from dragonfly import dfly_args
from . import dfly_args
from .utility import batch_check_data, batch_fill_data, gen_test_data
BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"} BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"}
NUM_KEYS = 100
class SnapshotTestBase: class SnapshotTestBase:
KEYS = string.ascii_lowercase
def setup(self, tmp_dir: Path): def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
self.rdb_out = tmp_dir / "test.rdb" self.rdb_out = tmp_dir / "test.rdb"
if self.rdb_out.exists(): if self.rdb_out.exists():
self.rdb_out.unlink() self.rdb_out.unlink()
def populate(self, client: redis.Redis):
"""Populate instance with test data"""
for k in self.KEYS:
client.set(k, "val-"+k)
def check(self, client: redis.Redis):
"""Check instance contains test data"""
for k in self.KEYS:
assert client.get(k) == "val-"+k
def get_main_file(self, suffix): def get_main_file(self, suffix):
def is_main(f): return "summary" in f if suffix == "dfs" else True def is_main(f): return "summary" in f if suffix == "dfs" else True
files = glob.glob(str(self.tmp_dir.absolute()) + '/test-*.'+suffix) files = glob.glob(str(self.tmp_dir.absolute()) + '/test-*.'+suffix)
@ -45,14 +34,14 @@ class TestRdbSnapshot(SnapshotTestBase):
super().setup(tmp_dir) super().setup(tmp_dir)
def test_snapshot(self, client: redis.Redis): def test_snapshot(self, client: redis.Redis):
super().populate(client) batch_fill_data(client, gen_test_data(NUM_KEYS))
# save + flush + load # save + flush + load
client.execute_command("SAVE") client.execute_command("SAVE")
assert client.flushall() assert client.flushall()
client.execute_command("DEBUG LOAD " + super().get_main_file("rdb")) client.execute_command("DEBUG LOAD " + super().get_main_file("rdb"))
super().check(client) batch_check_data(client, gen_test_data(NUM_KEYS))
@dfly_args({**BASIC_ARGS, "dbfilename": "test"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test"})
@ -66,14 +55,14 @@ class TestDflySnapshot(SnapshotTestBase):
os.remove(file) os.remove(file)
def test_snapshot(self, client: redis.Redis): def test_snapshot(self, client: redis.Redis):
super().populate(client) batch_fill_data(client, gen_test_data(NUM_KEYS))
# save + flush + load # save + flush + load
client.execute_command("SAVE DF") client.execute_command("SAVE DF")
assert client.flushall() assert client.flushall()
client.execute_command("DEBUG LOAD " + super().get_main_file("dfs")) client.execute_command("DEBUG LOAD " + super().get_main_file("dfs"))
super().check(client) batch_check_data(client, gen_test_data(NUM_KEYS))
@dfly_args({**BASIC_ARGS, "dbfilename": "test.rdb", "save_schedule": "*:*"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test.rdb", "save_schedule": "*:*"})
@ -84,7 +73,7 @@ class TestPeriodicSnapshot(SnapshotTestBase):
super().setup(tmp_dir) super().setup(tmp_dir)
def test_snapshot(self, client: redis.Redis): def test_snapshot(self, client: redis.Redis):
super().populate(client) batch_fill_data(client, gen_test_data(NUM_KEYS))
time.sleep(60) time.sleep(60)

View file

@ -0,0 +1,82 @@
import redis
import aioredis
import itertools
import time
import asyncio
def grouper(n, iterable):
"""Transform iterable into iterator of chunks of size n"""
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
BATCH_SIZE = 100
def gen_test_data(n):
for i in range(n):
yield "k-"+str(i), "v-"+str(i)
def batch_fill_data(client: redis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
client.mset({k: v for k, v, in group})
async def batch_fill_data_async(client: aioredis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
await client.mset({k: v for k, v in group})
def as_str_val(v) -> str:
if isinstance(v, str):
return v
elif isinstance(v, bytes):
return v.decode()
else:
return str(v)
def batch_check_data(client: redis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
vals = client.mget(k for k, _ in group)
assert all(as_str_val(vals[i]) == v for i, (_, v) in enumerate(group))
async def batch_check_data_async(client: aioredis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
vals = await client.mget(k for k, _ in group)
assert all(as_str_val(vals[i]) == v for i, (_, v) in enumerate(group))
def wait_available(client: redis.Redis):
its = 0
while True:
try:
client.get('key')
print("wait_available iterations:", its)
return
except redis.ResponseError as e:
assert "Can not execute during LOADING" in str(e)
time.sleep(0.01)
its += 1
async def wait_available_async(client: aioredis.Redis):
its = 0
while True:
try:
await client.get('key')
print("wait_available iterations:", its)
return
except aioredis.ResponseError as e:
assert "Can not execute during LOADING" in str(e)
await asyncio.sleep(0.01)
its += 1