diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 6329f2cdf..ffbec4363 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -492,21 +492,8 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorAwait( [this, &flag]() { this->Replicate(flag.host, flag.port); }); - return; // DONT load any snapshots - } - - const auto load_path_result = snapshot_storage_->LoadPath(flag_dir, GetFlag(FLAGS_dbfilename)); - if (load_path_result) { - const std::string load_path = *load_path_result; - if (!load_path.empty()) { - load_result_ = Load(load_path); - } - } else { - if (std::error_code(load_path_result.error()) == std::errc::no_such_file_or_directory) { - LOG(WARNING) << "Load snapshot: No snapshot found"; - } else { - LOG(ERROR) << "Failed to load snapshot: " << load_path_result.error().Format(); - } + } else { // load from snapshot only if --replicaof is empty + LoadFromSnapshot(); } const auto create_snapshot_schedule_fb = [this] { @@ -519,10 +506,26 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorLoadPath(GetFlag(FLAGS_dir), GetFlag(FLAGS_dbfilename)); + if (load_path_result) { + const std::string load_path = *load_path_result; + if (!load_path.empty()) { + load_result_ = Load(load_path); + } + } else { + if (std::error_code(load_path_result.error()) == std::errc::no_such_file_or_directory) { + LOG(WARNING) << "Load snapshot: No snapshot found"; + } else { + LOG(ERROR) << "Failed to load snapshot: " << load_path_result.error().Format(); + } + } +} + void ServerFamily::JoinSnapshotSchedule() { schedule_done_.Notify(); snapshot_schedule_fb_.JoinIfNeeded(); @@ -2007,9 +2010,14 @@ void ServerFamily::Hello(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::ReplicaOfInternal(string_view host, string_view port_sv, ConnectionContext* cntx, ActionOnConnectionFail on_err) { LOG(INFO) << "Replicating " << host << ":" << port_sv; - unique_lock lk(replicaof_mu_); // Only one REPLICAOF command can run at a time + // We should not execute replica of command while loading from snapshot. + if (ServerState::tlocal()->is_master && service_.GetGlobalState() == GlobalState::LOADING) { + cntx->SendError("Can not execute during LOADING"); + return; + } + // If NO ONE was supplied, just stop the current replica (if it exists) if (IsReplicatingNoOne(host, port_sv)) { if (!ServerState::tlocal()->is_master) { diff --git a/src/server/server_family.h b/src/server/server_family.h index 825758ce7..04c8cc189 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -207,6 +207,7 @@ class ServerFamily { private: void JoinSnapshotSchedule(); + void LoadFromSnapshot(); uint32_t shard_count() const { return shard_set->size(); diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index eedbc5c54..bc874f276 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -1806,3 +1806,33 @@ async def test_client_pause_with_replica(df_local_factory, df_seeder_factory): assert await seeder.compare(capture, port=replica.port) await disconnect_clients(c_master, c_replica) + + +async def test_replicaof_reject_on_load(df_local_factory, df_seeder_factory): + tmp_file_name = "".join(random.choices(string.ascii_letters, k=10)) + master = df_local_factory.create() + replica = df_local_factory.create(dbfilename=f"dump_{tmp_file_name}") + df_local_factory.start_all([master, replica]) + + seeder = df_seeder_factory.create(port=replica.port, keys=30000) + await seeder.run(target_deviation=0.1) + c_replica = replica.client() + dbsize = await c_replica.dbsize() + assert dbsize >= 9000 + + replica.stop() + replica.start() + c_replica = replica.client() + # Check replica of not alowed while loading snapshot + try: + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + assert False + except aioredis.ResponseError as e: + assert "Can not execute during LOADING" in str(e) + # Check one we finish loading snapshot replicaof success + await wait_available_async(c_replica) + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + + await c_replica.close() + master.stop() + replica.stop()