feat: add replication over tls (#1525)

1. Introduces `tls_replication` flag to allow tls connections for replicas
2. Add pytests
This commit is contained in:
Kostas Kyrimis 2023-07-19 21:21:46 +03:00 committed by GitHub
parent 6e9f092fa2
commit 078d152ae0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 198 additions and 71 deletions

2
helio

@ -1 +1 @@
Subproject commit 6cb1fed97288339e907eada0a0e30ba531de184d
Subproject commit 0087d9cc59d0522325aacc91b93fdc8509cfb006

View file

@ -6,16 +6,9 @@
#include <absl/base/internal/endian.h>
#ifdef __clang__
#include <experimental/memory_resource>
namespace PMR_NS = std::experimental::pmr;
#else
#include <memory_resource>
namespace PMR_NS = std::pmr;
#endif
#include <optional>
#include "base/pmr/memory_resource.h"
#include "core/json_object.h"
#include "core/small_string.h"

View file

@ -3,16 +3,9 @@
//
#pragma once
#ifdef __clang__
#include <experimental/memory_resource>
namespace PMR_NS = std::experimental::pmr;
#else
#include <memory_resource>
namespace PMR_NS = std::pmr;
#endif
#include <vector>
#include "base/pmr/memory_resource.h"
#include "core/dash_internal.h"
namespace dfly {

View file

@ -8,13 +8,7 @@
#include <functional>
#include <type_traits>
#ifdef __clang__
#include <experimental/memory_resource>
namespace PMR_NS = std::experimental::pmr;
#else
#include <memory_resource>
namespace PMR_NS = std::pmr;
#endif
#include "base/pmr/memory_resource.h"
namespace dfly {

View file

@ -6,13 +6,7 @@
#include <mimalloc.h>
#ifdef __clang__
#include <experimental/memory_resource>
namespace PMR_NS = std::experimental::pmr;
#else
#include <memory_resource>
namespace PMR_NS = std::pmr;
#endif
#include "base/pmr/memory_resource.h"
namespace dfly {

View file

@ -258,8 +258,7 @@ void Connection::OnShutdown() {
void Connection::OnPreMigrateThread() {
// If we migrating to another io_uring we should cancel any pending requests we have.
if (break_poll_id_ != UINT32_MAX) {
auto* ls = static_cast<LinuxSocketBase*>(socket_.get());
ls->CancelPoll(break_poll_id_);
socket_->CancelPoll(break_poll_id_);
break_poll_id_ = UINT32_MAX;
}
}
@ -269,9 +268,8 @@ void Connection::OnPostMigrateThread() {
if (breaker_cb_) {
DCHECK_EQ(UINT32_MAX, break_poll_id_);
auto* ls = static_cast<LinuxSocketBase*>(socket_.get());
break_poll_id_ =
ls->PollEvent(POLLERR | POLLHUP, [this](int32_t mask) { this->OnBreakCb(mask); });
socket_->PollEvent(POLLERR | POLLHUP, [this](int32_t mask) { this->OnBreakCb(mask); });
}
}
@ -293,30 +291,28 @@ void Connection::UnregisterShutdownHook(ShutdownHandle id) {
void Connection::HandleRequests() {
ThisFiber::SetName("DflyConnection");
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
if (absl::GetFlag(FLAGS_tcp_nodelay)) {
int val = 1;
CHECK_EQ(0, setsockopt(lsb->native_handle(), IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)));
CHECK_EQ(0, setsockopt(socket_->native_handle(), IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)));
}
auto remote_ep = lsb->RemoteEndpoint();
auto remote_ep = socket_->RemoteEndpoint();
FiberSocketBase* peer = socket_.get();
#ifdef DFLY_USE_SSL
unique_ptr<tls::TlsSocket> tls_sock;
if (ctx_) {
const bool no_tls_on_admin_port = absl::GetFlag(FLAGS_no_tls_on_admin_port);
if (!(IsAdmin() && no_tls_on_admin_port)) {
tls_sock.reset(new tls::TlsSocket(socket_.get()));
unique_ptr<tls::TlsSocket> tls_sock = make_unique<tls::TlsSocket>(std::move(socket_));
tls_sock->InitSSL(ctx_);
FiberSocketBase::AcceptResult aresult = tls_sock->Accept();
SetSocket(tls_sock.release());
if (!aresult) {
LOG(WARNING) << "Error handshaking " << aresult.error().message();
return;
}
peer = tls_sock.get();
peer = socket_.get();
VLOG(1) << "TLS handshake succeeded";
}
}
@ -339,16 +335,15 @@ void Connection::HandleRequests() {
http_conn.ReleaseSocket();
} else {
cc_.reset(service_->CreateContext(peer, this));
auto* us = static_cast<LinuxSocketBase*>(socket_.get());
if (breaker_cb_) {
break_poll_id_ =
us->PollEvent(POLLERR | POLLHUP, [this](int32_t mask) { this->OnBreakCb(mask); });
socket_->PollEvent(POLLERR | POLLHUP, [this](int32_t mask) { this->OnBreakCb(mask); });
}
ConnectionFlow(peer);
if (break_poll_id_ != UINT32_MAX) {
us->CancelPoll(break_poll_id_);
socket_->CancelPoll(break_poll_id_);
}
cc_.reset();
@ -363,22 +358,19 @@ void Connection::RegisterBreakHook(BreakerCb breaker_cb) {
}
std::string Connection::LocalBindAddress() const {
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
auto le = lsb->LocalEndpoint();
auto le = socket_->LocalEndpoint();
return le.address().to_string();
}
string Connection::GetClientInfo(unsigned thread_id) const {
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
string res;
auto le = lsb->LocalEndpoint();
auto re = lsb->RemoteEndpoint();
auto le = socket_->LocalEndpoint();
auto re = socket_->RemoteEndpoint();
time_t now = time(nullptr);
int cpu = 0;
socklen_t len = sizeof(cpu);
getsockopt(lsb->native_handle(), SOL_SOCKET, SO_INCOMING_CPU, &cpu, &len);
getsockopt(socket_->native_handle(), SOL_SOCKET, SO_INCOMING_CPU, &cpu, &len);
int my_cpu_id = sched_getcpu();
static constexpr string_view PHASE_NAMES[] = {"readsock", "process"};
@ -386,7 +378,7 @@ string Connection::GetClientInfo(unsigned thread_id) const {
absl::StrAppend(&res, "id=", id_, " addr=", re.address().to_string(), ":", re.port());
absl::StrAppend(&res, " laddr=", le.address().to_string(), ":", le.port());
absl::StrAppend(&res, " fd=", lsb->native_handle(), " name=", name_);
absl::StrAppend(&res, " fd=", socket_->native_handle(), " name=", name_);
absl::StrAppend(&res, " tid=", thread_id, " irqmatch=", int(cpu == my_cpu_id));
absl::StrAppend(&res, " age=", now - creation_time_, " idle=", now - last_interaction_);
absl::StrAppend(&res, " phase=", PHASE_NAMES[phase_]);
@ -406,9 +398,8 @@ uint32_t Connection::GetClientId() const {
}
bool Connection::IsAdmin() const {
auto* lsb = static_cast<LinuxSocketBase*>(socket_.get());
uint16_t admin_port = absl::GetFlag(FLAGS_admin_port);
return lsb->LocalEndpoint().port() == admin_port;
return socket_->LocalEndpoint().port() == admin_port;
}
io::Result<bool> Connection::CheckForHttpProto(FiberSocketBase* peer) {
@ -903,18 +894,16 @@ void Connection::EnsureAsyncMemoryBudget() {
}
std::string Connection::RemoteEndpointStr() const {
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
bool unix_socket = lsb->IsUDS();
const bool unix_socket = socket_->IsUDS();
std::string connection_str = unix_socket ? "unix:" : std::string{};
auto re = lsb->RemoteEndpoint();
auto re = socket_->RemoteEndpoint();
absl::StrAppend(&connection_str, re.address().to_string(), ":", re.port());
return connection_str;
}
std::string Connection::RemoteEndpointAddress() const {
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
auto re = lsb->RemoteEndpoint();
auto re = socket_->RemoteEndpoint();
return re.address().to_string();
}

View file

@ -57,7 +57,7 @@ namespace {
#ifdef DFLY_USE_SSL
// To connect: openssl s_client -cipher "ADH:@SECLEVEL=0" -state -crlf -connect 127.0.0.1:6380
static SSL_CTX* CreateSslCntx() {
static SSL_CTX* CreateSslServerCntx() {
SSL_CTX* ctx = SSL_CTX_new(TLS_server_method());
const auto& tls_key_file = GetFlag(FLAGS_tls_key_file);
unsigned mask = SSL_VERIFY_NONE;
@ -139,7 +139,7 @@ Listener::Listener(Protocol protocol, ServiceInterface* si) : service_(si), prot
#ifdef DFLY_USE_SSL
if (GetFlag(FLAGS_tls)) {
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL);
ctx_ = CreateSslCntx();
ctx_ = CreateSslServerCntx();
}
#endif
@ -275,7 +275,7 @@ void Listener::OnConnectionClose(util::Connection* conn) {
}
// We can limit number of threads handling dragonfly connections.
ProactorBase* Listener::PickConnectionProactor(LinuxSocketBase* sock) {
ProactorBase* Listener::PickConnectionProactor(util::FiberSocketBase* sock) {
util::ProactorPool* pp = pool();
uint32_t res_id = kuint32max;

View file

@ -12,6 +12,7 @@
#include <vector>
#include "facade/facade_types.h"
#include "util/fiber_socket_base.h"
#include "util/fibers/proactor_base.h"
#include "util/http/http_handler.h"
#include "util/listener_interface.h"
@ -36,7 +37,7 @@ class Listener : public util::ListenerInterface {
private:
util::Connection* NewConnection(ProactorBase* proactor) final;
ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final;
ProactorBase* PickConnectionProactor(util::FiberSocketBase* sock) final;
void OnConnectionStart(util::Connection* conn) final;
void OnConnectionClose(util::Connection* conn) final;

View file

@ -31,6 +31,11 @@ add_library(dragonfly_lib channel_store.cc command_registry.cc
cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib aws_lib strings_lib html_lib
http_client_lib absl::random_random TRDP::jsoncons zstd TRDP::lz4)
if (DF_USE_SSL)
set(TLS_LIB tls_lib)
target_compile_definitions(dragonfly_lib PRIVATE DFLY_USE_SSL)
endif()
add_library(dfly_test_lib test_utils.cc)
cxx_link(dfly_test_lib dragonfly_lib facade_test gtest_main_ext)

View file

@ -15,6 +15,7 @@ extern "C" {
#include <absl/strings/strip.h>
#include <boost/asio/ip/tcp.hpp>
#include <string>
#include "base/logging.h"
#include "facade/dragonfly_connection.h"
@ -26,7 +27,17 @@ extern "C" {
#include "server/rdb_load.h"
#include "strings/human_readable.h"
#ifdef DFLY_USE_SSL
#include "util/tls/tls_socket.h"
#endif
ABSL_FLAG(std::string, masterauth, "", "password for authentication with master");
ABSL_FLAG(bool, tls_replication, false, "Enable TLS on replication");
ABSL_DECLARE_FLAG(std::string, tls_cert_file);
ABSL_DECLARE_FLAG(std::string, tls_key_file);
ABSL_DECLARE_FLAG(std::string, tls_ca_cert_file);
ABSL_DECLARE_FLAG(std::string, tls_ca_cert_dir);
namespace dfly {
@ -39,6 +50,36 @@ using absl::StrCat;
namespace {
#ifdef DFLY_USE_SSL
static ProtocolClient::SSL_CTX* CreateSslClientCntx() {
ProtocolClient::SSL_CTX* ctx = SSL_CTX_new(TLS_client_method());
const auto& tls_key_file = GetFlag(FLAGS_tls_key_file);
unsigned mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
CHECK_EQ(1, SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM));
const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file);
CHECK_EQ(1, SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()));
const auto& tls_ca_cert_file = GetFlag(FLAGS_tls_ca_cert_file);
const auto& tls_ca_cert_dir = GetFlag(FLAGS_tls_ca_cert_dir);
const auto* file = tls_ca_cert_file.empty() ? nullptr : tls_ca_cert_file.data();
const auto* dir = tls_ca_cert_dir.empty() ? nullptr : tls_ca_cert_dir.data();
CHECK_EQ(1, SSL_CTX_load_verify_locations(ctx, file, dir));
CHECK_EQ(1, SSL_CTX_set_cipher_list(ctx, "DEFAULT"));
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
SSL_CTX_set_verify(ctx, mask, NULL);
CHECK_EQ(1, SSL_CTX_set_dh_auto(ctx, 1));
return ctx;
}
#endif
int ResolveDns(std::string_view host, char* dest) {
struct addrinfo hints, *servinfo;
@ -101,9 +142,37 @@ std::string ProtocolClient::ServerContext::Description() const {
return absl::StrCat(host, ":", port);
}
void ProtocolClient::ValidateTlsFlags() const {
if (absl::GetFlag(FLAGS_tls_cert_file).empty()) {
LOG(ERROR) << "tls_cert_file flag should be set";
exit(1);
}
if (absl::GetFlag(FLAGS_tls_ca_cert_file).empty() &&
absl::GetFlag(FLAGS_tls_ca_cert_dir).empty()) {
LOG(ERROR) << "Either or both tls_ca_cert_file or tls_ca_cert_dir flags must be set";
exit(1);
}
}
void ProtocolClient::MaybeInitSslCtx() {
if (absl::GetFlag(FLAGS_tls_replication)) {
ValidateTlsFlags();
ssl_ctx_ = CreateSslClientCntx();
}
}
ProtocolClient::ProtocolClient(string host, uint16_t port) {
server_context_.host = std::move(host);
server_context_.port = port;
#ifdef DFLY_USE_SSL
MaybeInitSslCtx();
#endif
}
ProtocolClient::ProtocolClient(ServerContext context) : server_context_(std::move(context)) {
#ifdef DFLY_USE_SSL
MaybeInitSslCtx();
#endif
}
ProtocolClient::~ProtocolClient() {
@ -111,6 +180,11 @@ ProtocolClient::~ProtocolClient() {
auto ec = sock_->Close();
LOG_IF(ERROR, ec) << "Error closing socket " << ec;
}
#ifdef DFLY_USE_SSL
if (ssl_ctx_) {
SSL_CTX_free(ssl_ctx_);
}
#endif
}
error_code ProtocolClient::ResolveMasterDns() {
@ -136,7 +210,13 @@ error_code ProtocolClient::ConnectAndAuth(std::chrono::milliseconds connect_time
// run we must not create a new socket. sock_mu_ syncs between the two
// functions.
if (!cntx->IsCancelled()) {
sock_.reset(mythread->CreateSocket());
if (ssl_ctx_) {
auto tls_sock = std::make_unique<tls::TlsSocket>(mythread->CreateSocket());
tls_sock->InitSSL(ssl_ctx_);
sock_.reset(tls_sock.release());
} else {
sock_.reset(mythread->CreateSocket());
}
serializer_.reset(new ReqSerializer(sock_.get()));
} else {
return cntx->GetError();

View file

@ -17,6 +17,10 @@
#include "server/version.h"
#include "util/fiber_socket_base.h"
#ifdef DFLY_USE_SSL
#include <openssl/ssl.h>
#endif
namespace facade {
class ReqSerializer;
}; // namespace facade
@ -32,6 +36,10 @@ struct JournalReader;
// This class should be inherited from.
class ProtocolClient {
public:
#ifdef DFLY_USE_SSL
using SSL_CTX = struct ssl_ctx_st;
#endif
ProtocolClient(std::string master_host, uint16_t port);
virtual ~ProtocolClient();
@ -51,8 +59,7 @@ class ProtocolClient {
// Constructing using a fully initialized ServerContext allows to skip
// the DNS resolution step.
explicit ProtocolClient(ServerContext context) : server_context_(std::move(context)) {
}
explicit ProtocolClient(ServerContext context);
std::error_code ResolveMasterDns(); // Resolve master dns
// Connect to master and authenticate if needed.
@ -101,7 +108,7 @@ class ProtocolClient {
return sock_->proactor();
}
util::LinuxSocketBase* Sock() const {
util::FiberSocketBase* Sock() const {
return sock_.get();
}
@ -113,7 +120,7 @@ class ProtocolClient {
facade::RespVec resp_args_;
base::IoBuf resp_buf_;
std::unique_ptr<util::LinuxSocketBase> sock_;
std::unique_ptr<util::FiberSocketBase> sock_;
Mutex sock_mu_;
protected:
@ -123,6 +130,13 @@ class ProtocolClient {
std::string last_resp_;
uint64_t last_io_time_ = 0; // in ns, monotonic clock.
#ifdef DFLY_USE_SSL
void ValidateTlsFlags() const;
void MaybeInitSslCtx();
SSL_CTX* ssl_ctx_{nullptr};
#endif
};
} // namespace dfly

View file

@ -1259,7 +1259,6 @@ async def test_no_tls_on_admin_port(
df_local_factory, df_seeder_factory, t_master, t_replica, with_tls_server_args
):
# 1. Spin up dragonfly without tls, debug populate
master = df_local_factory.create(
no_tls_on_admin_port="true",
admin_port=ADMIN_PORT,
@ -1274,7 +1273,6 @@ async def test_no_tls_on_admin_port(
assert 100 == db_size
# 2. Spin up a replica and initiate a REPLICAOF
replica = df_local_factory.create(
no_tls_on_admin_port="true",
admin_port=ADMIN_PORT + 1,
@ -1293,3 +1291,69 @@ async def test_no_tls_on_admin_port(
assert 100 == db_size
await c_replica.close()
await c_master.close()
# 1. Number of master threads
# 2. Number of threads for each replica
# 3. Admin port
replication_cases = [(8, 8, False), (8, 8, True)]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replica, test_admin_port", replication_cases)
async def test_tls_replication(
df_local_factory,
df_seeder_factory,
t_master,
t_replica,
test_admin_port,
with_ca_tls_server_args,
with_ca_tls_client_args,
):
# 1. Spin up dragonfly tls enabled, debug populate
master = df_local_factory.create(
tls_replication="true",
**with_ca_tls_server_args,
port=BASE_PORT,
admin_port=ADMIN_PORT,
proactor_threads=t_master,
)
master.start()
c_master = aioredis.Redis(port=master.port, **with_ca_tls_client_args)
await c_master.execute_command("DEBUG POPULATE 100")
db_size = await c_master.execute_command("DBSIZE")
assert 100 == db_size
# 2. Spin up a replica and initiate a REPLICAOF
replica = df_local_factory.create(
tls_replication="true",
**with_ca_tls_server_args,
port=BASE_PORT + 1,
proactor_threads=t_replica,
)
replica.start()
c_replica = aioredis.Redis(port=replica.port, **with_ca_tls_client_args)
port = master.port if not test_admin_port else master.admin_port
res = await c_replica.execute_command("REPLICAOF localhost " + str(port))
assert b"OK" == res
await check_all_replicas_finished([c_replica], c_master)
# 3. Verify that replica dbsize == debug populate key size -- replication works
db_size = await c_replica.execute_command("DBSIZE")
assert 100 == db_size
# 4. Kill master, spin it up and see if replica reconnects
master.stop(kill=True)
master.start()
c_master = aioredis.Redis(port=master.port, **with_ca_tls_client_args)
# Master doesn't load the snapshot, therefore dbsize should be 0
await c_master.execute_command("SET MY_KEY 1")
db_size = await c_master.execute_command("DBSIZE")
assert 1 == db_size
await check_all_replicas_finished([c_replica], c_master)
db_size = await c_replica.execute_command("DBSIZE")
assert 1 == db_size
await c_replica.close()
await c_master.close()