diff --git a/src/server/debugcmd.cc b/src/server/debugcmd.cc index 20e951940..932ba2915 100644 --- a/src/server/debugcmd.cc +++ b/src/server/debugcmd.cc @@ -170,13 +170,9 @@ void DebugCmd::Reload(CmdArgList args) { if (save) { string err_details; - const CommandId* cid = sf_.service().FindCmd("SAVE"); - CHECK_NOTNULL(cid); - intrusive_ptr trans(new Transaction{cid, ServerState::tlocal()->thread_index()}); - trans->InitByArgs(0, {}); VLOG(1) << "Performing save"; - GenericError ec = sf_.DoSave(absl::GetFlag(FLAGS_df_snapshot_format), trans.get()); + GenericError ec = sf_.DoSave(); if (ec) { return (*cntx_)->SendError(ec.Format()); } diff --git a/src/server/rdb_test.cc b/src/server/rdb_test.cc index 8c7341815..04fd2097c 100644 --- a/src/server/rdb_test.cc +++ b/src/server/rdb_test.cc @@ -38,18 +38,21 @@ namespace dfly { class RdbTest : public BaseFamilyTest { protected: - static void SetUpTestSuite(); void TearDown(); + void SetUp(); io::FileSource GetSource(string name); }; -void RdbTest::SetUpTestSuite() { - BaseFamilyTest::SetUpTestSuite(); +void RdbTest::SetUp() { SetFlag(&FLAGS_dbfilename, "rdbtestdump"); + BaseFamilyTest::SetUp(); } void RdbTest::TearDown() { + // Disable save on shutdown + SetFlag(&FLAGS_dbfilename, ""); + auto rdb_files = io::StatFiles("rdbtestdump*"); CHECK(rdb_files); for (const auto& fl : *rdb_files) { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 29082d517..a93beca98 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -467,7 +467,7 @@ void ServerFamily::Init(util::AcceptServer* acceptor, util::ListenerInterface* m if (!save_time.empty()) { std::optional spec = ParseSaveSchedule(save_time); if (spec) { - snapshot_fiber_ = service_.proactor_pool().GetNextProactor()->LaunchFiber( + snapshot_schedule_fb_ = service_.proactor_pool().GetNextProactor()->LaunchFiber( [save_spec = std::move(spec.value()), this] { SnapshotScheduling(save_spec); }); } else { LOG(WARNING) << "Invalid snapshot time specifier " << save_time; @@ -481,9 +481,18 @@ void ServerFamily::Shutdown() { if (load_result_.valid()) load_result_.wait(); - is_snapshot_done_.Notify(); - if (snapshot_fiber_.IsJoinable()) { - snapshot_fiber_.Join(); + schedule_done_.Notify(); + if (snapshot_schedule_fb_.IsJoinable()) { + snapshot_schedule_fb_.Join(); + } + + if (save_on_shutdown_ && !absl::GetFlag(FLAGS_dbfilename).empty()) { + shard_set->pool()->GetNextProactor()->Await([this] { + GenericError ec = DoSave(); + if (ec) { + LOG(WARNING) << "Failed to perform snapshot " << ec.Format(); + } + }); } pb_task_->Await([this] { @@ -608,7 +617,7 @@ Future ServerFamily::Load(const std::string& load_path) { void ServerFamily::SnapshotScheduling(const SnapshotSpec& spec) { const auto loop_sleep_time = std::chrono::seconds(20); while (true) { - if (is_snapshot_done_.WaitFor(loop_sleep_time)) { + if (schedule_done_.WaitFor(loop_sleep_time)) { break; } @@ -629,13 +638,7 @@ void ServerFamily::SnapshotScheduling(const SnapshotSpec& spec) { continue; } - const CommandId* cid = service().FindCmd("SAVE"); - CHECK_NOTNULL(cid); - boost::intrusive_ptr trans( - new Transaction{cid, ServerState::tlocal()->thread_index()}); - trans->InitByArgs(0, {}); - - GenericError ec = DoSave(absl::GetFlag(FLAGS_df_snapshot_format), trans.get()); + GenericError ec = DoSave(); if (ec) { LOG(WARNING) << "Failed to perform snapshot " << ec.Format(); } @@ -665,9 +668,6 @@ error_code ServerFamily::LoadRdb(const std::string& rdb_file) { } else { ec = res.error(); } - - service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); - return ec; } @@ -924,6 +924,15 @@ GenericError DoPartialSave(PartialSaveOpts opts, const dfly::StringVec& scripts, return local_ec; } +GenericError ServerFamily::DoSave() { + const CommandId* cid = service().FindCmd("SAVE"); + CHECK_NOTNULL(cid); + boost::intrusive_ptr trans( + new Transaction{cid, ServerState::tlocal()->thread_index()}); + trans->InitByArgs(0, {}); + return DoSave(absl::GetFlag(FLAGS_df_snapshot_format), trans.get()); +} + GenericError ServerFamily::DoSave(bool new_version, Transaction* trans) { fs::path dir_path(GetFlag(FLAGS_dir)); AggregateGenericError ec; @@ -2036,6 +2045,22 @@ void ServerFamily::Latency(CmdArgList args, ConnectionContext* cntx) { } void ServerFamily::_Shutdown(CmdArgList args, ConnectionContext* cntx) { + if (args.size() > 1) { + (*cntx)->SendError(kSyntaxErr); + return; + } + + if (args.size() == 1) { + auto sub_cmd = ArgS(args, 0); + if (absl::EqualsIgnoreCase(sub_cmd, "SAVE")) { + } else if (absl::EqualsIgnoreCase(sub_cmd, "NOSAVE")) { + save_on_shutdown_ = false; + } else { + (*cntx)->SendError(kSyntaxErr); + return; + } + } + CHECK_NOTNULL(acceptor_)->Stop(); (*cntx)->SendOk(); } @@ -2079,7 +2104,7 @@ void ServerFamily::Register(CommandRegistry* registry) { << CI{"LATENCY", CO::NOSCRIPT | CO::LOADING | CO::FAST, -2, 0, 0, 0}.HFUNC(Latency) << CI{"MEMORY", kMemOpts, -2, 0, 0, 0}.HFUNC(Memory) << CI{"SAVE", CO::ADMIN | CO::GLOBAL_TRANS, -1, 0, 0, 0}.HFUNC(Save) - << CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, 1, 0, 0, 0}.HFUNC(_Shutdown) + << CI{"SHUTDOWN", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.HFUNC(_Shutdown) << CI{"SLAVEOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf) << CI{"READONLY", CO::READONLY, 1, 0, 0, 0}.HFUNC(ReadOnly) << CI{"REPLICAOF", kReplicaOpts, 3, 0, 0, 0}.HFUNC(ReplicaOf) diff --git a/src/server/server_family.h b/src/server/server_family.h index be6ba113b..431b91ccd 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -91,6 +91,10 @@ class ServerFamily { // if new_version is true, saves DF specific, non redis compatible snapshot. GenericError DoSave(bool new_version, Transaction* transaction); + // Calls DoSave with a default generated transaction and with the format + // specified in --df_snapshot_format + GenericError DoSave(); + // Burns down and destroy all the data from the database. // if kDbAll is passed, burns all the databases to the ground. std::error_code Drakarys(Transaction* transaction, DbIndex db_ind); @@ -161,7 +165,7 @@ class ServerFamily { void SnapshotScheduling(const SnapshotSpec& time); - Fiber snapshot_fiber_; + Fiber snapshot_schedule_fb_; Future load_result_; uint32_t stats_caching_task_ = 0; @@ -186,7 +190,11 @@ class ServerFamily { std::shared_ptr last_save_info_; // protected by save_mu_; std::atomic_bool is_saving_{false}; - Done is_snapshot_done_; + // Used to override save on shutdown behavior that is usually set + // be --dbfilename. + bool save_on_shutdown_{true}; + + Done schedule_done_; std::unique_ptr fq_threadpool_; }; diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index 62cb5035a..78c8cc629 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -36,9 +36,9 @@ replication_cases = [ @pytest.mark.asyncio @pytest.mark.parametrize("t_master, t_replicas, seeder_config", replication_cases) async def test_replication_all(df_local_factory, df_seeder_factory, t_master, t_replicas, seeder_config): - master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) + master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master, dbfilename="") replicas = [ - df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t) + df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t, dbfilename="") for i, t in enumerate(t_replicas) ] @@ -148,10 +148,10 @@ disconnect_cases = [ @pytest.mark.asyncio @pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases) async def test_disconnect_replica(df_local_factory: DflyInstanceFactory, df_seeder_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys): - master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) + master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master, dbfilename="") replicas = [ (df_local_factory.create( - port=BASE_PORT+i+1, proactor_threads=t), crash_fs) + port=BASE_PORT+i+1, proactor_threads=t, dbfilename=""), crash_fs) for i, (t, crash_fs) in enumerate( chain( zip(t_crash_fs, repeat(DISCONNECT_CRASH_FULL_SYNC)), @@ -284,10 +284,10 @@ master_crash_cases = [ @pytest.mark.asyncio @pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases) async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master, t_replicas, n_random_crashes, n_keys): - master = df_local_factory.create(port=1111, proactor_threads=t_master) + master = df_local_factory.create(port=1111, proactor_threads=t_master, dbfilename="") replicas = [ df_local_factory.create( - port=BASE_PORT+i+1, proactor_threads=t) + port=BASE_PORT+i+1, proactor_threads=t, dbfilename="") for i, t in enumerate(t_replicas) ] @@ -398,8 +398,8 @@ async def test_cancel_replication_immediately(df_local_factory, df_seeder_factor """ COMMANDS_TO_ISSUE = 40 - replica = df_local_factory.create(port=BASE_PORT, v=1) - masters = [df_local_factory.create(port=BASE_PORT+i+1) for i in range(4)] + replica = df_local_factory.create(port=BASE_PORT, dbfilename="") + masters = [df_local_factory.create(port=BASE_PORT+i+1, dbfilename="") for i in range(4)] seeders = [df_seeder_factory.create(port=m.port) for m in masters] df_local_factory.start_all([replica] + masters) diff --git a/tests/dragonfly/snapshot_test.py b/tests/dragonfly/snapshot_test.py index ac5b4e7e6..4f7eb93d7 100644 --- a/tests/dragonfly/snapshot_test.py +++ b/tests/dragonfly/snapshot_test.py @@ -4,6 +4,7 @@ import os import glob import aioredis from pathlib import Path +import aioredis from . import dfly_args from .utility import DflySeeder, wait_available_async @@ -18,7 +19,8 @@ class SnapshotTestBase: self.tmp_dir = tmp_dir def get_main_file(self, pattern): - def is_main(f): return "summary" in f if pattern.endswith("dfs") else True + def is_main(f): return "summary" in f if pattern.endswith( + "dfs") else True files = glob.glob(str(self.tmp_dir.absolute()) + '/' + pattern) possible_mains = list(filter(is_main, files)) assert len(possible_mains) == 1, possible_mains @@ -92,6 +94,8 @@ class TestDflySnapshot(SnapshotTestBase): assert await seeder.compare(start_capture) # We spawn instances manually, so reduce memory usage of default to minimum + + @dfly_args({"proactor_threads": "1"}) class TestDflyAutoLoadSnapshot(SnapshotTestBase): """Test automatic loading of dump files on startup with timestamp""" @@ -138,7 +142,8 @@ class TestPeriodicSnapshot(SnapshotTestBase): @pytest.mark.asyncio async def test_snapshot(self, df_seeder_factory, df_server): - seeder = df_seeder_factory.create(port=df_server.port, keys=10, multi_transaction_probability=0) + seeder = df_seeder_factory.create( + port=df_server.port, keys=10, multi_transaction_probability=0) await seeder.run(target_deviation=0.5) time.sleep(60) @@ -156,9 +161,34 @@ class TestPathEscapes(SnapshotTestBase): @pytest.mark.asyncio async def test_snapshot(self, df_local_factory): - df_server = df_local_factory.create(dbfilename="../../../../etc/passwd") + df_server = df_local_factory.create( + dbfilename="../../../../etc/passwd") try: df_server.start() assert False, "Server should not start correctly" except Exception as e: pass + + +@dfly_args({**BASIC_ARGS, "dbfilename": "test-shutdown"}) +class TestDflySnapshotOnShutdown(SnapshotTestBase): + """Test multi file snapshot""" + @pytest.fixture(autouse=True) + def setup(self, tmp_dir: Path): + self.tmp_dir = tmp_dir + + @pytest.mark.asyncio + async def test_snapshot(self, df_seeder_factory, df_server): + seeder = df_seeder_factory.create(port=df_server.port, **SEEDER_ARGS) + await seeder.run(target_deviation=0.1) + + start_capture = await seeder.capture() + + df_server.stop() + df_server.start() + + a_client = aioredis.Redis(port=df_server.port) + await wait_available_async(a_client) + await a_client.connection_pool.disconnect() + + assert await seeder.compare(start_capture)