diff --git a/src/server/detail/save_stages_controller.cc b/src/server/detail/save_stages_controller.cc index 0099974bd..0f0b9a5b8 100644 --- a/src/server/detail/save_stages_controller.cc +++ b/src/server/detail/save_stages_controller.cc @@ -35,10 +35,6 @@ namespace fs = std::filesystem; namespace { -bool IsCloudPath(string_view path) { - return absl::StartsWith(path, kS3Prefix) || absl::StartsWith(path, kGCSPrefix); -} - // Create a directory and all its parents if they don't exist. error_code CreateDirs(fs::path dir_path) { error_code ec; @@ -387,8 +383,8 @@ GenericError SaveStagesController::FinalizeFileMovement() { // Build full path: get dir, try creating dirs, get filename with placeholder GenericError SaveStagesController::BuildFullPath() { - fs::path dir_path = GetFlag(FLAGS_dir); - if (!dir_path.empty() && !IsCloudPath(GetFlag(FLAGS_dir))) { + fs::path dir_path = cloud_uri_.empty() ? GetFlag(FLAGS_dir) : cloud_uri_; + if (!dir_path.empty() && cloud_uri_.empty() && !IsCloudPath(GetFlag(FLAGS_dir))) { if (auto ec = CreateDirs(dir_path); ec) return {ec, "Failed to create directories"}; } diff --git a/src/server/detail/save_stages_controller.h b/src/server/detail/save_stages_controller.h index abf8bd6bd..564a672ac 100644 --- a/src/server/detail/save_stages_controller.h +++ b/src/server/detail/save_stages_controller.h @@ -29,6 +29,7 @@ struct SaveInfo { struct SaveStagesInputs { bool use_dfs_format_; + std::string_view cloud_uri_; std::string_view basename_; Transaction* trans_; Service* service_; diff --git a/src/server/detail/snapshot_storage.cc b/src/server/detail/snapshot_storage.cc index f327938a9..4ebeea78b 100644 --- a/src/server/detail/snapshot_storage.cc +++ b/src/server/detail/snapshot_storage.cc @@ -34,9 +34,6 @@ using namespace util; using namespace std; namespace { -inline bool IsGcsPath(string_view path) { - return absl::StartsWith(path, kGCSPrefix); -} constexpr string_view kSummarySuffix = "summary.dfs"sv; @@ -270,7 +267,7 @@ io::Result, GenericError> GcsSnapshotStorage::Open } io::ReadonlyFileOrError GcsSnapshotStorage::OpenReadFile(const std::string& path) { - if (!IsGcsPath(path)) + if (!IsGCSPath(path)) return nonstd::make_unexpected(GenericError("Invalid GCS path")); auto [bucket, key] = GetBucketPath(path); @@ -321,7 +318,7 @@ io::Result GcsSnapshotStorage::LoadPath(string_view d io::Result, GenericError> GcsSnapshotStorage::ExpandFromPath( const string& load_path) { - if (!IsGcsPath(load_path)) + if (!IsGCSPath(load_path)) return nonstd::make_unexpected( GenericError(make_error_code(errc::invalid_argument), "Invalid GCS path")); diff --git a/src/server/detail/snapshot_storage.h b/src/server/detail/snapshot_storage.h index 6f217a545..537235c0f 100644 --- a/src/server/detail/snapshot_storage.h +++ b/src/server/detail/snapshot_storage.h @@ -7,6 +7,8 @@ #include #endif +#include + #include #include #include @@ -186,5 +188,17 @@ struct FilenameSubstitutions { void SubstituteFilenamePlaceholders(fs::path* filename, const FilenameSubstitutions& fns); +inline bool IsS3Path(std::string_view path) { + return absl::StartsWith(path, detail::kS3Prefix); +} + +inline bool IsGCSPath(std::string_view path) { + return absl::StartsWith(path, detail::kGCSPrefix); +} + +inline bool IsCloudPath(std::string_view path) { + return IsS3Path(path) || IsGCSPath(path); +} + } // namespace detail } // namespace dfly diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 156634604..a69fd3a63 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -251,12 +251,29 @@ string UnknownCmd(string cmd, CmdArgList args) { absl::StrJoin(args.begin(), args.end(), ", ", CmdArgListFormatter())); } -bool IsS3Path(string_view path) { - return absl::StartsWith(path, detail::kS3Prefix); -} - -bool IsGCSPath(string_view path) { - return absl::StartsWith(path, detail::kGCSPrefix); +std::shared_ptr CreateCloudSnapshotStorage(std::string_view uri) { + if (detail::IsS3Path(uri)) { +#ifdef WITH_AWS + shard_set->pool()->GetNextProactor()->Await([&] { util::aws::Init(); }); + return std::make_shared( + absl::GetFlag(FLAGS_s3_endpoint), absl::GetFlag(FLAGS_s3_use_https), + absl::GetFlag(FLAGS_s3_ec2_metadata), absl::GetFlag(FLAGS_s3_sign_payload)); +#else + LOG(ERROR) << "Compiled without AWS support"; + exit(1); +#endif + } else if (detail::IsGCSPath(uri)) { + auto gcs = std::make_shared(); + auto ec = shard_set->pool()->GetNextProactor()->Await([&] { return gcs->Init(3000); }); + if (ec) { + LOG(ERROR) << "Failed to initialize GCS snapshot storage: " << ec.message(); + exit(1); + } + return gcs; + } else { + LOG(ERROR) << "Uknown cloud storage " << uri; + exit(1); + } } // Check that if TLS is used at least one form of client authentication is @@ -854,24 +871,9 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorpool()->GetNextProactor()->Await([&] { util::aws::Init(); }); - snapshot_storage_ = std::make_shared( - absl::GetFlag(FLAGS_s3_endpoint), absl::GetFlag(FLAGS_s3_use_https), - absl::GetFlag(FLAGS_s3_ec2_metadata), absl::GetFlag(FLAGS_s3_sign_payload)); -#else - LOG(ERROR) << "Compiled without AWS support"; - exit(1); -#endif - } else if (IsGCSPath(flag_dir)) { - auto gcs = std::make_shared(); - auto ec = shard_set->pool()->GetNextProactor()->Await([&] { return gcs->Init(3000); }); - if (ec) { - LOG(ERROR) << "Failed to initialize GCS snapshot storage: " << ec.message(); - exit(1); - } - snapshot_storage_ = std::move(gcs); + + if (detail::IsCloudPath(flag_dir)) { + snapshot_storage_ = CreateCloudSnapshotStorage(flag_dir); } else if (fq_threadpool_) { snapshot_storage_ = std::make_shared(fq_threadpool_.get()); } else { @@ -1655,10 +1657,11 @@ GenericError ServerFamily::DoSave(bool ignore_state) { CHECK_NOTNULL(cid); boost::intrusive_ptr trans(new Transaction{cid}); trans->InitByArgs(&namespaces->GetDefaultNamespace(), 0, {}); - return DoSave(absl::GetFlag(FLAGS_df_snapshot_format), {}, trans.get(), ignore_state); + return DoSave(SaveCmdOptions{absl::GetFlag(FLAGS_df_snapshot_format), {}, {}}, trans.get(), + ignore_state); } -GenericError ServerFamily::DoSaveCheckAndStart(bool new_version, string_view basename, +GenericError ServerFamily::DoSaveCheckAndStart(const SaveCmdOptions& save_cmd_opts, Transaction* trans, bool ignore_state) { auto state = ServerState::tlocal()->gstate(); @@ -1674,10 +1677,13 @@ GenericError ServerFamily::DoSaveCheckAndStart(bool new_version, string_view bas "SAVING - can not save database"}; } - VLOG(1) << "Saving snapshot to " << basename; + auto snapshot_storage = save_cmd_opts.cloud_uri.empty() + ? snapshot_storage_ + : CreateCloudSnapshotStorage(save_cmd_opts.cloud_uri); save_controller_ = make_unique(detail::SaveStagesInputs{ - new_version, basename, trans, &service_, fq_threadpool_.get(), snapshot_storage_}); + save_cmd_opts.new_version, save_cmd_opts.cloud_uri, save_cmd_opts.basename, trans, + &service_, fq_threadpool_.get(), snapshot_storage}); auto res = save_controller_->InitResourcesAndStart(); @@ -1714,9 +1720,9 @@ GenericError ServerFamily::WaitUntilSaveFinished(Transaction* trans, bool ignore return save_info.error; } -GenericError ServerFamily::DoSave(bool new_version, string_view basename, Transaction* trans, +GenericError ServerFamily::DoSave(const SaveCmdOptions& save_cmd_opts, Transaction* trans, bool ignore_state) { - if (auto ec = DoSaveCheckAndStart(new_version, basename, trans, ignore_state); ec) { + if (auto ec = DoSaveCheckAndStart(save_cmd_opts, trans, ignore_state); ec) { return ec; } @@ -2078,46 +2084,61 @@ void ServerFamily::BgSaveFb(boost::intrusive_ptr trans) { } } -std::optional ServerFamily::GetVersionAndBasename( - CmdArgList args, SinkReplyBuilder* builder) { - if (args.size() > 2) { +std::optional ServerFamily::GetSaveCmdOpts(CmdArgList args, + SinkReplyBuilder* builder) { + if (args.size() > 3) { builder->SendError(kSyntaxErr); return {}; } - bool new_version = absl::GetFlag(FLAGS_df_snapshot_format); + SaveCmdOptions save_cmd_opts; + save_cmd_opts.new_version = absl::GetFlag(FLAGS_df_snapshot_format); if (args.size() >= 1) { string sub_cmd = absl::AsciiStrToUpper(ArgS(args, 0)); if (sub_cmd == "DF") { - new_version = true; + save_cmd_opts.new_version = true; } else if (sub_cmd == "RDB") { - new_version = false; + save_cmd_opts.new_version = false; } else { builder->SendError(UnknownSubCmd(sub_cmd, "SAVE"), kSyntaxErrType); return {}; } } - string_view basename; - if (args.size() == 2) { - basename = ArgS(args, 1); + if (args.size() >= 2) { + if (detail::IsS3Path(ArgS(args, 1))) { +#ifdef WITH_AWS + save_cmd_opts.cloud_uri = ArgS(args, 1); +#else + LOG(ERROR) << "Compiled without AWS support"; + exit(1); +#endif + } else if (detail::IsGCSPath(ArgS(args, 1))) { + save_cmd_opts.cloud_uri = ArgS(args, 1); + } else { + // no cloud_uri get basename and return + save_cmd_opts.basename = ArgS(args, 1); + return save_cmd_opts; + } + // cloud_uri is set so get basename if provided + if (args.size() == 3) { + save_cmd_opts.basename = ArgS(args, 2); + } } - return ServerFamily::VersionBasename{new_version, basename}; + return save_cmd_opts; } -// BGSAVE [DF|RDB] [basename] +// SAVE [DF|RDB] [CLOUD_URI] [BASENAME] // TODO add missing [SCHEDULE] void ServerFamily::BgSave(CmdArgList args, const CommandContext& cmd_cntx) { - auto maybe_res = GetVersionAndBasename(args, cmd_cntx.rb); + auto maybe_res = GetSaveCmdOpts(args, cmd_cntx.rb); if (!maybe_res) { return; } - const auto [version, basename] = *maybe_res; - - if (auto ec = DoSaveCheckAndStart(version, basename, cmd_cntx.tx); ec) { + if (auto ec = DoSaveCheckAndStart(*maybe_res, cmd_cntx.tx); ec) { cmd_cntx.rb->SendError(ec.Format()); return; } @@ -2127,18 +2148,16 @@ void ServerFamily::BgSave(CmdArgList args, const CommandContext& cmd_cntx) { cmd_cntx.rb->SendOk(); } -// SAVE [DF|RDB] [basename] +// SAVE [DF|RDB] [CLOUD_URI] [BASENAME] // Allows saving the snapshot of the dataset on disk, potentially overriding the format // and the snapshot name. void ServerFamily::Save(CmdArgList args, const CommandContext& cmd_cntx) { - auto maybe_res = GetVersionAndBasename(args, cmd_cntx.rb); + auto maybe_res = GetSaveCmdOpts(args, cmd_cntx.rb); if (!maybe_res) { return; } - const auto [version, basename] = *maybe_res; - - GenericError ec = DoSave(version, basename, cmd_cntx.tx); + GenericError ec = DoSave(*maybe_res, cmd_cntx.tx); if (ec) { cmd_cntx.rb->SendError(ec.Format()); } else { diff --git a/src/server/server_family.h b/src/server/server_family.h index db479ab1c..092622701 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -158,6 +158,15 @@ struct ReplicaOffsetInfo { std::vector flow_offsets; }; +struct SaveCmdOptions { + // if new_version is true, saves DF specific, non redis compatible snapshot. + bool new_version; + // cloud storage URI + std::string_view cloud_uri; + // if basename is not empty it will override dbfilename flag + std::string_view basename; +}; + class ServerFamily { using SinkReplyBuilder = facade::SinkReplyBuilder; @@ -193,9 +202,7 @@ class ServerFamily { void StatsMC(std::string_view section, SinkReplyBuilder* builder); - // if new_version is true, saves DF specific, non redis compatible snapshot. - // if basename is not empty it will override dbfilename flag. - GenericError DoSave(bool new_version, std::string_view basename, Transaction* transaction, + GenericError DoSave(const SaveCmdOptions& save_cmd_opts, Transaction* transaction, bool ignore_state = false); // Calls DoSave with a default generated transaction and with the format @@ -313,14 +320,11 @@ class ServerFamily { void SendInvalidationMessages() const; - // Helper function to retrieve version(true if format is dfs rdb), and basename from args. - // In case of an error an empty optional is returned. - using VersionBasename = std::pair; - std::optional GetVersionAndBasename(CmdArgList args, SinkReplyBuilder* builder); + std::optional GetSaveCmdOpts(CmdArgList args, SinkReplyBuilder* builder); void BgSaveFb(boost::intrusive_ptr trans); - GenericError DoSaveCheckAndStart(bool new_version, string_view basename, Transaction* trans, + GenericError DoSaveCheckAndStart(const SaveCmdOptions& save_cmd_opts, Transaction* trans, bool ignore_state = false) ABSL_LOCKS_EXCLUDED(save_mu_); GenericError WaitUntilSaveFinished(Transaction* trans, diff --git a/tests/dragonfly/snapshot_test.py b/tests/dragonfly/snapshot_test.py index ba6544f48..483fca160 100644 --- a/tests/dragonfly/snapshot_test.py +++ b/tests/dragonfly/snapshot_test.py @@ -311,6 +311,21 @@ async def test_info_persistence_field(async_client): assert "loading:0" in (await async_client.execute_command("INFO PERSISTENCE")) +def delete_s3_objects(bucket, prefix): + client = boto3.client("s3") + resp = client.list_objects_v2( + Bucket=bucket, + Prefix=prefix, + ) + keys = [] + for obj in resp["Contents"]: + keys.append({"Key": obj["Key"]}) + client.delete_objects( + Bucket=bucket, + Delete={"Objects": keys}, + ) + + # If DRAGONFLY_S3_BUCKET is configured, AWS credentials must also be # configured. @pytest.mark.skipif( @@ -338,27 +353,36 @@ async def test_s3_snapshot(async_client, tmp_dir): assert await StaticSeeder.capture(async_client) == start_capture finally: - - def delete_objects(bucket, prefix): - client = boto3.client("s3") - resp = client.list_objects_v2( - Bucket=bucket, - Prefix=prefix, - ) - keys = [] - for obj in resp["Contents"]: - keys.append({"Key": obj["Key"]}) - client.delete_objects( - Bucket=bucket, - Delete={"Objects": keys}, - ) - - delete_objects( + delete_s3_objects( os.environ["DRAGONFLY_S3_BUCKET"], str(tmp_dir)[1:], ) +# If DRAGONFLY_S3_BUCKET is configured, AWS credentials must also be +# configured. +@pytest.mark.skipif( + "DRAGONFLY_S3_BUCKET" not in os.environ or os.environ["DRAGONFLY_S3_BUCKET"] == "", + reason="AWS S3 snapshots bucket is not configured", +) +@dfly_args({**BASIC_ARGS}) +async def test_s3_save_local_dir(async_client): + seeder = StaticSeeder(key_target=10_000) + await seeder.run(async_client) + + try: + # SAVE to S3 bucket with `s3_dump` as filename prefix + await async_client.execute_command( + "SAVE", "DF", "s3://" + os.environ["DRAGONFLY_S3_BUCKET"], "s3_dump" + ) + + finally: + delete_s3_objects( + os.environ["DRAGONFLY_S3_BUCKET"], + "s3_dump", + ) + + @dfly_args({**BASIC_ARGS, "dbfilename": "test-shutdown"}) class TestDflySnapshotOnShutdown: SEEDER_ARGS = dict(key_target=10_000)