diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 1de48d183..64b94b8b6 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -72,6 +72,10 @@ class Connection : public util::Connection { CopyCharBuf(name, sizeof(name_), name_); } + const char* GetName() const { + return name_; + } + void SetPhase(std::string_view phase) { CopyCharBuf(phase, sizeof(phase_), phase_); } diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 1a0e25e46..be1640489 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -89,7 +89,7 @@ struct ConnectionState { // If this server is master, and this connection is from a secondary replica, // then it holds positive sync session id. uint32_t repl_session_id = 0; - uint32_t repl_threadid = kuint32max; + uint32_t repl_flow_id = kuint32max; ExecInfo exec_info; std::optional script_info; diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index b72837511..a9cdea550 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -3,6 +3,7 @@ // #include "server/dflycmd.h" +#include #include #include @@ -12,6 +13,8 @@ #include "server/engine_shard_set.h" #include "server/error.h" #include "server/journal/journal.h" +#include "server/rdb_save.h" +#include "server/script_mgr.h" #include "server/server_family.h" #include "server/server_state.h" #include "server/transaction.h" @@ -27,8 +30,10 @@ using namespace std; using util::ProactorBase; namespace { +const char kBadMasterId[] = "bad master id"; const char kIdNotFound[] = "syncid not found"; const char kInvalidSyncId[] = "bad sync id"; +const char kInvalidState[] = "invalid state"; bool ToSyncId(string_view str, uint32_t* num) { if (!absl::StartsWith(str, "SYNC")) @@ -37,6 +42,22 @@ bool ToSyncId(string_view str, uint32_t* num) { return absl::SimpleAtoi(str, num); } + +struct TransactionGuard { + constexpr static auto kEmptyCb = [](Transaction* t, EngineShard* shard) { return OpStatus::OK; }; + + TransactionGuard(Transaction* t) : t(t) { + t->Schedule(); + t->Execute(kEmptyCb, false); + } + + ~TransactionGuard() { + t->Execute(kEmptyCb, true); + } + + Transaction* t; +}; + } // namespace DflyCmd::DflyCmd(util::ListenerInterface* listener, ServerFamily* server_family) @@ -58,7 +79,11 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) { return Thread(args, cntx); } - if (sub_cmd == "SYNC" && args.size() == 5) { + if (sub_cmd == "FLOW" && args.size() == 5) { + return Flow(args, cntx); + } + + if (sub_cmd == "SYNC" && args.size() == 3) { return Sync(args, cntx); } @@ -70,8 +95,22 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) { } void DflyCmd::OnClose(ConnectionContext* cntx) { - if (cntx->conn_state.repl_session_id > 0 && cntx->conn_state.repl_threadid != kuint32max) { - DeleteSyncSession(cntx->conn_state.repl_session_id); + unsigned session_id = cntx->conn_state.repl_session_id; + unsigned flow_id = cntx->conn_state.repl_flow_id; + + if (!session_id) + return; + + if (flow_id == kuint32max) { + DeleteSyncSession(session_id); + } else { + shared_ptr sync_info = GetSyncInfo(session_id); + if (sync_info) { + lock_guard lk(sync_info->mu); + if (sync_info->state != SyncState::CANCELLED) { + UnregisterFlow(&sync_info->flows[flow_id]); + } + } } } @@ -164,39 +203,88 @@ void DflyCmd::Thread(CmdArgList args, ConnectionContext* cntx) { return rb->SendOk(); } - rb->SendError(kInvalidIntErr); - return; + return rb->SendError(kInvalidIntErr); } -void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { +void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) { RedisReplyBuilder* rb = static_cast(cntx->reply_builder()); - string_view masterid = ArgS(args, 2); + string_view master_id = ArgS(args, 2); string_view sync_id_str = ArgS(args, 3); string_view flow_id_str = ArgS(args, 4); - unsigned flow_id; - VLOG(1) << "Got DFLY SYNC " << masterid << " " << sync_id_str << " " << flow_id_str; + VLOG(1) << "Got DFLY FLOW " << master_id << " " << sync_id_str << " " << flow_id_str; - if (masterid != sf_->master_id()) { - return rb->SendError("Bad master id"); + if (master_id != sf_->master_id()) { + return rb->SendError(kBadMasterId); } - if (!absl::SimpleAtoi(flow_id_str, &flow_id) || !absl::StartsWith(sync_id_str, "SYNC")) { - return rb->SendError(kSyntaxErr); + unsigned flow_id; + if (!absl::SimpleAtoi(flow_id_str, &flow_id) || flow_id >= shard_set->pool()->size()) { + return rb->SendError(facade::kInvalidIntErr); } auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb); if (!sync_id) return; - // assuming here that shard id and thread id is the same thing. - if (int(flow_id) != ProactorBase::GetIndex()) { - listener_->Migrate(cntx->owner(), shard_set->pool()->at(flow_id)); + unique_lock lk(sync_info->mu); + if (sync_info->state != SyncState::PREPARATION) + return rb->SendError(kInvalidState); + + // Set meta info on connection. + cntx->owner()->SetName(absl::StrCat("repl_flow_", sync_id)); + cntx->conn_state.repl_session_id = sync_id; + cntx->conn_state.repl_flow_id = flow_id; + + absl::InsecureBitGen gen; + string eof_token = GetRandomHex(gen, 40); + + sync_info->flows[flow_id] = FlowInfo{cntx->owner(), eof_token}; + listener_->Migrate(cntx->owner(), shard_set->pool()->at(flow_id)); + + rb->StartArray(2); + rb->SendSimpleString("FULL"); + rb->SendSimpleString(eof_token); +} + +void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { + RedisReplyBuilder* rb = static_cast(cntx->reply_builder()); + string_view sync_id_str = ArgS(args, 2); + + VLOG(1) << "Got DFLY SYNC " << sync_id_str; + + auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb); + if (!sync_id) + return; + + unique_lock lk(sync_info->mu); + if (sync_info->state != SyncState::PREPARATION) + return rb->SendError(kInvalidState); + + // Check all flows are connected. + // This might happen if a flow abruptly disconnected before sending the SYNC request. + for (const FlowInfo& flow : sync_info->flows) { + if (!flow.conn) { + return rb->SendError(kInvalidState); + } } - (void)sync_id; - (void)sync_info; + // Start full sync. + { + TransactionGuard tg{cntx->transaction}; + AggregateStatus status; + auto cb = [this, &status, sync_info = sync_info](unsigned index, auto*) { + status = StartFullSyncInThread(&sync_info->flows[index], EngineShard::tlocal()); + }; + shard_set->pool()->AwaitFiberOnAll(std::move(cb)); + + // TODO: Send better error + if (*status != OpStatus::OK) + return rb->SendError(kInvalidState); + } + + sync_info->state = SyncState::FULL_SYNC; return rb->SendOk(); } @@ -210,30 +298,123 @@ void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) { return rb->SendOk(); } +OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, EngineShard* shard) { + DCHECK(!flow->fb.joinable()); + + SaveMode save_mode = shard == nullptr ? SaveMode::SUMMARY : SaveMode::SINGLE_SHARD; + flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false)); + + if (shard != nullptr) { + flow->saver->StartSnapshotInShard(false, shard); + } + + flow->fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow); + return OpStatus::OK; +} + +void DflyCmd::FullSyncFb(FlowInfo* flow) { + error_code ec; + RdbSaver* saver = flow->saver.get(); + + if (saver->Mode() == SaveMode::SUMMARY) { + auto scripts = sf_->script_mgr()->GetLuaScripts(); + ec = saver->SaveHeader(scripts); + } else { + ec = saver->SaveHeader({}); + } + + if (ec) { + LOG(ERROR) << ec; + return; + } + + if (saver->Mode() != SaveMode::SUMMARY) { + // TODO: we should be able to stop earlier if requested. + ec = saver->SaveBody(nullptr); + if (ec) { + LOG(ERROR) << ec; + return; + } + } + + ec = flow->conn->socket()->Write(io::Buffer(flow->eof_token)); + if (ec) { + LOG(ERROR) << ec; + return; + } + + ec = flow->conn->socket()->Shutdown(SHUT_RDWR); +} + uint32_t DflyCmd::CreateSyncSession() { unique_lock lk(mu_); - auto [it, inserted] = sync_infos_.emplace(next_sync_id_, new SyncInfo); + auto sync_info = make_shared(); + sync_info->flows.resize(shard_set->size() + 1); + + auto [it, inserted] = sync_infos_.emplace(next_sync_id_, std::move(sync_info)); CHECK(inserted); return next_sync_id_++; } +void DflyCmd::UnregisterFlow(FlowInfo* flow) { + // TODO: Cancel saver operations. + flow->conn = nullptr; + flow->saver.reset(); +} + void DflyCmd::DeleteSyncSession(uint32_t sync_id) { + shared_ptr sync_info; + + // Remove sync_info from map. + // Store by value to keep alive. + { + unique_lock lk(mu_); + + auto it = sync_infos_.find(sync_id); + if (it == sync_infos_.end()) + return; + + sync_info = it->second; + sync_infos_.erase(it); + } + + // Wait for all operations to finish. + // Set state to CANCELLED so no other operations will run. + { + unique_lock lk(sync_info->mu); + sync_info->state = SyncState::CANCELLED; + } + + // Try to cleanup flows. + for (auto& flow : sync_info->flows) { + if (flow.conn != nullptr) { + VLOG(1) << "Flow connection " << flow.conn->GetName() << " is still alive" + << " on sync_id " << sync_id; + } + // TODO: Implement cancellation. + if (flow.fb.joinable()) { + VLOG(1) << "Force joining fiber on on sync_id " << sync_id; + flow.fb.join(); + } + } +} + +shared_ptr DflyCmd::GetSyncInfo(uint32_t sync_id) { unique_lock lk(mu_); auto it = sync_infos_.find(sync_id); - if (it == sync_infos_.end()) - return; - - delete it->second; - sync_infos_.erase(it); + if (it != sync_infos_.end()) + return it->second; + return {}; } -pair DflyCmd::GetSyncInfoOrReply(std::string_view id_str, - RedisReplyBuilder* rb) { - uint32_t sync_id; +pair> DflyCmd::GetSyncInfoOrReply(std::string_view id_str, + RedisReplyBuilder* rb) { + unique_lock lk(mu_); + uint32_t sync_id; if (!ToSyncId(id_str, &sync_id)) { rb->SendError(kInvalidSyncId); return {0, nullptr}; diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index 3db30b2b9..47a77f646 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -5,6 +5,9 @@ #pragma once #include +#include + +#include #include "server/conn_context.h" @@ -20,6 +23,7 @@ namespace dfly { class EngineShardSet; class ServerFamily; +class RdbSaver; namespace journal { class Journal; @@ -27,12 +31,27 @@ class Journal; class DflyCmd { public: - enum class SyncState { PREPARATION, FULL_SYNC }; + enum class SyncState { PREPARATION, FULL_SYNC, CANCELLED }; + + struct FlowInfo { + FlowInfo() = default; + FlowInfo(facade::Connection* conn, const std::string& eof_token) + : conn(conn), eof_token(eof_token){}; + + facade::Connection* conn; + std::string eof_token; + + std::unique_ptr saver; + + ::boost::fibers::fiber fb; + }; struct SyncInfo { SyncState state = SyncState::PREPARATION; - int64_t tx_id = 0; + std::vector flows; + + ::boost::fibers::mutex mu; // guard operations on replica. }; public: @@ -57,6 +76,10 @@ class DflyCmd { // Return connection thread index or migrate to another thread. void Thread(CmdArgList args, ConnectionContext* cntx); + // FLOW + // Register connection as flow for sync session. + void Flow(CmdArgList args, ConnectionContext* cntx); + // SYNC // Migrate connection to required flow thread. // Stub: will be replcaed with full sync. @@ -66,19 +89,31 @@ class DflyCmd { // Check all keys for expiry. void Expire(CmdArgList args, ConnectionContext* cntx); - // Delete sync session. + // Start full sync in thread. Start FullSyncFb. Called for each flow. + facade::OpStatus StartFullSyncInThread(FlowInfo* flow, EngineShard* shard); + + // Fiber that runs full sync for each flow. + void FullSyncFb(FlowInfo* flow); + + // Unregister flow. Must be called when flow disconnects. + void UnregisterFlow(FlowInfo*); + + // Delete sync session. Cleanup flows. void DeleteSyncSession(uint32_t sync_id); + // Get SyncInfo by sync_id. + std::shared_ptr GetSyncInfo(uint32_t sync_id); + // Find sync info by id or send error reply. - std::pair GetSyncInfoOrReply(std::string_view id, - facade::RedisReplyBuilder* rb); + std::pair> GetSyncInfoOrReply(std::string_view id, + facade::RedisReplyBuilder* rb); ServerFamily* sf_; util::ListenerInterface* listener_; TxId journal_txid_ = 0; - absl::btree_map sync_infos_; + absl::btree_map> sync_infos_; uint32_t next_sync_id_ = 1; ::boost::fibers::mutex mu_; // guard sync info and journal operations. diff --git a/src/server/replica.cc b/src/server/replica.cc index 2c17551c8..85a14b0b9 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -254,6 +254,8 @@ void Replica::ReplicateFb() { state_mask_ &= R_ENABLED; // reset all flags besides R_ENABLED continue; } + + service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); VLOG(1) << "Replica greet ok"; } @@ -358,8 +360,8 @@ error_code Replica::Greet() { master_context_.dfly_session_id = param1; num_df_flows_ = param2; - VLOG(1) << "Master id: " << param0 << ", sync id: " << param1 - << ", num journals " << num_df_flows_; + VLOG(1) << "Master id: " << param0 << ", sync id: " << param1 << ", num journals " + << num_df_flows_; } else { LOG(ERROR) << "Bad response " << ToSV(io_buf.InputBuffer()); @@ -474,6 +476,27 @@ error_code Replica::InitiateDflySync() { if (ec) return ec; + ReqSerializer serializer{sock_.get()}; + + // Master waits for this command in order to start sending replication stream. + serializer.SendCommand(StrCat("DFLY SYNC ", master_context_.dfly_session_id)); + RETURN_ON_ERR(serializer.ec()); + + base::IoBuf io_buf{128}; + unsigned consumed = 0; + RETURN_ON_ERR(ReadRespReply(&io_buf, &consumed)); + if (resp_args_.size() != 1 || resp_args_.front().type != RespExpr::STRING || + ToSV(resp_args_.front().GetBuf()) != "OK") { + LOG(ERROR) << "Sync failed " << ToSV(io_buf.InputBuffer()); + return make_error_code(errc::bad_message); + } + + for (unsigned i = 0; i < num_df_flows_; ++i) { + shard_flows_[i]->sync_fb_.join(); + } + + LOG(INFO) << "Full sync finished"; + state_mask_ |= R_SYNC_OK; return error_code{}; @@ -710,6 +733,28 @@ error_code Replica::ParseAndExecute(base::IoBuf* io_buf) { return error_code{}; } +void Replica::ReplicateDFFb(unique_ptr io_buf, string eof_token) { + SocketSource ss{sock_.get()}; + io::PrefixSource ps{io_buf->InputBuffer(), &ss}; + + RdbLoader loader(NULL); + loader.Load(&ps); + + if (!eof_token.empty()) { + unique_ptr buf(new uint8_t[eof_token.size()]); + // pass leftover data from the loader. + io::PrefixSource chained(loader.Leftover(), &ps); + VLOG(1) << "Before reading from chained stream"; + io::Result eof_res = chained.Read(io::MutableBytes{buf.get(), eof_token.size()}); + if (!eof_res || *eof_res != eof_token.size()) { + LOG(ERROR) << "Error finding eof token in the stream"; + } + + // TODO - to compare tokens + } + VLOG(1) << "ReplicateDFFb finished after reading " << loader.bytes_read() << " bytes"; +} + error_code Replica::StartFlow() { CHECK(!sock_); DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty()); @@ -720,22 +765,42 @@ error_code Replica::StartFlow() { sock_.reset(mythread->CreateSocket()); RETURN_ON_ERR(sock_->Connect(master_context_.master_ep)); + VLOG(1) << "Sending on flow " << master_context_.master_repl_id << " " + << master_context_.dfly_session_id << " " << master_context_.flow_id; + ReqSerializer serializer{sock_.get()}; - serializer.SendCommand(StrCat("DFLY SYNC ", master_context_.master_repl_id, " ", + serializer.SendCommand(StrCat("DFLY FLOW ", master_context_.master_repl_id, " ", master_context_.dfly_session_id, " ", master_context_.flow_id)); RETURN_ON_ERR(serializer.ec()); parser_.reset(new RedisParser{false}); // client mode - base::IoBuf io_buf{128}; - unsigned consumed = 0; - RETURN_ON_ERR(ReadRespReply(&io_buf, &consumed)); - if (resp_args_.size() != 1 || resp_args_.front().type != RespExpr::STRING || - ToSV(resp_args_.front().GetBuf()) != "OK") { - LOG(ERROR) << "Bad SYNC response " << ToSV(io_buf.InputBuffer()); + std::unique_ptr io_buf{new base::IoBuf(128)}; + unsigned consumed = 0; + RETURN_ON_ERR(ReadRespReply(io_buf.get(), &consumed)); // uses parser_ + + if (resp_args_.size() < 2 || resp_args_[0].type != RespExpr::STRING || + resp_args_[1].type != RespExpr::STRING) { + LOG(ERROR) << "Bad FLOW response " << ToSV(io_buf->InputBuffer()); return make_error_code(errc::bad_message); } + string_view flow_directive = ToSV(resp_args_[0].GetBuf()); + string eof_token; + if (flow_directive == "FULL") { + eof_token = ToSV(resp_args_[1].GetBuf()); + } else { + LOG(ERROR) << "Bad FLOW response " << ToSV(io_buf->InputBuffer()); + } + io_buf->ConsumeInput(consumed); + + state_mask_ = R_ENABLED | R_TCP_CONNECTED; + + // We can not discard io_buf because it may contain data + // besides the response we parsed. Therefore we pass it further to ReplicateDFFb. + sync_fb_ = + ::boost::fibers::fiber(&Replica::ReplicateDFFb, this, std::move(io_buf), move(eof_token)); + return error_code{}; } diff --git a/src/server/replica.h b/src/server/replica.h index ae1efc0e0..2b6c7429e 100644 --- a/src/server/replica.h +++ b/src/server/replica.h @@ -17,7 +17,6 @@ class Service; class ConnectionContext; class Replica { - // The attributes of the master we are connecting to. struct MasterContext { std::string host; @@ -102,6 +101,9 @@ class Replica { std::error_code StartFlow(); + // Full sync fiber function. + void ReplicateDFFb(std::unique_ptr io_buf, std::string eof_token); + Service& service_; ::boost::fibers::fiber sync_fb_; diff --git a/src/server/server_family.cc b/src/server/server_family.cc index aee6ec12d..e0f507097 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -872,7 +872,7 @@ error_code ServerFamily::DoSave(bool new_version, Transaction* trans, string* er // Save summary file. { - const auto& scripts = script_mgr_->GetLuaScripts(); + const auto scripts = script_mgr_->GetLuaScripts(); auto& summary_snapshot = snapshots[shard_set->size()]; summary_snapshot.reset(new RdbSnapshot(fq_threadpool_.get())); if (ec = DoPartialSave(filename, path, now, scripts, summary_snapshot.get(), nullptr)) { @@ -899,7 +899,7 @@ error_code ServerFamily::DoSave(bool new_version, Transaction* trans, string* er VLOG(1) << "Saving to " << path; snapshots[0].reset(new RdbSnapshot(fq_threadpool_.get())); - const auto& lua_scripts = script_mgr_->GetLuaScripts(); + const auto lua_scripts = script_mgr_->GetLuaScripts(); ec = snapshots[0]->Start(SaveMode::RDB, path.generic_string(), lua_scripts); if (!ec) { @@ -1471,6 +1471,12 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) { replica_.swap(new_replica); + GlobalState new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING); + if (new_state != GlobalState::LOADING) { + LOG(WARNING) << GlobalStateName(new_state) << " in progress, ignored"; + return; + } + // Flushing all the data after we marked this instance as replica. Transaction* transaction = cntx->transaction; transaction->Schedule(); @@ -1484,6 +1490,7 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) { // Replica sends response in either case. No need to send response in this function. // It's a bit confusing but simpler. if (!replica_->Run(cntx)) { + service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); replica_.reset(); } diff --git a/tests/dragonfly/__init__.py b/tests/dragonfly/__init__.py index aba567498..8d2a0bf9b 100644 --- a/tests/dragonfly/__init__.py +++ b/tests/dragonfly/__init__.py @@ -1,4 +1,7 @@ import pytest +import typing +import time +import subprocess import time import subprocess @@ -14,6 +17,7 @@ class DflyInstance: self.path = path self.args = args self.cwd = cwd + self.proc = None def start(self): arglist = DflyInstance.format_args(self.args) @@ -29,14 +33,21 @@ class DflyInstance: raise Exception( f"Failed to start instance, return code {return_code}") - def stop(self): + def stop(self, kill=False): + proc, self.proc = self.proc, None + if proc is None: + return + print(f"Stopping instance on {self.port}") try: - self.proc.terminate() - outs, errs = self.proc.communicate(timeout=15) + if kill: + proc.kill() + else: + proc.terminate() + outs, errs = proc.communicate(timeout=15) except subprocess.TimeoutExpired: print("Unable to terminate DragonflyDB gracefully, it was killed") - outs, errs = self.proc.communicate() + outs, errs = proc.communicate() print(outs, errs) def __getitem__(self, k): @@ -64,12 +75,21 @@ class DflyInstanceFactory: self.cwd = cwd self.path = path self.args = args + self.instances = [] def create(self, **kwargs) -> DflyInstance: args = {**self.args, **kwargs} for k, v in args.items(): args[k] = v.format(**self.env) if isinstance(v, str) else v - return DflyInstance(self.path, args, self.cwd) + + instance = DflyInstance(self.path, args, self.cwd) + self.instances.append(instance) + return instance + + def stop_all(self): + """Stop all lanched instances.""" + for instance in self.instances: + instance.stop() def dfly_args(*args): diff --git a/tests/dragonfly/conftest.py b/tests/dragonfly/conftest.py index e8ec1df49..b13c642f9 100644 --- a/tests/dragonfly/conftest.py +++ b/tests/dragonfly/conftest.py @@ -50,7 +50,17 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory: scripts_dir, '../../build-dbg/dragonfly')) args = request.param if request.param else {} - return DflyInstanceFactory(test_env, tmp_dir, path=path, args=args) + factory = DflyInstanceFactory(test_env, tmp_dir, path=path, args=args) + yield factory + factory.stop_all() + + +@pytest.fixture(scope="function") +def df_local_factory(df_factory: DflyInstanceFactory): + factory = DflyInstanceFactory( + df_factory.env, df_factory.cwd, df_factory.path, df_factory.args) + yield factory + factory.stop_all() @pytest.fixture(scope="session") @@ -61,6 +71,7 @@ def df_server(df_factory: DflyInstanceFactory) -> DflyInstance: """ instance = df_factory.create() instance.start() + yield instance clients_left = None diff --git a/tests/dragonfly/generic_test.py b/tests/dragonfly/generic_test.py index 346368aca..d427ed741 100644 --- a/tests/dragonfly/generic_test.py +++ b/tests/dragonfly/generic_test.py @@ -1,4 +1,5 @@ -from dragonfly import dfly_multi_test_args +from . import dfly_multi_test_args +from .utility import batch_fill_data, gen_test_data @dfly_multi_test_args({'keys_output_limit': 512}, {'keys_output_limit': 1024}) @@ -6,7 +7,7 @@ class TestKeys: def test_max_keys(self, client, df_server): max_keys = df_server['keys_output_limit'] - for x in range(max_keys*3): - client.set(str(x), str(x)) + batch_fill_data(client, gen_test_data(max_keys * 3)) + keys = client.keys() assert len(keys) in range(max_keys, max_keys+512) diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py new file mode 100644 index 000000000..96703a2fa --- /dev/null +++ b/tests/dragonfly/replication_test.py @@ -0,0 +1,167 @@ + +import pytest +import asyncio +import aioredis +import redis +import time + +from .utility import * + + +BASE_PORT = 1111 + +""" +Test simple full sync on one replica without altering data during replication. +""" + +# (threads_master, threads_replica, n entries) +simple_full_sync_cases = [ + (2, 2, 100), + (8, 2, 500), + (2, 8, 500), + (6, 4, 500) +] + + +@pytest.mark.parametrize("t_master, t_replica, n_keys", simple_full_sync_cases) +def test_simple_full_sync(df_local_factory, t_master, t_replica, n_keys): + master = df_local_factory.create(port=1111, proactor_threads=t_master) + replica = df_local_factory.create(port=1112, proactor_threads=t_replica) + + # Start master and fill with test data + master.start() + c_master = redis.Redis(port=master.port) + batch_fill_data(c_master, gen_test_data(n_keys)) + + # Start replica and run REPLICAOF + replica.start() + c_replica = redis.Redis(port=replica.port) + c_replica.replicaof("localhost", str(master.port)) + + # Check replica received test data + wait_available(c_replica) + batch_check_data(c_replica, gen_test_data(n_keys)) + + # Stop replication manually + c_replica.replicaof("NO", "ONE") + assert c_replica.set("writeable", "true") + + # Check test data persisted + batch_check_data(c_replica, gen_test_data(n_keys)) + + +""" +Test simple full sync on multiple replicas without altering data during replication. +The replicas start running in parallel. +""" + +# (threads_master, threads_replicas, n entries) +simple_full_sync_multi_cases = [ + (4, [3, 2], 500), + (8, [6, 5, 4], 500), + (8, [2] * 5, 100), + (4, [1] * 20, 500) +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("t_master, t_replicas, n_keys", simple_full_sync_multi_cases) +async def test_simple_full_sync_multi(df_local_factory, t_master, t_replicas, n_keys): + def data_gen(): return gen_test_data(n_keys) + + master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) + replicas = [ + df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t) + for i, t in enumerate(t_replicas) + ] + + # Start master and fill with test data + master.start() + c_master = aioredis.Redis(port=master.port, single_connection_client=True) + await batch_fill_data_async(c_master, data_gen()) + + # Start replica tasks in parallel + tasks = [ + asyncio.create_task(run_sfs_replica( + replica, master, data_gen), name="replica-"+str(replica.port)) + for replica in replicas + ] + + for task in tasks: + assert await task + + await c_master.connection_pool.disconnect() + + +async def run_sfs_replica(replica, master, data_gen): + replica.start() + c_replica = aioredis.Redis( + port=replica.port, single_connection_client=None) + + await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) + + await wait_available_async(c_replica) + await batch_check_data_async(c_replica, data_gen()) + + await c_replica.connection_pool.disconnect() + return True + + +""" +Test replica crash during full sync on multiple replicas without altering data during replication. +""" + + +# (threads_master, threads_replicas, n entries) +simple_full_sync_multi_crash_cases = [ + (5, [1] * 15, 5000), + (5, [1] * 20, 5000), + (5, [1] * 25, 5000) +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("t_master, t_replicas, n_keys", simple_full_sync_multi_crash_cases) +async def test_simple_full_sync_mutli_crash(df_local_factory, t_master, t_replicas, n_keys): + def data_gen(): return gen_test_data(n_keys) + + master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) + replicas = [ + df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t) + for i, t in enumerate(t_replicas) + ] + + # Start master and fill with test data + master.start() + c_master = aioredis.Redis(port=master.port, single_connection_client=True) + await batch_fill_data_async(c_master, data_gen()) + + # Start replica tasks in parallel + tasks = [ + asyncio.create_task(run_sfs_crash_replica( + replica, master, data_gen), name="replica-"+str(replica.port)) + for replica in replicas + ] + + for task in tasks: + assert await task + + # Check master is ok + await batch_check_data_async(c_master, data_gen()) + + await c_master.connection_pool.disconnect() + + +async def run_sfs_crash_replica(replica, master, data_gen): + replica.start() + c_replica = aioredis.Redis( + port=replica.port, single_connection_client=None) + + await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) + + # Kill the replica after a short delay + await asyncio.sleep(0.0) + replica.stop(kill=True) + + await c_replica.connection_pool.disconnect() + return True diff --git a/tests/dragonfly/snapshot_test.py b/tests/dragonfly/snapshot_test.py index 6d74f5f67..db607f02c 100644 --- a/tests/dragonfly/snapshot_test.py +++ b/tests/dragonfly/snapshot_test.py @@ -4,33 +4,22 @@ import redis import string import os import glob - - from pathlib import Path -from dragonfly import dfly_args + +from . import dfly_args +from .utility import batch_check_data, batch_fill_data, gen_test_data BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"} +NUM_KEYS = 100 class SnapshotTestBase: - KEYS = string.ascii_lowercase - def setup(self, tmp_dir: Path): self.tmp_dir = tmp_dir self.rdb_out = tmp_dir / "test.rdb" if self.rdb_out.exists(): self.rdb_out.unlink() - def populate(self, client: redis.Redis): - """Populate instance with test data""" - for k in self.KEYS: - client.set(k, "val-"+k) - - def check(self, client: redis.Redis): - """Check instance contains test data""" - for k in self.KEYS: - assert client.get(k) == "val-"+k - def get_main_file(self, suffix): def is_main(f): return "summary" in f if suffix == "dfs" else True files = glob.glob(str(self.tmp_dir.absolute()) + '/test-*.'+suffix) @@ -45,14 +34,14 @@ class TestRdbSnapshot(SnapshotTestBase): super().setup(tmp_dir) def test_snapshot(self, client: redis.Redis): - super().populate(client) + batch_fill_data(client, gen_test_data(NUM_KEYS)) # save + flush + load client.execute_command("SAVE") assert client.flushall() client.execute_command("DEBUG LOAD " + super().get_main_file("rdb")) - super().check(client) + batch_check_data(client, gen_test_data(NUM_KEYS)) @dfly_args({**BASIC_ARGS, "dbfilename": "test"}) @@ -66,14 +55,14 @@ class TestDflySnapshot(SnapshotTestBase): os.remove(file) def test_snapshot(self, client: redis.Redis): - super().populate(client) + batch_fill_data(client, gen_test_data(NUM_KEYS)) # save + flush + load client.execute_command("SAVE DF") assert client.flushall() client.execute_command("DEBUG LOAD " + super().get_main_file("dfs")) - super().check(client) + batch_check_data(client, gen_test_data(NUM_KEYS)) @dfly_args({**BASIC_ARGS, "dbfilename": "test.rdb", "save_schedule": "*:*"}) @@ -84,7 +73,7 @@ class TestPeriodicSnapshot(SnapshotTestBase): super().setup(tmp_dir) def test_snapshot(self, client: redis.Redis): - super().populate(client) + batch_fill_data(client, gen_test_data(NUM_KEYS)) time.sleep(60) diff --git a/tests/dragonfly/utility.py b/tests/dragonfly/utility.py new file mode 100644 index 000000000..57b0b142f --- /dev/null +++ b/tests/dragonfly/utility.py @@ -0,0 +1,82 @@ +import redis +import aioredis +import itertools +import time +import asyncio + + +def grouper(n, iterable): + """Transform iterable into iterator of chunks of size n""" + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +BATCH_SIZE = 100 + + +def gen_test_data(n): + for i in range(n): + yield "k-"+str(i), "v-"+str(i) + + +def batch_fill_data(client: redis.Redis, gen): + for group in grouper(BATCH_SIZE, gen): + client.mset({k: v for k, v, in group}) + + +async def batch_fill_data_async(client: aioredis.Redis, gen): + for group in grouper(BATCH_SIZE, gen): + await client.mset({k: v for k, v in group}) + + +def as_str_val(v) -> str: + if isinstance(v, str): + return v + elif isinstance(v, bytes): + return v.decode() + else: + return str(v) + + +def batch_check_data(client: redis.Redis, gen): + for group in grouper(BATCH_SIZE, gen): + vals = client.mget(k for k, _ in group) + assert all(as_str_val(vals[i]) == v for i, (_, v) in enumerate(group)) + + +async def batch_check_data_async(client: aioredis.Redis, gen): + for group in grouper(BATCH_SIZE, gen): + vals = await client.mget(k for k, _ in group) + assert all(as_str_val(vals[i]) == v for i, (_, v) in enumerate(group)) + + +def wait_available(client: redis.Redis): + its = 0 + while True: + try: + client.get('key') + print("wait_available iterations:", its) + return + except redis.ResponseError as e: + assert "Can not execute during LOADING" in str(e) + + time.sleep(0.01) + its += 1 + + +async def wait_available_async(client: aioredis.Redis): + its = 0 + while True: + try: + await client.get('key') + print("wait_available iterations:", its) + return + except aioredis.ResponseError as e: + assert "Can not execute during LOADING" in str(e) + + await asyncio.sleep(0.01) + its += 1