chore: refactor VersionMonitor into a separate file (#2326)

* chore: refactor VersionMonitor into a separate file
---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2023-12-24 22:06:57 +02:00 committed by GitHub
parent d129674e17
commit 700a65ece5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 230 additions and 191 deletions

View file

@ -1,4 +1,4 @@
add_executable(dragonfly dfly_main.cc) add_executable(dragonfly dfly_main.cc version_monitor.cc)
cxx_link(dragonfly base dragonfly_lib) cxx_link(dragonfly base dragonfly_lib)
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_BUILD_TYPE STREQUAL "Release") if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_BUILD_TYPE STREQUAL "Release")

View file

@ -23,12 +23,10 @@
#endif #endif
#include <mimalloc.h> #include <mimalloc.h>
#include <openssl/err.h>
#include <signal.h> #include <signal.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <regex>
#include "base/init.h" #include "base/init.h"
#include "base/proc_util.h" // for GetKernelVersion #include "base/proc_util.h" // for GetKernelVersion
@ -40,10 +38,10 @@
#include "server/generic_family.h" #include "server/generic_family.h"
#include "server/main_service.h" #include "server/main_service.h"
#include "server/version.h" #include "server/version.h"
#include "server/version_monitor.h"
#include "strings/human_readable.h" #include "strings/human_readable.h"
#include "util/accept_server.h" #include "util/accept_server.h"
#include "util/fibers/pool.h" #include "util/fibers/pool.h"
#include "util/http/http_client.h"
#include "util/varz.h" #include "util/varz.h"
#ifdef __APPLE__ #ifdef __APPLE__
@ -88,189 +86,6 @@ namespace {
using util::http::TlsClient; using util::http::TlsClient;
std::optional<std::string> GetVersionString(const std::string& version_str) {
// The server sends a message such as {"latest": "0.12.0"}
const auto reg_match_expr = R"(\{\"latest"\:[ \t]*\"([0-9]+\.[0-9]+\.[0-9]+)\"\})";
VLOG(1) << "checking version '" << version_str << "'";
auto const regex = std::regex(reg_match_expr);
std::smatch match;
if (std::regex_match(version_str, match, regex) && match.size() > 1) {
// the second entry is the match to the group that holds the version string
return match[1].str();
} else {
LOG_FIRST_N(WARNING, 1) << "Remote version - invalid version number: '" << version_str << "'";
return std::nullopt;
}
}
std::optional<std::string> GetRemoteVersion(ProactorBase* proactor, SSL_CTX* ssl_context,
const std::string host, std::string_view service,
const std::string& resource,
const std::string& ver_header) {
namespace bh = boost::beast::http;
using ResponseType = bh::response<bh::string_body>;
bh::request<bh::string_body> req{bh::verb::get, resource, 11 /*http 1.1*/};
req.set(bh::field::host, host);
req.set(bh::field::user_agent, ver_header);
ResponseType res;
TlsClient http_client{proactor};
http_client.set_connect_timeout_ms(2000);
auto ec = http_client.Connect(host, service, ssl_context);
if (ec) {
LOG_FIRST_N(WARNING, 1) << "Remote version - connection error [" << host << ":" << service
<< "] : " << ec.message();
return nullopt;
}
ec = http_client.Send(req, &res);
if (!ec) {
VLOG(1) << "successfully got response from HTTP GET for host " << host << ":" << service << "/"
<< resource << " response code is " << res.result();
if (res.result() == bh::status::ok) {
return GetVersionString(res.body());
}
} else {
static bool is_logged{false};
if (!is_logged) {
is_logged = true;
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
const char* func_err = "ssl_internal_error";
#else
const char* func_err = ERR_func_error_string(ec.value());
#endif
// Unfortunately AsioStreamAdapter looses the original error category
// because std::error_code can not be converted into boost::system::error_code.
// It's fixed in later versions of Boost, but for now we assume it's from TLS.
LOG(WARNING) << "Remote version - HTTP GET error [" << host << ":" << service << resource
<< "], error: " << ec.value();
LOG(WARNING) << "ssl error: " << func_err << "/" << ERR_reason_error_string(ec.value());
}
}
return nullopt;
}
struct VersionMonitor {
Fiber version_fiber_;
Done monitor_ver_done_;
void Run(ProactorPool* proactor_pool);
void Shutdown() {
monitor_ver_done_.Notify();
if (version_fiber_.IsJoinable()) {
version_fiber_.Join();
}
}
private:
struct SslDeleter {
void operator()(SSL_CTX* ssl) {
if (ssl) {
TlsClient::FreeContext(ssl);
}
}
};
using SslPtr = std::unique_ptr<SSL_CTX, SslDeleter>;
void RunTask(SslPtr);
bool IsVersionOutdated(std::string_view remote, std::string_view current) const;
};
bool VersionMonitor::IsVersionOutdated(const std::string_view remote,
const std::string_view current) const {
const absl::InlinedVector<absl::string_view, 3> remote_xyz = absl::StrSplit(remote, ".");
const absl::InlinedVector<absl::string_view, 3> current_xyz = absl::StrSplit(current, ".");
if (remote_xyz.size() != current_xyz.size()) {
LOG(WARNING) << "Can't compare Dragonfly version " << current << " to latest version "
<< remote;
return false;
}
const auto print_to_log = [](const std::string_view version, const absl::string_view part) {
LOG(WARNING) << "Can't parse " << version << " part of version " << part << " as a number";
};
for (size_t i = 0; i < remote_xyz.size(); ++i) {
size_t remote_x = 0;
if (!absl::SimpleAtoi(remote_xyz[i], &remote_x)) {
print_to_log(remote, remote_xyz[i]);
return false;
}
size_t current_x = 0;
if (!absl::SimpleAtoi(current_xyz[i], &current_x)) {
print_to_log(current, current_xyz[i]);
return false;
}
if (remote_x > current_x) {
return true;
}
if (remote_x < current_x) {
return false;
}
}
return false;
}
void VersionMonitor::Run(ProactorPool* proactor_pool) {
// Avoid running dev environments.
if (getenv("DFLY_DEV_ENV")) {
LOG(WARNING) << "Running in dev environment (DFLY_DEV_ENV is set) - version monitoring is "
"disabled";
return;
}
// not a production release tag.
if (!GetFlag(FLAGS_version_check) || kGitTag[0] != 'v' || strchr(kGitTag, '-')) {
return;
}
SslPtr ssl_ctx(TlsClient::CreateSslContext());
if (!ssl_ctx) {
VLOG(1) << "Remote version - failed to create SSL context - cannot run version monitoring";
return;
}
version_fiber_ = proactor_pool->GetNextProactor()->LaunchFiber(
[ssl_ctx = std::move(ssl_ctx), this]() mutable { RunTask(std::move(ssl_ctx)); });
}
void VersionMonitor::RunTask(SslPtr ssl_ctx) {
const auto loop_sleep_time = std::chrono::hours(24); // every 24 hours
const std::string host_name = "version.dragonflydb.io";
const std::string_view port = "443";
const std::string resource = "/v1";
string_view current_version(kGitTag);
current_version.remove_prefix(1);
const std::string version_header = absl::StrCat("DragonflyDB/", current_version);
ProactorBase* my_pb = ProactorBase::me();
while (true) {
const std::optional<std::string> remote_version =
GetRemoteVersion(my_pb, ssl_ctx.get(), host_name, port, resource, version_header);
if (remote_version) {
const std::string_view rv = remote_version.value();
if (IsVersionOutdated(rv, current_version)) {
LOG_FIRST_N(INFO, 1) << "Your current version '" << current_version
<< "' is not the latest version. A newer version '" << rv
<< "' is now available. Please consider an update.";
}
}
if (monitor_ver_done_.WaitFor(loop_sleep_time)) {
VLOG(1) << "finish running version monitor task";
return;
}
}
}
enum class TermColor { kDefault, kRed, kGreen, kYellow }; enum class TermColor { kDefault, kRed, kGreen, kYellow };
// Returns the ANSI color code for the given color. TermColor::kDefault is // Returns the ANSI color code for the given color. TermColor::kDefault is
// an invalid input. // an invalid input.
@ -459,7 +274,11 @@ bool RunEngine(ProactorPool* pool, AcceptServer* acceptor) {
service.Init(acceptor, listeners, opts); service.Init(acceptor, listeners, opts);
VersionMonitor version_monitor; VersionMonitor version_monitor;
version_monitor.Run(pool);
// check if it's a production release tag.
if (GetFlag(FLAGS_version_check) && kGitTag[0] == 'v' && strchr(kGitTag, '-') == nullptr) {
version_monitor.Run(pool);
}
// Start the acceptor loop and wait for the server to shutdown. // Start the acceptor loop and wait for the server to shutdown.
acceptor->Run(); acceptor->Run();
@ -733,9 +552,9 @@ void PrintBasicUsageInfo() {
void ParseFlagsFromEnv() { void ParseFlagsFromEnv() {
if (getenv("DFLY_PASSWORD")) { if (getenv("DFLY_PASSWORD")) {
LOG(WARNING) LOG(FATAL) << "DFLY_PASSWORD environment variable was deprecated in favor of DFLY_requirepass";
<< "DFLY_PASSWORD environment variable is being deprecated in favour of DFLY_requirepass";
} }
// Allowed environment variable names that can have // Allowed environment variable names that can have
// DFLY_ prefix, but don't necessarily have an ABSL flag created // DFLY_ prefix, but don't necessarily have an ABSL flag created
absl::flat_hash_set<std::string_view> ignored_environment_flag_names = {"DEV_ENV", "PASSWORD"}; absl::flat_hash_set<std::string_view> ignored_environment_flag_names = {"DEV_ENV", "PASSWORD"};

View file

@ -0,0 +1,184 @@
// Copyright 2023, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/version_monitor.h"
#include <absl/strings/str_split.h>
#include <openssl/err.h>
#include <boost/beast/http/string_body.hpp>
#include <regex>
#include "base/logging.h"
#include "server/version.h"
namespace dfly {
using namespace std;
using namespace util;
using http::TlsClient;
namespace {
std::optional<std::string> GetVersionString(const std::string& version_str) {
// The server sends a message such as {"latest": "0.12.0"}
const auto reg_match_expr = R"(\{\"latest"\:[ \t]*\"([0-9]+\.[0-9]+\.[0-9]+)\"\})";
VLOG(1) << "checking version '" << version_str << "'";
auto const regex = std::regex(reg_match_expr);
std::smatch match;
if (std::regex_match(version_str, match, regex) && match.size() > 1) {
// the second entry is the match to the group that holds the version string
return match[1].str();
} else {
LOG_FIRST_N(WARNING, 1) << "Remote version - invalid version number: '" << version_str << "'";
return std::nullopt;
}
}
std::optional<std::string> GetRemoteVersion(ProactorBase* proactor, SSL_CTX* ssl_context,
const std::string host, std::string_view service,
const std::string& resource,
const std::string& ver_header) {
namespace bh = boost::beast::http;
using ResponseType = bh::response<bh::string_body>;
bh::request<bh::string_body> req{bh::verb::get, resource, 11 /*http 1.1*/};
req.set(bh::field::host, host);
req.set(bh::field::user_agent, ver_header);
ResponseType res;
TlsClient http_client{proactor};
http_client.set_connect_timeout_ms(2000);
auto ec = http_client.Connect(host, service, ssl_context);
if (ec) {
LOG_FIRST_N(WARNING, 1) << "Remote version - connection error [" << host << ":" << service
<< "] : " << ec.message();
return nullopt;
}
ec = http_client.Send(req, &res);
if (!ec) {
VLOG(1) << "successfully got response from HTTP GET for host " << host << ":" << service << "/"
<< resource << " response code is " << res.result();
if (res.result() == bh::status::ok) {
return GetVersionString(res.body());
}
} else {
static bool is_logged{false};
if (!is_logged) {
is_logged = true;
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
const char* func_err = "ssl_internal_error";
#else
const char* func_err = ERR_func_error_string(ec.value());
#endif
// Unfortunately AsioStreamAdapter looses the original error category
// because std::error_code can not be converted into boost::system::error_code.
// It's fixed in later versions of Boost, but for now we assume it's from TLS.
LOG(WARNING) << "Remote version - HTTP GET error [" << host << ":" << service << resource
<< "], error: " << ec.value();
LOG(WARNING) << "ssl error: " << func_err << "/" << ERR_reason_error_string(ec.value());
}
}
return nullopt;
}
} // namespace
bool VersionMonitor::IsVersionOutdated(const std::string_view remote,
const std::string_view current) const {
const absl::InlinedVector<absl::string_view, 3> remote_xyz = absl::StrSplit(remote, ".");
const absl::InlinedVector<absl::string_view, 3> current_xyz = absl::StrSplit(current, ".");
if (remote_xyz.size() != current_xyz.size()) {
LOG(WARNING) << "Can't compare Dragonfly version " << current << " to latest version "
<< remote;
return false;
}
const auto print_to_log = [](const std::string_view version, const absl::string_view part) {
LOG(WARNING) << "Can't parse " << version << " part of version " << part << " as a number";
};
for (size_t i = 0; i < remote_xyz.size(); ++i) {
size_t remote_x = 0;
if (!absl::SimpleAtoi(remote_xyz[i], &remote_x)) {
print_to_log(remote, remote_xyz[i]);
return false;
}
size_t current_x = 0;
if (!absl::SimpleAtoi(current_xyz[i], &current_x)) {
print_to_log(current, current_xyz[i]);
return false;
}
if (remote_x > current_x) {
return true;
}
if (remote_x < current_x) {
return false;
}
}
return false;
}
void VersionMonitor::Run(ProactorPool* proactor_pool) {
// Avoid running dev environments.
if (getenv("DFLY_DEV_ENV")) {
LOG(WARNING) << "Running in dev environment (DFLY_DEV_ENV is set) - version monitoring is "
"disabled";
return;
}
SslPtr ssl_ctx(TlsClient::CreateSslContext());
if (!ssl_ctx) {
VLOG(1) << "Remote version - failed to create SSL context - cannot run version monitoring";
return;
}
version_fiber_ = proactor_pool->GetNextProactor()->LaunchFiber(
[ssl_ctx = std::move(ssl_ctx), this]() mutable { RunTask(std::move(ssl_ctx)); });
}
void VersionMonitor::Shutdown() {
monitor_ver_done_.Notify();
if (version_fiber_.IsJoinable()) {
version_fiber_.Join();
}
}
void VersionMonitor::RunTask(SslPtr ssl_ctx) {
const auto loop_sleep_time = std::chrono::hours(24); // every 24 hours
const std::string host_name = "version.dragonflydb.io";
const std::string_view port = "443";
const std::string resource = "/v1";
string_view current_version(kGitTag);
current_version.remove_prefix(1);
const std::string version_header = absl::StrCat("DragonflyDB/", current_version);
ProactorBase* my_pb = ProactorBase::me();
while (true) {
const std::optional<std::string> remote_version =
GetRemoteVersion(my_pb, ssl_ctx.get(), host_name, port, resource, version_header);
if (remote_version) {
const std::string_view rv = remote_version.value();
if (IsVersionOutdated(rv, current_version)) {
LOG_FIRST_N(INFO, 1) << "Your current version '" << current_version
<< "' is not the latest version. A newer version '" << rv
<< "' is now available. Please consider an update.";
}
}
if (monitor_ver_done_.WaitFor(loop_sleep_time)) {
VLOG(1) << "finish running version monitor task";
return;
}
}
}
} // namespace dfly

View file

@ -0,0 +1,36 @@
// Copyright 2023, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include "util/fibers/fibers.h"
#include "util/fibers/pool.h"
#include "util/http/http_client.h"
namespace dfly {
class VersionMonitor {
public:
void Run(util::ProactorPool* proactor_pool);
void Shutdown();
private:
struct SslDeleter {
void operator()(SSL_CTX* ssl) {
if (ssl) {
util::http::TlsClient::FreeContext(ssl);
}
}
};
using SslPtr = std::unique_ptr<SSL_CTX, SslDeleter>;
void RunTask(SslPtr);
bool IsVersionOutdated(std::string_view remote, std::string_view current) const;
util::fb2::Fiber version_fiber_;
util::fb2::Done monitor_ver_done_;
};
} // namespace dfly

View file

@ -23,7 +23,7 @@ class TestKeys:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def export_dfly_password() -> str: def export_dfly_password() -> str:
pwd = "flypwd" pwd = "flypwd"
with EnvironCntx(DFLY_PASSWORD=pwd): with EnvironCntx(DFLY_requirepass=pwd):
yield pwd yield pwd