diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 0c709b1c9..1ab9bcf45 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -300,9 +300,20 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, } migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections); + +#ifdef DFLY_USE_SSL + // Increment reference counter so Listener won't free the context while we're + // still using it. + if (ctx) { + SSL_CTX_up_ref(ctx); + } +#endif } Connection::~Connection() { +#ifdef DFLY_USE_SSL + SSL_CTX_free(ctx_); +#endif } // Called from Connection::Shutdown() right after socket_->Shutdown call. diff --git a/src/facade/dragonfly_listener.cc b/src/facade/dragonfly_listener.cc index cb92e76e2..64ee10380 100644 --- a/src/facade/dragonfly_listener.cc +++ b/src/facade/dragonfly_listener.cc @@ -67,24 +67,31 @@ namespace { #ifdef DFLY_USE_SSL +// Creates the TLS context. Returns nullptr if the TLS configuration is invalid. // To connect: openssl s_client -state -crlf -connect 127.0.0.1:6380 SSL_CTX* CreateSslServerCntx() { const auto& tls_key_file = GetFlag(FLAGS_tls_key_file); if (tls_key_file.empty()) { LOG(ERROR) << "To use TLS, a server certificate must be provided with the --tls_key_file flag!"; - exit(-1); + return nullptr; } SSL_CTX* ctx = SSL_CTX_new(TLS_server_method()); unsigned mask = SSL_VERIFY_NONE; - DFLY_SSL_CHECK(1 == SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM)); + if (SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM) != 1) { + LOG(ERROR) << "Failed to load TLS key"; + return nullptr; + } const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file); if (!tls_cert_file.empty()) { // TO connect with redis-cli you need both tls-key-file and tls-cert-file // loaded. Use `redis-cli --tls -p 6380 --insecure PING` to test - DFLY_SSL_CHECK(1 == SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str())); + if (SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()) != 1) { + LOG(ERROR) << "Failed to load TLS certificate"; + return nullptr; + } } const auto tls_ca_cert_file = GetFlag(FLAGS_tls_ca_cert_file); @@ -92,7 +99,10 @@ SSL_CTX* CreateSslServerCntx() { if (!tls_ca_cert_file.empty() || !tls_ca_cert_dir.empty()) { const auto* file = tls_ca_cert_file.empty() ? nullptr : tls_ca_cert_file.data(); const auto* dir = tls_ca_cert_dir.empty() ? nullptr : tls_ca_cert_dir.data(); - DFLY_SSL_CHECK(1 == SSL_CTX_load_verify_locations(ctx, file, dir)); + if (SSL_CTX_load_verify_locations(ctx, file, dir) != 1) { + LOG(ERROR) << "Failed to load TLS verify locations (CA cert file or CA cert dir)"; + return nullptr; + } mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT; } @@ -145,9 +155,10 @@ bool ConfigureKeepAlive(int fd) { Listener::Listener(Protocol protocol, ServiceInterface* si, Role role) : service_(si), protocol_(protocol) { #ifdef DFLY_USE_SSL - if (GetFlag(FLAGS_tls)) { - OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL); - ctx_ = CreateSslServerCntx(); + // Always initialise OpenSSL so we can enable TLS at runtime. + OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, nullptr); + if (!ReconfigureTLS()) { + exit(-1); } #endif role_ = role; @@ -198,6 +209,27 @@ error_code Listener::ConfigureServerSocket(int fd) { return error_code{}; } +bool Listener::ReconfigureTLS() { + SSL_CTX* prev_ctx = ctx_; + if (GetFlag(FLAGS_tls)) { + SSL_CTX* ctx = CreateSslServerCntx(); + if (!ctx) { + return false; + } + ctx_ = ctx; + } else { + ctx_ = nullptr; + } + + if (prev_ctx) { + // SSL_CTX is reference counted so if other connections have a reference + // to the context it won't be freed yet. + SSL_CTX_free(prev_ctx); + } + + return true; +} + void Listener::PreAcceptLoop(util::ProactorBase* pb) { per_thread_.resize(pool()->size()); } diff --git a/src/facade/dragonfly_listener.h b/src/facade/dragonfly_listener.h index 295ae37ab..ee3ee4547 100644 --- a/src/facade/dragonfly_listener.h +++ b/src/facade/dragonfly_listener.h @@ -34,6 +34,9 @@ class Listener : public util::ListenerInterface { std::error_code ConfigureServerSocket(int fd) final; + // ReconfigureTLS MUST be called from the same proactor as the listener. + bool ReconfigureTLS(); + // Wait until all connections that pass the filter have stopped dispatching or until a timeout has // run out. Returns true if the all connections have stopped dispatching. bool AwaitDispatches(absl::Duration timeout, diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 1d3e58528..2065ef9d5 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -298,9 +298,9 @@ bool IsValidSaveScheduleNibble(string_view time, unsigned int max) { // enabled. That means either using a password or giving a root // certificate for authenticating client certificates which will // be required. -void ValidateServerTlsFlags() { +bool ValidateServerTlsFlags() { if (!absl::GetFlag(FLAGS_tls)) { - return; + return true; } bool has_auth = false; @@ -316,8 +316,10 @@ void ValidateServerTlsFlags() { if (!has_auth) { LOG(ERROR) << "TLS configured but no authentication method is used!"; - exit(1); + return false; } + + return true; } bool IsReplicatingNoOne(string_view host, string_view port) { @@ -446,7 +448,9 @@ ServerFamily::ServerFamily(Service* service) : service_(*service) { exit(1); } - ValidateServerTlsFlags(); + if (!ValidateServerTlsFlags()) { + exit(1); + } ValidateClientTlsFlags(); } @@ -502,6 +506,25 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorsocket()->proactor()->Await([l] { return l->ReconfigureTLS(); })) { + return false; + } + } + return true; + }); + config_registry.RegisterMutable("tls_cert_file"); + config_registry.RegisterMutable("tls_key_file"); + config_registry.RegisterMutable("tls_ca_cert_file"); + config_registry.RegisterMutable("tls_ca_cert_dir"); + pb_task_ = shard_set->pool()->GetNextProactor(); if (pb_task_->GetKind() == ProactorBase::EPOLL) { fq_threadpool_.reset(new FiberQueueThreadPool(absl::GetFlag(FLAGS_epoll_file_threads))); @@ -1176,7 +1199,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { string_view sub_cmd = ArgS(args, 0); if (sub_cmd == "SET") { - if (args.size() < 3) { + if (args.size() != 3) { return (*cntx)->SendError(WrongNumArgsError("config|set")); } diff --git a/tests/dragonfly/conftest.py b/tests/dragonfly/conftest.py index 8ff88806f..0a61e2073 100644 --- a/tests/dragonfly/conftest.py +++ b/tests/dragonfly/conftest.py @@ -21,7 +21,7 @@ from tempfile import TemporaryDirectory from .instance import DflyInstance, DflyParams, DflyInstanceFactory from . import PortPicker, dfly_args -from .utility import DflySeederFactory, gen_certificate +from .utility import DflySeederFactory, gen_ca_cert, gen_certificate logging.getLogger("asyncio").setLevel(logging.WARNING) @@ -254,30 +254,22 @@ def memcached_connection(df_server: DflyInstance): @pytest.fixture(scope="session") -def gen_ca_cert(tmp_dir): - # We first need to generate the tls certificates to be used by the server - - # Generate CA (certificate authority) key and self-signed certificate - # In production, CA should be generated by a third party authority - # Expires in one day and is not encrtypted (-nodes) - # X.509 format for the key +def with_tls_ca_cert_args(tmp_dir): ca_key = os.path.join(tmp_dir, "ca-key.pem") ca_cert = os.path.join(tmp_dir, "ca-cert.pem") - step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key} -out {ca_cert} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"' - subprocess.run(step, shell=True) - + gen_ca_cert(ca_key, ca_cert) return {"ca_key": ca_key, "ca_cert": ca_cert} @pytest.fixture(scope="session") -def with_tls_server_args(tmp_dir, gen_ca_cert): +def with_tls_server_args(tmp_dir, with_tls_ca_cert_args): tls_server_key = os.path.join(tmp_dir, "df-key.pem") tls_server_req = os.path.join(tmp_dir, "df-req.pem") tls_server_cert = os.path.join(tmp_dir, "df-cert.pem") gen_certificate( - gen_ca_cert["ca_key"], - gen_ca_cert["ca_cert"], + with_tls_ca_cert_args["ca_key"], + with_tls_ca_cert_args["ca_cert"], tls_server_req, tls_server_key, tls_server_cert, @@ -288,21 +280,21 @@ def with_tls_server_args(tmp_dir, gen_ca_cert): @pytest.fixture(scope="session") -def with_ca_tls_server_args(with_tls_server_args, gen_ca_cert): +def with_ca_tls_server_args(with_tls_server_args, with_tls_ca_cert_args): args = deepcopy(with_tls_server_args) - args["tls_ca_cert_file"] = gen_ca_cert["ca_cert"] + args["tls_ca_cert_file"] = with_tls_ca_cert_args["ca_cert"] return args @pytest.fixture(scope="session") -def with_tls_client_args(tmp_dir, gen_ca_cert): +def with_tls_client_args(tmp_dir, with_tls_ca_cert_args): tls_client_key = os.path.join(tmp_dir, "client-key.pem") tls_client_req = os.path.join(tmp_dir, "client-req.pem") tls_client_cert = os.path.join(tmp_dir, "client-cert.pem") gen_certificate( - gen_ca_cert["ca_key"], - gen_ca_cert["ca_cert"], + with_tls_ca_cert_args["ca_key"], + with_tls_ca_cert_args["ca_cert"], tls_client_req, tls_client_key, tls_client_cert, @@ -313,7 +305,7 @@ def with_tls_client_args(tmp_dir, gen_ca_cert): @pytest.fixture(scope="session") -def with_ca_tls_client_args(with_tls_client_args, gen_ca_cert): +def with_ca_tls_client_args(with_tls_client_args, with_tls_ca_cert_args): args = deepcopy(with_tls_client_args) - args["ssl_ca_certs"] = gen_ca_cert["ca_cert"] + args["ssl_ca_certs"] = with_tls_ca_cert_args["ca_cert"] return args diff --git a/tests/dragonfly/tls_conf_test.py b/tests/dragonfly/tls_conf_test.py index 35b0ae229..2fd21ca0a 100644 --- a/tests/dragonfly/tls_conf_test.py +++ b/tests/dragonfly/tls_conf_test.py @@ -18,20 +18,20 @@ async def test_tls_no_key(df_factory): server.start() -async def test_tls_password(df_factory, with_tls_server_args, gen_ca_cert): +async def test_tls_password(df_factory, with_tls_server_args, with_tls_ca_cert_args): with df_factory.create(requirepass="XXX", **with_tls_server_args) as server: async with server.client( - ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"] + ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] ) as client: await client.ping() async def test_tls_client_certs( - df_factory, with_ca_tls_server_args, with_tls_client_args, gen_ca_cert + df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args ): with df_factory.create(**with_ca_tls_server_args) as server: async with server.client( - **with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"] + **with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] ) as client: await client.ping() @@ -52,3 +52,102 @@ async def test_client_tls_cert(df_factory, with_tls_server_args): key_args.pop("tls") with df_factory.create(tls_replication=None, **key_args): pass + + +async def test_config_update_tls_certs( + df_factory, with_tls_server_args, with_tls_ca_cert_args, tmp_dir +): + # Generate new certificates. + ca_key = os.path.join(tmp_dir, "ca-key-new.pem") + ca_cert = os.path.join(tmp_dir, "ca-cert-new.pem") + gen_ca_cert(ca_key, ca_cert) + tls_server_key = os.path.join(tmp_dir, "df-key-new.pem") + tls_server_req = os.path.join(tmp_dir, "df-req-new.pem") + tls_server_cert = os.path.join(tmp_dir, "df-cert-new.pem") + gen_certificate( + ca_key, + ca_cert, + tls_server_req, + tls_server_key, + tls_server_cert, + ) + + with df_factory.create(requirepass="XXX", **with_tls_server_args) as server: + async with server.client( + ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] + ) as client: + await client.config_set( + "tls_key_file", + tls_server_key, + ) + await client.config_set("tls_cert_file", tls_server_cert) + # Note must still set `tls true` to reload the TLS context. + await client.config_set("tls", "true") + + # The existing connection should still work. + await client.ping() + + # Connecting with the old CA should fail. + with pytest.raises(redis.exceptions.ConnectionError): + async with server.client( + ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] + ) as client: + await client.ping() + + # Connecting with the new CA should succeed. + async with server.client(ssl=True, password="XXX", ssl_ca_certs=ca_cert) as client: + await client.ping() + + +async def test_config_enable_tls( + df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args +): + with df_factory.create() as server: + async with server.client() as client: + await client.ping() + + # Note the order here matters as flags are applied in order. + await client.config_set( + "tls_key_file", + with_ca_tls_server_args["tls_key_file"], + ) + await client.config_set( + "tls_cert_file", + with_ca_tls_server_args["tls_cert_file"], + ) + await client.config_set( + "tls_ca_cert_file", + with_ca_tls_server_args["tls_ca_cert_file"], + ) + await client.config_set( + "tls", + "true", + ) + + # The existing client should still be connected. + await client.ping() + + # Connecting without TLS should fail. + with pytest.raises(redis.exceptions.ConnectionError): + async with server.client() as client_unauth: + await client_unauth.ping() + + # Connecting with TLS should succeed. + async with server.client( + **with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] + ) as client_tls: + await client_tls.ping() + + +async def test_config_disable_tls( + df_factory, with_ca_tls_server_args, with_tls_client_args, with_tls_ca_cert_args +): + with df_factory.create(**with_ca_tls_server_args) as server: + async with server.client( + **with_tls_client_args, ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] + ) as client_tls: + await client_tls.config_set("tls", "false") + + # Connecting without TLS should succeed. + async with server.client() as client_unauth: + await client_unauth.ping() diff --git a/tests/dragonfly/utility.py b/tests/dragonfly/utility.py index b3e139d33..2aebb8b7f 100644 --- a/tests/dragonfly/utility.py +++ b/tests/dragonfly/utility.py @@ -581,6 +581,17 @@ async def disconnect_clients(*clients): await asyncio.gather(*(c.connection_pool.disconnect() for c in clients)) +def gen_ca_cert(ca_key_path, ca_cert_path): + # We first need to generate the tls certificates to be used by the server + + # Generate CA (certificate authority) key and self-signed certificate + # In production, CA should be generated by a third party authority + # Expires in one day and is not encrtypted (-nodes) + # X.509 format for the key + step = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key_path} -out {ca_cert_path} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/emailAddress=acme@gmail.com"' + subprocess.run(step, shell=True) + + def gen_certificate( ca_key_path, ca_certificate_path, certificate_request_path, private_key_path, certificate_path ):