feat(aws): add s3 awsv2 (#1929)

* feat(aws): add s3 awsv2

* feat(aws): add s3 snapshot test

* feat(aws): disable ec2 metadata by default

* feat(aws): add s3 disable payload signing flag

* chore: update helio

* fix: fix requirements.txt

* feat(s3): update sign payload flag

* chore: update helio
This commit is contained in:
Andy Dunstall 2023-10-06 10:24:56 +01:00 committed by GitHub
parent 0c1402c4ab
commit 2d28b48481
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 148 additions and 76 deletions

2
helio

@ -1 +1 @@
Subproject commit 1a2a1fc5ce2ad09d4dbe310470e06776f2c527ab
Subproject commit 6655f713c7cb6aebc586973ebea6d6bfea06df3c

View file

@ -46,7 +46,7 @@ add_library(dragonfly_lib engine_shard_set.cc channel_store.cc command_registry.
find_library(ZSTD_LIB NAMES libzstd.a libzstdstatic.a zstd NAMES_PER_DIR REQUIRED)
cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib aws_lib strings_lib html_lib
cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib awsv2_lib strings_lib html_lib
http_client_lib absl::random_random TRDP::jsoncons ${ZSTD_LIB} TRDP::lz4
TRDP::croncpp)

View file

@ -258,14 +258,6 @@ void SaveStagesController::UpdateSaveInfo() {
}
GenericError SaveStagesController::InitResources() {
if (is_cloud_ && !aws_) {
*aws_ = make_unique<cloud::AWS>("s3");
if (auto ec = aws_->get()->Init(); ec) {
aws_->reset();
return {ec, "Couldn't initialize AWS"};
}
}
snapshots_.resize(use_dfs_format_ ? shard_set->size() + 1 : 1);
for (auto& [snapshot, _] : snapshots_)
snapshot = make_unique<RdbSnapshot>(fq_threadpool_, snapshot_storage_.get());

View file

@ -10,7 +10,6 @@
#include "server/detail/snapshot_storage.h"
#include "server/rdb_save.h"
#include "server/server_family.h"
#include "util/cloud/aws.h"
#include "util/fibers/fiberqueue_threadpool.h"
namespace dfly {
@ -29,7 +28,6 @@ struct SaveStagesInputs {
util::fb2::FiberQueueThreadPool* fq_threadpool_;
std::shared_ptr<LastSaveInfo>* last_save_info_;
util::fb2::Mutex* save_mu_;
std::unique_ptr<util::cloud::AWS>* aws_;
std::shared_ptr<SnapshotStorage> snapshot_storage_;
};
@ -119,7 +117,5 @@ struct SaveStagesController : public SaveStagesInputs {
GenericError ValidateFilename(const std::filesystem::path& filename, bool new_version);
std::string InferLoadFile(string_view dir, util::cloud::AWS* aws);
} // namespace detail
} // namespace dfly

View file

@ -5,13 +5,21 @@
#include <absl/strings/str_replace.h>
#include <absl/strings/strip.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/ListObjectsV2Request.h>
#include <aws/s3/model/PutObjectRequest.h>
#include <regex>
#include "base/logging.h"
#include "io/file_util.h"
#include "server/engine_shard_set.h"
#include "util/cloud/s3.h"
#include "util/aws/aws.h"
#include "util/aws/credentials_provider_chain.h"
#include "util/aws/s3_endpoint_provider.h"
#include "util/aws/s3_read_file.h"
#include "util/aws/s3_write_file.h"
#include "util/fibers/fiber_file.h"
namespace dfly {
@ -166,30 +174,44 @@ io::Result<std::vector<std::string>, GenericError> FileSnapshotStorage::LoadPath
return paths;
}
AwsS3SnapshotStorage::AwsS3SnapshotStorage(util::cloud::AWS* aws) : aws_{aws} {
AwsS3SnapshotStorage::AwsS3SnapshotStorage(const std::string& endpoint, bool ec2_metadata,
bool sign_payload) {
shard_set->pool()->GetNextProactor()->Await([&] {
if (!ec2_metadata) {
setenv("AWS_EC2_METADATA_DISABLED", "true", 0);
}
// S3ClientConfiguration may request configuration and credentials from
// EC2 metadata so must be run in a proactor thread.
Aws::S3::S3ClientConfiguration s3_conf{};
if (!sign_payload) {
s3_conf.payloadSigningPolicy = Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::ForceNever;
}
std::shared_ptr<Aws::Auth::AWSCredentialsProvider> credentials_provider =
std::make_shared<util::aws::CredentialsProviderChain>();
// Pass a custom endpoint. If empty uses the S3 endpoint.
std::shared_ptr<Aws::S3::S3EndpointProviderBase> endpoint_provider =
std::make_shared<util::aws::S3EndpointProvider>(endpoint);
s3_ = std::make_shared<Aws::S3::S3Client>(credentials_provider, endpoint_provider, s3_conf);
});
}
io::Result<std::pair<io::Sink*, uint8_t>, GenericError> AwsS3SnapshotStorage::OpenWriteFile(
const std::string& path) {
DCHECK(aws_);
util::fb2::ProactorBase* proactor = shard_set->pool()->GetNextProactor();
return proactor->Await([&]() -> io::Result<std::pair<io::Sink*, uint8_t>, GenericError> {
std::optional<std::pair<std::string, std::string>> bucket_path = GetBucketPath(path);
if (!bucket_path) {
return nonstd::make_unexpected(GenericError("Invalid S3 path"));
}
auto [bucket, key] = *bucket_path;
io::Result<util::aws::S3WriteFile> file = util::aws::S3WriteFile::Open(bucket, key, s3_);
if (!file) {
return nonstd::make_unexpected(GenericError(file.error(), "Failed to open write file"));
}
std::optional<std::pair<std::string, std::string>> bucket_path = GetBucketPath(path);
if (!bucket_path) {
return nonstd::make_unexpected(GenericError("Invalid S3 path"));
}
auto [bucket_name, obj_path] = *bucket_path;
util::cloud::S3Bucket bucket(*aws_, bucket_name);
std::error_code ec = bucket.Connect(kBucketConnectMs);
if (ec) {
return nonstd::make_unexpected(GenericError(ec, "Couldn't connect to S3 bucket"));
}
auto res = bucket.OpenWriteFile(obj_path);
if (!res) {
return nonstd::make_unexpected(GenericError(res.error(), "Couldn't open file for writing"));
}
return std::pair<io::Sink*, uint8_t>(*res, FileType::CLOUD);
util::aws::S3WriteFile* f = new util::aws::S3WriteFile(std::move(*file));
return std::pair<io::Sink*, uint8_t>(f, FileType::CLOUD);
});
}
io::ReadonlyFileOrError AwsS3SnapshotStorage::OpenReadFile(const std::string& path) {
@ -197,18 +219,8 @@ io::ReadonlyFileOrError AwsS3SnapshotStorage::OpenReadFile(const std::string& pa
if (!bucket_path) {
return nonstd::make_unexpected(GenericError("Invalid S3 path"));
}
auto [bucket_name, obj_path] = *bucket_path;
util::cloud::S3Bucket bucket{*aws_, bucket_name};
std::error_code ec = bucket.Connect(kBucketConnectMs);
if (ec) {
return nonstd::make_unexpected(GenericError(ec, "Couldn't connect to S3 bucket"));
}
auto res = bucket.OpenReadFile(obj_path);
if (!res) {
return nonstd::make_unexpected(GenericError(res.error(), "Couldn't open file for reading"));
}
return res;
auto [bucket, key] = *bucket_path;
return new util::aws::S3ReadFile(bucket, key, s3_);
}
io::Result<std::string, GenericError> AwsS3SnapshotStorage::LoadPath(std::string_view dir,
@ -312,19 +324,29 @@ io::Result<std::vector<std::string>, GenericError> AwsS3SnapshotStorage::LoadPat
io::Result<std::vector<std::string>, GenericError> AwsS3SnapshotStorage::ListObjects(
std::string_view bucket_name, std::string_view prefix) {
util::cloud::S3Bucket bucket(*aws_, bucket_name);
std::error_code ec = bucket.Connect(kBucketConnectMs);
if (ec) {
return nonstd::make_unexpected(GenericError{ec, "Couldn't connect to S3 bucket"});
}
// Each list objects request has a 1000 object limit, so page through the
// objects if needed.
std::string continuation_token;
std::vector<std::string> keys;
ec = bucket.ListAllObjects(
prefix, [&](size_t sz, std::string_view name) { keys.push_back(std::string(name)); });
if (ec) {
return nonstd::make_unexpected(GenericError{ec, "Couldn't list objects in S3 bucket"});
}
do {
Aws::S3::Model::ListObjectsV2Request request;
request.SetBucket(std::string(bucket_name));
request.SetPrefix(std::string(prefix));
if (!continuation_token.empty()) {
request.SetContinuationToken(continuation_token);
}
Aws::S3::Model::ListObjectsV2Outcome outcome = s3_->ListObjectsV2(request);
if (outcome.IsSuccess()) {
continuation_token = outcome.GetResult().GetNextContinuationToken();
for (const auto& object : outcome.GetResult().GetContents()) {
keys.push_back(object.GetKey());
}
} else {
return nonstd::make_unexpected(GenericError{"Failed list objects in S3 bucket: " +
outcome.GetError().GetExceptionName()});
}
} while (!continuation_token.empty());
return keys;
}

View file

@ -3,6 +3,8 @@
#pragma once
#include <aws/s3/S3Client.h>
#include <filesystem>
#include <string>
#include <string_view>
@ -10,7 +12,6 @@
#include "io/io.h"
#include "server/common.h"
#include "util/cloud/aws.h"
#include "util/fibers/fiberqueue_threadpool.h"
#include "util/fibers/uring_file.h"
@ -71,7 +72,7 @@ class FileSnapshotStorage : public SnapshotStorage {
class AwsS3SnapshotStorage : public SnapshotStorage {
public:
AwsS3SnapshotStorage(util::cloud::AWS* aws);
AwsS3SnapshotStorage(const std::string& endpoint, bool ec2_metadata, bool sign_payload);
io::Result<std::pair<io::Sink*, uint8_t>, GenericError> OpenWriteFile(
const std::string& path) override;
@ -90,7 +91,7 @@ class AwsS3SnapshotStorage : public SnapshotStorage {
io::Result<std::vector<std::string>, GenericError> ListObjects(std::string_view bucket_name,
std::string_view prefix);
util::cloud::AWS* aws_;
std::shared_ptr<Aws::S3::S3Client> s3_;
};
// Returns bucket_name, obj_path for an s3 path.

View file

@ -51,8 +51,7 @@ extern "C" {
#include "server/version.h"
#include "strings/human_readable.h"
#include "util/accept_server.h"
#include "util/cloud/aws.h"
#include "util/cloud/s3.h"
#include "util/aws/aws.h"
#include "util/fibers/fiber_file.h"
using namespace std;
@ -90,6 +89,17 @@ ABSL_FLAG(ReplicaOfFlag, replicaof, ReplicaOfFlag{},
"to replicate. "
"Format should be <IPv4>:<PORT> or host:<PORT> or [<IPv6>]:<PORT>");
ABSL_FLAG(string, s3_endpoint, "", "endpoint for s3 snapshots, default uses aws regional endpoint");
// Disable EC2 metadata by default, or if a users credentials are invalid the
// AWS client will spent 30s trying to connect to inaccessable EC2 endpoints
// to load the credentials.
ABSL_FLAG(bool, s3_ec2_metadata, false,
"whether to load credentials and configuration from EC2 metadata");
// Enables S3 payload signing over HTTP. This reduces the latency and resource
// usage when writing snapshots to S3, at the expense of security.
ABSL_FLAG(bool, s3_sign_payload, true,
"whether to sign the s3 request payload when uploading snapshots");
ABSL_DECLARE_FLAG(int32_t, port);
ABSL_DECLARE_FLAG(bool, cache_mode);
ABSL_DECLARE_FLAG(uint32_t, hz);
@ -422,12 +432,10 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vector<facade::Listen
string flag_dir = GetFlag(FLAGS_dir);
if (IsCloudPath(flag_dir)) {
aws_ = make_unique<cloud::AWS>("s3");
auto ec = shard_set->pool()->GetNextProactor()->Await([&] { return aws_->Init(); });
if (ec) {
LOG(FATAL) << "Failed to initialize AWS " << ec;
}
snapshot_storage_ = std::make_shared<detail::AwsS3SnapshotStorage>(aws_.get());
shard_set->pool()->GetNextProactor()->Await([&] { util::aws::Init(); });
snapshot_storage_ = std::make_shared<detail::AwsS3SnapshotStorage>(
absl::GetFlag(FLAGS_s3_endpoint), absl::GetFlag(FLAGS_s3_ec2_metadata),
absl::GetFlag(FLAGS_s3_sign_payload));
} else if (fq_threadpool_) {
snapshot_storage_ = std::make_shared<detail::FileSnapshotStorage>(fq_threadpool_.get());
} else {
@ -898,9 +906,9 @@ GenericError ServerFamily::DoSave() {
}
GenericError ServerFamily::DoSave(bool new_version, string_view basename, Transaction* trans) {
SaveStagesController sc{detail::SaveStagesInputs{
new_version, basename, trans, &service_, &is_saving_, fq_threadpool_.get(), &last_save_info_,
&save_mu_, &aws_, snapshot_storage_}};
SaveStagesController sc{detail::SaveStagesInputs{new_version, basename, trans, &service_,
&is_saving_, fq_threadpool_.get(),
&last_save_info_, &save_mu_, snapshot_storage_}};
return sc.Save();
}

View file

@ -15,14 +15,11 @@
#include "server/replica.h"
namespace util {
class AcceptServer;
class ListenerInterface;
class HttpListenerBase;
namespace cloud {
class AWS;
} // namespace cloud
} // namespace util
namespace dfly {
@ -252,7 +249,6 @@ class ServerFamily {
Done schedule_done_;
std::unique_ptr<FiberQueueThreadPool> fq_threadpool_;
std::unique_ptr<util::cloud::AWS> aws_;
std::shared_ptr<detail::SnapshotStorage> snapshot_storage_;
};

View file

@ -18,3 +18,4 @@ aiohttp==3.8.4
numpy==1.24.3
pytest-json-report==1.5.0
psutil==5.9.5
boto3==1.28.55

View file

@ -5,6 +5,8 @@ import glob
import asyncio
from redis import asyncio as aioredis
from pathlib import Path
import boto3
import logging
from . import dfly_args
from .utility import DflySeeder, wait_available_async
@ -261,3 +263,57 @@ class TestDflyInfoPersistenceLoadingField(SnapshotTestBase):
assert "0" == self.extract_is_loading_field(res)
await a_client.connection_pool.disconnect()
# If DRAGONFLY_S3_BUCKET is configured, AWS credentials must also be
# configured.
@pytest.mark.skipif(
"DRAGONFLY_S3_BUCKET" not in os.environ, reason="AWS S3 snapshots bucket is not configured"
)
@dfly_args({"dir": "s3://{DRAGONFLY_S3_BUCKET}{DRAGONFLY_TMP}", "dbfilename": ""})
class TestS3Snapshot:
"""Test a snapshot using S3 storage"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
@pytest.mark.asyncio
@pytest.mark.slow
async def test_snapshot(self, df_seeder_factory, async_client, df_server):
seeder = df_seeder_factory.create(port=df_server.port, **SEEDER_ARGS)
await seeder.run(target_deviation=0.1)
start_capture = await seeder.capture()
try:
# save + flush + load
await async_client.execute_command("SAVE DF snapshot")
assert await async_client.flushall()
await async_client.execute_command(
"DEBUG LOAD "
+ os.environ["DRAGONFLY_S3_BUCKET"]
+ str(self.tmp_dir)
+ "/snapshot-summary.dfs"
)
assert await seeder.compare(start_capture, port=df_server.port)
finally:
self._delete_objects(
os.environ["DRAGONFLY_S3_BUCKET"],
str(self.tmp_dir)[1:],
)
def _delete_objects(self, bucket, prefix):
client = boto3.client("s3")
resp = client.list_objects_v2(
Bucket=bucket,
Prefix=prefix,
)
keys = []
for obj in resp["Contents"]:
keys.append({"Key": obj["Key"]})
client.delete_objects(
Bucket=bucket,
Delete={"Objects": keys},
)