mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
feat(tls): support runtime tls reconfig (#2047)
* feat(tls): add tls reconfig * feat(config): error if multiple config params given * tls: move ctx ref to connection
This commit is contained in:
parent
3095d8a168
commit
124bafc06b
7 changed files with 208 additions and 37 deletions
|
@ -300,9 +300,20 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener,
|
|||
}
|
||||
|
||||
migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections);
|
||||
|
||||
#ifdef DFLY_USE_SSL
|
||||
// Increment reference counter so Listener won't free the context while we're
|
||||
// still using it.
|
||||
if (ctx) {
|
||||
SSL_CTX_up_ref(ctx);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Connection::~Connection() {
|
||||
#ifdef DFLY_USE_SSL
|
||||
SSL_CTX_free(ctx_);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Called from Connection::Shutdown() right after socket_->Shutdown call.
|
||||
|
|
|
@ -67,24 +67,31 @@ namespace {
|
|||
|
||||
#ifdef DFLY_USE_SSL
|
||||
|
||||
// Creates the TLS context. Returns nullptr if the TLS configuration is invalid.
|
||||
// To connect: openssl s_client -state -crlf -connect 127.0.0.1:6380
|
||||
SSL_CTX* CreateSslServerCntx() {
|
||||
const auto& tls_key_file = GetFlag(FLAGS_tls_key_file);
|
||||
if (tls_key_file.empty()) {
|
||||
LOG(ERROR) << "To use TLS, a server certificate must be provided with the --tls_key_file flag!";
|
||||
exit(-1);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SSL_CTX* ctx = SSL_CTX_new(TLS_server_method());
|
||||
unsigned mask = SSL_VERIFY_NONE;
|
||||
|
||||
DFLY_SSL_CHECK(1 == SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM));
|
||||
if (SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM) != 1) {
|
||||
LOG(ERROR) << "Failed to load TLS key";
|
||||
return nullptr;
|
||||
}
|
||||
const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file);
|
||||
|
||||
if (!tls_cert_file.empty()) {
|
||||
// TO connect with redis-cli you need both tls-key-file and tls-cert-file
|
||||
// loaded. Use `redis-cli --tls -p 6380 --insecure PING` to test
|
||||
DFLY_SSL_CHECK(1 == SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()));
|
||||
if (SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()) != 1) {
|
||||
LOG(ERROR) << "Failed to load TLS certificate";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const auto tls_ca_cert_file = GetFlag(FLAGS_tls_ca_cert_file);
|
||||
|
@ -92,7 +99,10 @@ SSL_CTX* CreateSslServerCntx() {
|
|||
if (!tls_ca_cert_file.empty() || !tls_ca_cert_dir.empty()) {
|
||||
const auto* file = tls_ca_cert_file.empty() ? nullptr : tls_ca_cert_file.data();
|
||||
const auto* dir = tls_ca_cert_dir.empty() ? nullptr : tls_ca_cert_dir.data();
|
||||
DFLY_SSL_CHECK(1 == SSL_CTX_load_verify_locations(ctx, file, dir));
|
||||
if (SSL_CTX_load_verify_locations(ctx, file, dir) != 1) {
|
||||
LOG(ERROR) << "Failed to load TLS verify locations (CA cert file or CA cert dir)";
|
||||
return nullptr;
|
||||
}
|
||||
mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
|
||||
}
|
||||
|
||||
|
@ -145,9 +155,10 @@ bool ConfigureKeepAlive(int fd) {
|
|||
Listener::Listener(Protocol protocol, ServiceInterface* si, Role role)
|
||||
: service_(si), protocol_(protocol) {
|
||||
#ifdef DFLY_USE_SSL
|
||||
if (GetFlag(FLAGS_tls)) {
|
||||
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL);
|
||||
ctx_ = CreateSslServerCntx();
|
||||
// Always initialise OpenSSL so we can enable TLS at runtime.
|
||||
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, nullptr);
|
||||
if (!ReconfigureTLS()) {
|
||||
exit(-1);
|
||||
}
|
||||
#endif
|
||||
role_ = role;
|
||||
|
@ -198,6 +209,27 @@ error_code Listener::ConfigureServerSocket(int fd) {
|
|||
return error_code{};
|
||||
}
|
||||
|
||||
bool Listener::ReconfigureTLS() {
|
||||
SSL_CTX* prev_ctx = ctx_;
|
||||
if (GetFlag(FLAGS_tls)) {
|
||||
SSL_CTX* ctx = CreateSslServerCntx();
|
||||
if (!ctx) {
|
||||
return false;
|
||||
}
|
||||
ctx_ = ctx;
|
||||
} else {
|
||||
ctx_ = nullptr;
|
||||
}
|
||||
|
||||
if (prev_ctx) {
|
||||
// SSL_CTX is reference counted so if other connections have a reference
|
||||
// to the context it won't be freed yet.
|
||||
SSL_CTX_free(prev_ctx);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Listener::PreAcceptLoop(util::ProactorBase* pb) {
|
||||
per_thread_.resize(pool()->size());
|
||||
}
|
||||
|
|
|
@ -34,6 +34,9 @@ class Listener : public util::ListenerInterface {
|
|||
|
||||
std::error_code ConfigureServerSocket(int fd) final;
|
||||
|
||||
// ReconfigureTLS MUST be called from the same proactor as the listener.
|
||||
bool ReconfigureTLS();
|
||||
|
||||
// Wait until all connections that pass the filter have stopped dispatching or until a timeout has
|
||||
// run out. Returns true if the all connections have stopped dispatching.
|
||||
bool AwaitDispatches(absl::Duration timeout,
|
||||
|
|
|
@ -298,9 +298,9 @@ bool IsValidSaveScheduleNibble(string_view time, unsigned int max) {
|
|||
// enabled. That means either using a password or giving a root
|
||||
// certificate for authenticating client certificates which will
|
||||
// be required.
|
||||
void ValidateServerTlsFlags() {
|
||||
bool ValidateServerTlsFlags() {
|
||||
if (!absl::GetFlag(FLAGS_tls)) {
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool has_auth = false;
|
||||
|
@ -316,8 +316,10 @@ void ValidateServerTlsFlags() {
|
|||
|
||||
if (!has_auth) {
|
||||
LOG(ERROR) << "TLS configured but no authentication method is used!";
|
||||
exit(1);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsReplicatingNoOne(string_view host, string_view port) {
|
||||
|
@ -446,7 +448,9 @@ ServerFamily::ServerFamily(Service* service) : service_(*service) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
ValidateServerTlsFlags();
|
||||
if (!ValidateServerTlsFlags()) {
|
||||
exit(1);
|
||||
}
|
||||
ValidateClientTlsFlags();
|
||||
}
|
||||
|
||||
|
@ -502,6 +506,25 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vector<facade::Listen
|
|||
return res.has_value();
|
||||
});
|
||||
|
||||
// We only reconfigure TLS when the 'tls' config key changes. Therefore to
|
||||
// update TLS certs, first update tls_cert_file, then set 'tls true'.
|
||||
config_registry.RegisterMutable("tls", [this](const absl::CommandLineFlag& flag) {
|
||||
if (!ValidateServerTlsFlags()) {
|
||||
return false;
|
||||
}
|
||||
for (facade::Listener* l : listeners_) {
|
||||
// Must reconfigure in the listener proactor to avoid a race.
|
||||
if (!l->socket()->proactor()->Await([l] { return l->ReconfigureTLS(); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
config_registry.RegisterMutable("tls_cert_file");
|
||||
config_registry.RegisterMutable("tls_key_file");
|
||||
config_registry.RegisterMutable("tls_ca_cert_file");
|
||||
config_registry.RegisterMutable("tls_ca_cert_dir");
|
||||
|
||||
pb_task_ = shard_set->pool()->GetNextProactor();
|
||||
if (pb_task_->GetKind() == ProactorBase::EPOLL) {
|
||||
fq_threadpool_.reset(new FiberQueueThreadPool(absl::GetFlag(FLAGS_epoll_file_threads)));
|
||||
|
@ -1176,7 +1199,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {
|
|||
string_view sub_cmd = ArgS(args, 0);
|
||||
|
||||
if (sub_cmd == "SET") {
|
||||
if (args.size() < 3) {
|
||||
if (args.size() != 3) {
|
||||
return (*cntx)->SendError(WrongNumArgsError("config|set"));
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
from .instance import DflyInstance, DflyParams, DflyInstanceFactory
|
||||
from . import PortPicker, dfly_args
|
||||
from .utility import DflySeederFactory, gen_certificate
|
||||
from .utility import DflySeederFactory, gen_ca_cert, gen_certificate
|
||||
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
|
||||
|
@ -254,30 +254,22 @@ def memcached_connection(df_server: DflyInstance):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gen_ca_cert(tmp_dir):
|
||||
# We first need to generate the tls certificates to be used by the server
|
||||
|
||||
# Generate CA (certificate authority) key and self-signed certificate
|
||||
# In production, CA should be generated by a third party authority
|
||||
# Expires in one day and is not encrtypted (-nodes)
|
||||
# X.509 format for the key
|
||||
def with_tls_ca_cert_args(tmp_dir):
|
||||
ca_key = os.path.join(tmp_dir, "ca-key.pem")
|
||||
ca_cert = os.path.join(tmp_dir, "ca-cert.pem")
|
||||
step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key} -out {ca_cert} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"'
|
||||
subprocess.run(step, shell=True)
|
||||
|
||||
gen_ca_cert(ca_key, ca_cert)
|
||||
return {"ca_key": ca_key, "ca_cert": ca_cert}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def with_tls_server_args(tmp_dir, gen_ca_cert):
|
||||
def with_tls_server_args(tmp_dir, with_tls_ca_cert_args):
|
||||
tls_server_key = os.path.join(tmp_dir, "df-key.pem")
|
||||
tls_server_req = os.path.join(tmp_dir, "df-req.pem")
|
||||
tls_server_cert = os.path.join(tmp_dir, "df-cert.pem")
|
||||
|
||||
gen_certificate(
|
||||
gen_ca_cert["ca_key"],
|
||||
gen_ca_cert["ca_cert"],
|
||||
with_tls_ca_cert_args["ca_key"],
|
||||
with_tls_ca_cert_args["ca_cert"],
|
||||
tls_server_req,
|
||||
tls_server_key,
|
||||
tls_server_cert,
|
||||
|
@ -288,21 +280,21 @@ def with_tls_server_args(tmp_dir, gen_ca_cert):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def with_ca_tls_server_args(with_tls_server_args, gen_ca_cert):
|
||||
def with_ca_tls_server_args(with_tls_server_args, with_tls_ca_cert_args):
|
||||
args = deepcopy(with_tls_server_args)
|
||||
args["tls_ca_cert_file"] = gen_ca_cert["ca_cert"]
|
||||
args["tls_ca_cert_file"] = with_tls_ca_cert_args["ca_cert"]
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def with_tls_client_args(tmp_dir, gen_ca_cert):
|
||||
def with_tls_client_args(tmp_dir, with_tls_ca_cert_args):
|
||||
tls_client_key = os.path.join(tmp_dir, "client-key.pem")
|
||||
tls_client_req = os.path.join(tmp_dir, "client-req.pem")
|
||||
tls_client_cert = os.path.join(tmp_dir, "client-cert.pem")
|
||||
|
||||
gen_certificate(
|
||||
gen_ca_cert["ca_key"],
|
||||
gen_ca_cert["ca_cert"],
|
||||
with_tls_ca_cert_args["ca_key"],
|
||||
with_tls_ca_cert_args["ca_cert"],
|
||||
tls_client_req,
|
||||
tls_client_key,
|
||||
tls_client_cert,
|
||||
|
@ -313,7 +305,7 @@ def with_tls_client_args(tmp_dir, gen_ca_cert):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def with_ca_tls_client_args(with_tls_client_args, gen_ca_cert):
|
||||
def with_ca_tls_client_args(with_tls_client_args, with_tls_ca_cert_args):
|
||||
args = deepcopy(with_tls_client_args)
|
||||
args["ssl_ca_certs"] = gen_ca_cert["ca_cert"]
|
||||
args["ssl_ca_certs"] = with_tls_ca_cert_args["ca_cert"]
|
||||
return args
|
||||
|
|
|
@ -18,20 +18,20 @@ async def test_tls_no_key(df_factory):
|
|||
server.start()
|
||||
|
||||
|
||||
async def test_tls_password(df_factory, with_tls_server_args, gen_ca_cert):
|
||||
async def test_tls_password(df_factory, with_tls_server_args, with_tls_ca_cert_args):
|
||||
with df_factory.create(requirepass="XXX", **with_tls_server_args) as server:
|
||||
async with server.client(
|
||||
ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"]
|
||||
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
|
||||
) as client:
|
||||
await client.ping()
|
||||
|
||||
|
||||
async def test_tls_client_certs(
|
||||
df_factory, with_ca_tls_server_args, with_tls_client_args, gen_ca_cert
|
||||
df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args
|
||||
):
|
||||
with df_factory.create(**with_ca_tls_server_args) as server:
|
||||
async with server.client(
|
||||
**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]
|
||||
**with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
|
||||
) as client:
|
||||
await client.ping()
|
||||
|
||||
|
@ -52,3 +52,102 @@ async def test_client_tls_cert(df_factory, with_tls_server_args):
|
|||
key_args.pop("tls")
|
||||
with df_factory.create(tls_replication=None, **key_args):
|
||||
pass
|
||||
|
||||
|
||||
async def test_config_update_tls_certs(
|
||||
df_factory, with_tls_server_args, with_tls_ca_cert_args, tmp_dir
|
||||
):
|
||||
# Generate new certificates.
|
||||
ca_key = os.path.join(tmp_dir, "ca-key-new.pem")
|
||||
ca_cert = os.path.join(tmp_dir, "ca-cert-new.pem")
|
||||
gen_ca_cert(ca_key, ca_cert)
|
||||
tls_server_key = os.path.join(tmp_dir, "df-key-new.pem")
|
||||
tls_server_req = os.path.join(tmp_dir, "df-req-new.pem")
|
||||
tls_server_cert = os.path.join(tmp_dir, "df-cert-new.pem")
|
||||
gen_certificate(
|
||||
ca_key,
|
||||
ca_cert,
|
||||
tls_server_req,
|
||||
tls_server_key,
|
||||
tls_server_cert,
|
||||
)
|
||||
|
||||
with df_factory.create(requirepass="XXX", **with_tls_server_args) as server:
|
||||
async with server.client(
|
||||
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
|
||||
) as client:
|
||||
await client.config_set(
|
||||
"tls_key_file",
|
||||
tls_server_key,
|
||||
)
|
||||
await client.config_set("tls_cert_file", tls_server_cert)
|
||||
# Note must still set `tls true` to reload the TLS context.
|
||||
await client.config_set("tls", "true")
|
||||
|
||||
# The existing connection should still work.
|
||||
await client.ping()
|
||||
|
||||
# Connecting with the old CA should fail.
|
||||
with pytest.raises(redis.exceptions.ConnectionError):
|
||||
async with server.client(
|
||||
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
|
||||
) as client:
|
||||
await client.ping()
|
||||
|
||||
# Connecting with the new CA should succeed.
|
||||
async with server.client(ssl=True, password="XXX", ssl_ca_certs=ca_cert) as client:
|
||||
await client.ping()
|
||||
|
||||
|
||||
async def test_config_enable_tls(
|
||||
df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args
|
||||
):
|
||||
with df_factory.create() as server:
|
||||
async with server.client() as client:
|
||||
await client.ping()
|
||||
|
||||
# Note the order here matters as flags are applied in order.
|
||||
await client.config_set(
|
||||
"tls_key_file",
|
||||
with_ca_tls_server_args["tls_key_file"],
|
||||
)
|
||||
await client.config_set(
|
||||
"tls_cert_file",
|
||||
with_ca_tls_server_args["tls_cert_file"],
|
||||
)
|
||||
await client.config_set(
|
||||
"tls_ca_cert_file",
|
||||
with_ca_tls_server_args["tls_ca_cert_file"],
|
||||
)
|
||||
await client.config_set(
|
||||
"tls",
|
||||
"true",
|
||||
)
|
||||
|
||||
# The existing client should still be connected.
|
||||
await client.ping()
|
||||
|
||||
# Connecting without TLS should fail.
|
||||
with pytest.raises(redis.exceptions.ConnectionError):
|
||||
async with server.client() as client_unauth:
|
||||
await client_unauth.ping()
|
||||
|
||||
# Connecting with TLS should succeed.
|
||||
async with server.client(
|
||||
**with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
|
||||
) as client_tls:
|
||||
await client_tls.ping()
|
||||
|
||||
|
||||
async def test_config_disable_tls(
|
||||
df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args
|
||||
):
|
||||
with df_factory.create(**with_ca_tls_server_args) as server:
|
||||
async with server.client(
|
||||
**with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
|
||||
) as client_tls:
|
||||
await client_tls.config_set("tls", "false")
|
||||
|
||||
# Connecting without TLS should succeed.
|
||||
async with server.client() as client_unauth:
|
||||
await client_unauth.ping()
|
||||
|
|
|
@ -581,6 +581,17 @@ async def disconnect_clients(*clients):
|
|||
await asyncio.gather(*(c.connection_pool.disconnect() for c in clients))
|
||||
|
||||
|
||||
def gen_ca_cert(ca_key_path, ca_cert_path):
|
||||
# We first need to generate the tls certificates to be used by the server
|
||||
|
||||
# Generate CA (certificate authority) key and self-signed certificate
|
||||
# In production, CA should be generated by a third party authority
|
||||
# Expires in one day and is not encrtypted (-nodes)
|
||||
# X.509 format for the key
|
||||
step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key_path} -out {ca_cert_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"'
|
||||
subprocess.run(step, shell=True)
|
||||
|
||||
|
||||
def gen_certificate(
|
||||
ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue