feat(tls): support runtime tls reconfig (#2047)

* feat(tls): add tls reconfig

* feat(config): error if multiple config params given

* tls: move ctx ref to connection
This commit is contained in:
Andy Dunstall 2023-10-23 17:35:39 +01:00 committed by GitHub
parent 3095d8a168
commit 124bafc06b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 37 deletions

View file

@ -300,9 +300,20 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener,
}
migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections);
#ifdef DFLY_USE_SSL
// Increment reference counter so Listener won't free the context while we're
// still using it.
if (ctx) {
SSL_CTX_up_ref(ctx);
}
#endif
}
Connection::~Connection() {
#ifdef DFLY_USE_SSL
SSL_CTX_free(ctx_);
#endif
}
// Called from Connection::Shutdown() right after socket_->Shutdown call.

View file

@ -67,24 +67,31 @@ namespace {
#ifdef DFLY_USE_SSL
// Creates the TLS context. Returns nullptr if the TLS configuration is invalid.
// To connect: openssl s_client -state -crlf -connect 127.0.0.1:6380
SSL_CTX* CreateSslServerCntx() {
const auto& tls_key_file = GetFlag(FLAGS_tls_key_file);
if (tls_key_file.empty()) {
LOG(ERROR) << "To use TLS, a server certificate must be provided with the --tls_key_file flag!";
exit(-1);
return nullptr;
}
SSL_CTX* ctx = SSL_CTX_new(TLS_server_method());
unsigned mask = SSL_VERIFY_NONE;
DFLY_SSL_CHECK(1 == SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM));
if (SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM) != 1) {
LOG(ERROR) << "Failed to load TLS key";
return nullptr;
}
const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file);
if (!tls_cert_file.empty()) {
// TO connect with redis-cli you need both tls-key-file and tls-cert-file
// loaded. Use `redis-cli --tls -p 6380 --insecure PING` to test
DFLY_SSL_CHECK(1 == SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()));
if (SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()) != 1) {
LOG(ERROR) << "Failed to load TLS certificate";
return nullptr;
}
}
const auto tls_ca_cert_file = GetFlag(FLAGS_tls_ca_cert_file);
@ -92,7 +99,10 @@ SSL_CTX* CreateSslServerCntx() {
if (!tls_ca_cert_file.empty() || !tls_ca_cert_dir.empty()) {
const auto* file = tls_ca_cert_file.empty() ? nullptr : tls_ca_cert_file.data();
const auto* dir = tls_ca_cert_dir.empty() ? nullptr : tls_ca_cert_dir.data();
DFLY_SSL_CHECK(1 == SSL_CTX_load_verify_locations(ctx, file, dir));
if (SSL_CTX_load_verify_locations(ctx, file, dir) != 1) {
LOG(ERROR) << "Failed to load TLS verify locations (CA cert file or CA cert dir)";
return nullptr;
}
mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
}
@ -145,9 +155,10 @@ bool ConfigureKeepAlive(int fd) {
Listener::Listener(Protocol protocol, ServiceInterface* si, Role role)
: service_(si), protocol_(protocol) {
#ifdef DFLY_USE_SSL
if (GetFlag(FLAGS_tls)) {
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL);
ctx_ = CreateSslServerCntx();
// Always initialise OpenSSL so we can enable TLS at runtime.
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, nullptr);
if (!ReconfigureTLS()) {
exit(-1);
}
#endif
role_ = role;
@ -198,6 +209,27 @@ error_code Listener::ConfigureServerSocket(int fd) {
return error_code{};
}
bool Listener::ReconfigureTLS() {
SSL_CTX* prev_ctx = ctx_;
if (GetFlag(FLAGS_tls)) {
SSL_CTX* ctx = CreateSslServerCntx();
if (!ctx) {
return false;
}
ctx_ = ctx;
} else {
ctx_ = nullptr;
}
if (prev_ctx) {
// SSL_CTX is reference counted so if other connections have a reference
// to the context it won't be freed yet.
SSL_CTX_free(prev_ctx);
}
return true;
}
void Listener::PreAcceptLoop(util::ProactorBase* pb) {
per_thread_.resize(pool()->size());
}

View file

@ -34,6 +34,9 @@ class Listener : public util::ListenerInterface {
std::error_code ConfigureServerSocket(int fd) final;
// ReconfigureTLS MUST be called from the same proactor as the listener.
bool ReconfigureTLS();
// Wait until all connections that pass the filter have stopped dispatching or until a timeout has
// run out. Returns true if the all connections have stopped dispatching.
bool AwaitDispatches(absl::Duration timeout,

View file

@ -298,9 +298,9 @@ bool IsValidSaveScheduleNibble(string_view time, unsigned int max) {
// enabled. That means either using a password or giving a root
// certificate for authenticating client certificates which will
// be required.
void ValidateServerTlsFlags() {
bool ValidateServerTlsFlags() {
if (!absl::GetFlag(FLAGS_tls)) {
return;
return true;
}
bool has_auth = false;
@ -316,8 +316,10 @@ void ValidateServerTlsFlags() {
if (!has_auth) {
LOG(ERROR) << "TLS configured but no authentication method is used!";
exit(1);
return false;
}
return true;
}
bool IsReplicatingNoOne(string_view host, string_view port) {
@ -446,7 +448,9 @@ ServerFamily::ServerFamily(Service* service) : service_(*service) {
exit(1);
}
ValidateServerTlsFlags();
if (!ValidateServerTlsFlags()) {
exit(1);
}
ValidateClientTlsFlags();
}
@ -502,6 +506,25 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vector<facade::Listen
return res.has_value();
});
// We only reconfigure TLS when the 'tls' config key changes. Therefore to
// update TLS certs, first update tls_cert_file, then set 'tls true'.
config_registry.RegisterMutable("tls", [this](const absl::CommandLineFlag& flag) {
if (!ValidateServerTlsFlags()) {
return false;
}
for (facade::Listener* l : listeners_) {
// Must reconfigure in the listener proactor to avoid a race.
if (!l->socket()->proactor()->Await([l] { return l->ReconfigureTLS(); })) {
return false;
}
}
return true;
});
config_registry.RegisterMutable("tls_cert_file");
config_registry.RegisterMutable("tls_key_file");
config_registry.RegisterMutable("tls_ca_cert_file");
config_registry.RegisterMutable("tls_ca_cert_dir");
pb_task_ = shard_set->pool()->GetNextProactor();
if (pb_task_->GetKind() == ProactorBase::EPOLL) {
fq_threadpool_.reset(new FiberQueueThreadPool(absl::GetFlag(FLAGS_epoll_file_threads)));
@ -1176,7 +1199,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {
string_view sub_cmd = ArgS(args, 0);
if (sub_cmd == "SET") {
if (args.size() < 3) {
if (args.size() != 3) {
return (*cntx)->SendError(WrongNumArgsError("config|set"));
}

View file

@ -21,7 +21,7 @@ from tempfile import TemporaryDirectory
from .instance import DflyInstance, DflyParams, DflyInstanceFactory
from . import PortPicker, dfly_args
from .utility import DflySeederFactory, gen_certificate
from .utility import DflySeederFactory, gen_ca_cert, gen_certificate
logging.getLogger("asyncio").setLevel(logging.WARNING)
@ -254,30 +254,22 @@ def memcached_connection(df_server: DflyInstance):
@pytest.fixture(scope="session")
def gen_ca_cert(tmp_dir):
# We first need to generate the tls certificates to be used by the server
# Generate CA (certificate authority) key and self-signed certificate
# In production, CA should be generated by a third party authority
# Expires in one day and is not encrtypted (-nodes)
# X.509 format for the key
def with_tls_ca_cert_args(tmp_dir):
ca_key = os.path.join(tmp_dir, "ca-key.pem")
ca_cert = os.path.join(tmp_dir, "ca-cert.pem")
step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key} -out {ca_cert} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"'
subprocess.run(step, shell=True)
gen_ca_cert(ca_key, ca_cert)
return {"ca_key": ca_key, "ca_cert": ca_cert}
@pytest.fixture(scope="session")
def with_tls_server_args(tmp_dir, gen_ca_cert):
def with_tls_server_args(tmp_dir, with_tls_ca_cert_args):
tls_server_key = os.path.join(tmp_dir, "df-key.pem")
tls_server_req = os.path.join(tmp_dir, "df-req.pem")
tls_server_cert = os.path.join(tmp_dir, "df-cert.pem")
gen_certificate(
gen_ca_cert["ca_key"],
gen_ca_cert["ca_cert"],
with_tls_ca_cert_args["ca_key"],
with_tls_ca_cert_args["ca_cert"],
tls_server_req,
tls_server_key,
tls_server_cert,
@ -288,21 +280,21 @@ def with_tls_server_args(tmp_dir, gen_ca_cert):
@pytest.fixture(scope="session")
def with_ca_tls_server_args(with_tls_server_args, gen_ca_cert):
def with_ca_tls_server_args(with_tls_server_args, with_tls_ca_cert_args):
args = deepcopy(with_tls_server_args)
args["tls_ca_cert_file"] = gen_ca_cert["ca_cert"]
args["tls_ca_cert_file"] = with_tls_ca_cert_args["ca_cert"]
return args
@pytest.fixture(scope="session")
def with_tls_client_args(tmp_dir, gen_ca_cert):
def with_tls_client_args(tmp_dir, with_tls_ca_cert_args):
tls_client_key = os.path.join(tmp_dir, "client-key.pem")
tls_client_req = os.path.join(tmp_dir, "client-req.pem")
tls_client_cert = os.path.join(tmp_dir, "client-cert.pem")
gen_certificate(
gen_ca_cert["ca_key"],
gen_ca_cert["ca_cert"],
with_tls_ca_cert_args["ca_key"],
with_tls_ca_cert_args["ca_cert"],
tls_client_req,
tls_client_key,
tls_client_cert,
@ -313,7 +305,7 @@ def with_tls_client_args(tmp_dir, gen_ca_cert):
@pytest.fixture(scope="session")
def with_ca_tls_client_args(with_tls_client_args, gen_ca_cert):
def with_ca_tls_client_args(with_tls_client_args, with_tls_ca_cert_args):
args = deepcopy(with_tls_client_args)
args["ssl_ca_certs"] = gen_ca_cert["ca_cert"]
args["ssl_ca_certs"] = with_tls_ca_cert_args["ca_cert"]
return args

View file

@ -18,20 +18,20 @@ async def test_tls_no_key(df_factory):
server.start()
async def test_tls_password(df_factory, with_tls_server_args, gen_ca_cert):
async def test_tls_password(df_factory, with_tls_server_args, with_tls_ca_cert_args):
with df_factory.create(requirepass="XXX", **with_tls_server_args) as server:
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"]
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client:
await client.ping()
async def test_tls_client_certs(
df_factory, with_ca_tls_server_args, with_tls_client_args, gen_ca_cert
df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args
):
with df_factory.create(**with_ca_tls_server_args) as server:
async with server.client(
**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]
**with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client:
await client.ping()
@ -52,3 +52,102 @@ async def test_client_tls_cert(df_factory, with_tls_server_args):
key_args.pop("tls")
with df_factory.create(tls_replication=None, **key_args):
pass
async def test_config_update_tls_certs(
df_factory, with_tls_server_args, with_tls_ca_cert_args, tmp_dir
):
# Generate new certificates.
ca_key = os.path.join(tmp_dir, "ca-key-new.pem")
ca_cert = os.path.join(tmp_dir, "ca-cert-new.pem")
gen_ca_cert(ca_key, ca_cert)
tls_server_key = os.path.join(tmp_dir, "df-key-new.pem")
tls_server_req = os.path.join(tmp_dir, "df-req-new.pem")
tls_server_cert = os.path.join(tmp_dir, "df-cert-new.pem")
gen_certificate(
ca_key,
ca_cert,
tls_server_req,
tls_server_key,
tls_server_cert,
)
with df_factory.create(requirepass="XXX", **with_tls_server_args) as server:
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client:
await client.config_set(
"tls_key_file",
tls_server_key,
)
await client.config_set("tls_cert_file", tls_server_cert)
# Note must still set `tls true` to reload the TLS context.
await client.config_set("tls", "true")
# The existing connection should still work.
await client.ping()
# Connecting with the old CA should fail.
with pytest.raises(redis.exceptions.ConnectionError):
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client:
await client.ping()
# Connecting with the new CA should succeed.
async with server.client(ssl=True, password="XXX", ssl_ca_certs=ca_cert) as client:
await client.ping()
async def test_config_enable_tls(
df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args
):
with df_factory.create() as server:
async with server.client() as client:
await client.ping()
# Note the order here matters as flags are applied in order.
await client.config_set(
"tls_key_file",
with_ca_tls_server_args["tls_key_file"],
)
await client.config_set(
"tls_cert_file",
with_ca_tls_server_args["tls_cert_file"],
)
await client.config_set(
"tls_ca_cert_file",
with_ca_tls_server_args["tls_ca_cert_file"],
)
await client.config_set(
"tls",
"true",
)
# The existing client should still be connected.
await client.ping()
# Connecting without TLS should fail.
with pytest.raises(redis.exceptions.ConnectionError):
async with server.client() as client_unauth:
await client_unauth.ping()
# Connecting with TLS should succeed.
async with server.client(
**with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client_tls:
await client_tls.ping()
async def test_config_disable_tls(
df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args
):
with df_factory.create(**with_ca_tls_server_args) as server:
async with server.client(
**with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client_tls:
await client_tls.config_set("tls", "false")
# Connecting without TLS should succeed.
async with server.client() as client_unauth:
await client_unauth.ping()

View file

@ -581,6 +581,17 @@ async def disconnect_clients(*clients):
await asyncio.gather(*(c.connection_pool.disconnect() for c in clients))
def gen_ca_cert(ca_key_path, ca_cert_path):
# We first need to generate the tls certificates to be used by the server
# Generate CA (certificate authority) key and self-signed certificate
# In production, CA should be generated by a third party authority
# Expires in one day and is not encrtypted (-nodes)
# X.509 format for the key
step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key_path} -out {ca_cert_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"'
subprocess.run(step, shell=True)
def gen_certificate(
ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path
):