fix: enforce load limits when loading snapshot (#4136)

* fix: enforce load limits when loading snapshot

Prevent loading snapshots with used memory higher than max memory limit.

1. Store the used memory metadata only inside the summary file
2. Load the summary file before loading anything else, and if the used-memory is higher,
   abort the load.
---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-11-20 06:12:47 +02:00 committed by GitHub
parent 4e7800f94f
commit 0e7ae34fe4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 82 additions and 81 deletions

View file

@ -155,6 +155,6 @@
// Currently moved here from server.h // Currently moved here from server.h
#define LONG_STR_SIZE 21 /* Bytes needed for long -> str + '\0' */ #define LONG_STR_SIZE 21 /* Bytes needed for long -> str + '\0' */
#define REDIS_VERSION "999.999.999" #define REDIS_VERSION "6.2.11"
#endif #endif

View file

@ -90,7 +90,8 @@ string SnapshotStorage::FindMatchingFile(string_view prefix, string_view dbfilen
return {}; return {};
} }
io::Result<vector<string>, GenericError> SnapshotStorage::ExpandSnapshot(const string& load_path) { io::Result<SnapshotStorage::ExpandResult, GenericError> SnapshotStorage::ExpandSnapshot(
const string& load_path) {
if (!(absl::EndsWith(load_path, ".rdb") || absl::EndsWith(load_path, "summary.dfs"))) { if (!(absl::EndsWith(load_path, ".rdb") || absl::EndsWith(load_path, "summary.dfs"))) {
return nonstd::make_unexpected( return nonstd::make_unexpected(
GenericError(std::make_error_code(std::errc::invalid_argument), "Bad filename extension")); GenericError(std::make_error_code(std::errc::invalid_argument), "Bad filename extension"));
@ -101,17 +102,20 @@ io::Result<vector<string>, GenericError> SnapshotStorage::ExpandSnapshot(const s
return nonstd::make_unexpected(GenericError(ec, "File not found")); return nonstd::make_unexpected(GenericError(ec, "File not found"));
} }
vector<string> paths{{load_path}}; ExpandResult result;
// Collect all other files in case we're loading dfs. // Collect all other files in case we're loading dfs.
if (absl::EndsWith(load_path, "summary.dfs")) { if (absl::EndsWith(load_path, "summary.dfs")) {
auto res = ExpandFromPath(load_path); auto res = ExpandFromPath(load_path);
if (!res) { if (!res) {
return res; return nonstd::make_unexpected(res.error());
} }
paths.insert(paths.end(), res->begin(), res->end()); result = std::move(*res);
result.push_back(load_path);
} else {
result.push_back(load_path);
} }
return paths; return result;
} }
FileSnapshotStorage::FileSnapshotStorage(fb2::FiberQueueThreadPool* fq_threadpool) FileSnapshotStorage::FileSnapshotStorage(fb2::FiberQueueThreadPool* fq_threadpool)

View file

@ -51,8 +51,9 @@ class SnapshotStorage {
virtual io::Result<std::string, GenericError> LoadPath(std::string_view dir, virtual io::Result<std::string, GenericError> LoadPath(std::string_view dir,
std::string_view dbfilename) = 0; std::string_view dbfilename) = 0;
using ExpandResult = std::vector<std::string>;
// Searches for all the relevant snapshot files given the RDB file or DFS summary file path. // Searches for all the relevant snapshot files given the RDB file or DFS summary file path.
io::Result<std::vector<std::string>, GenericError> ExpandSnapshot(const std::string& load_path); io::Result<ExpandResult, GenericError> ExpandSnapshot(const std::string& load_path);
virtual bool IsCloud() const { virtual bool IsCloud() const {
return false; return false;

View file

@ -102,6 +102,8 @@ string error_category::message(int ev) const {
switch (ev) { switch (ev) {
case errc::wrong_signature: case errc::wrong_signature:
return "Wrong signature while trying to load from rdb file"; return "Wrong signature while trying to load from rdb file";
case errc::out_of_memory:
return "Out of memory, or used memory is too high";
default: default:
return absl::StrCat("Internal error when loading RDB file ", ev); return absl::StrCat("Internal error when loading RDB file ", ev);
break; break;
@ -2596,7 +2598,9 @@ error_code RdbLoader::HandleAux() {
} else if (auxkey == "lua") { } else if (auxkey == "lua") {
LoadScriptFromAux(std::move(auxval)); LoadScriptFromAux(std::move(auxval));
} else if (auxkey == "redis-ver") { } else if (auxkey == "redis-ver") {
VLOG(1) << "Loading RDB produced by version " << auxval; VLOG(1) << "Loading RDB produced by Redis version " << auxval;
} else if (auxkey == "df-ver") {
VLOG(1) << "Loading RDB produced by Dragonfly version " << auxval;
} else if (auxkey == "ctime") { } else if (auxkey == "ctime") {
int64_t ctime; int64_t ctime;
if (absl::SimpleAtoi(auxval, &ctime)) { if (absl::SimpleAtoi(auxval, &ctime)) {
@ -2606,9 +2610,14 @@ error_code RdbLoader::HandleAux() {
VLOG(1) << "RDB age " << strings::HumanReadableElapsedTime(age); VLOG(1) << "RDB age " << strings::HumanReadableElapsedTime(age);
} }
} else if (auxkey == "used-mem") { } else if (auxkey == "used-mem") {
long long usedmem; int64_t usedmem;
if (absl::SimpleAtoi(auxval, &usedmem)) { if (absl::SimpleAtoi(auxval, &usedmem)) {
VLOG(1) << "RDB memory usage when created " << strings::HumanReadableNumBytes(usedmem); VLOG(1) << "RDB memory usage when created " << strings::HumanReadableNumBytes(usedmem);
if (usedmem > ssize_t(max_memory_limit)) {
LOG(WARNING) << "Could not load snapshot - its used memory is " << usedmem
<< " but the limit is " << max_memory_limit;
return RdbError(errc::out_of_memory);
}
} }
} else if (auxkey == "aof-preamble") { } else if (auxkey == "aof-preamble") {
long long haspreamble; long long haspreamble;

View file

@ -1561,16 +1561,18 @@ void RdbSaver::FillFreqMap(RdbTypeFreqMap* freq_map) {
error_code RdbSaver::SaveAux(const GlobalData& glob_state) { error_code RdbSaver::SaveAux(const GlobalData& glob_state) {
static_assert(sizeof(void*) == 8, ""); static_assert(sizeof(void*) == 8, "");
int aof_preamble = false;
error_code ec; error_code ec;
/* Add a few fields about the state when the RDB was created. */ /* Add a few fields about the state when the RDB was created. */
RETURN_ON_ERR(impl_->SaveAuxFieldStrStr("redis-ver", REDIS_VERSION)); RETURN_ON_ERR(impl_->SaveAuxFieldStrStr("redis-ver", REDIS_VERSION));
RETURN_ON_ERR(impl_->SaveAuxFieldStrStr("df-ver", GetVersion()));
RETURN_ON_ERR(SaveAuxFieldStrInt("redis-bits", 64)); RETURN_ON_ERR(SaveAuxFieldStrInt("redis-bits", 64));
RETURN_ON_ERR(SaveAuxFieldStrInt("ctime", time(NULL))); RETURN_ON_ERR(SaveAuxFieldStrInt("ctime", time(NULL)));
RETURN_ON_ERR(SaveAuxFieldStrInt("used-mem", used_mem_current.load(memory_order_relaxed))); auto used_mem = used_mem_current.load(memory_order_relaxed);
RETURN_ON_ERR(SaveAuxFieldStrInt("aof-preamble", aof_preamble)); VLOG(1) << "Used memory during save: " << used_mem;
RETURN_ON_ERR(SaveAuxFieldStrInt("used-mem", used_mem));
RETURN_ON_ERR(SaveAuxFieldStrInt("aof-preamble", 0));
// Save lua scripts only in rdb or summary file // Save lua scripts only in rdb or summary file
DCHECK(save_mode_ != SaveMode::SINGLE_SHARD || glob_state.lua_scripts.empty()); DCHECK(save_mode_ != SaveMode::SINGLE_SHARD || glob_state.lua_scripts.empty());

View file

@ -94,7 +94,7 @@ TEST_F(RdbTest, Crc) {
TEST_F(RdbTest, LoadEmpty) { TEST_F(RdbTest, LoadEmpty) {
auto ec = LoadRdb("empty.rdb"); auto ec = LoadRdb("empty.rdb");
CHECK(!ec); ASSERT_FALSE(ec) << ec;
} }
TEST_F(RdbTest, LoadSmall6) { TEST_F(RdbTest, LoadSmall6) {
@ -646,4 +646,13 @@ TEST_F(RdbTest, LoadHugeStream) {
ASSERT_EQ(2000, CheckedInt({"xlen", "test:0"})); ASSERT_EQ(2000, CheckedInt({"xlen", "test:0"}));
} }
TEST_F(RdbTest, SnapshotTooBig) {
// Run({"debug", "populate", "10000", "foo", "1000"});
// usleep(5000); // let the stats to sync
max_memory_limit = 100000;
used_mem_current = 1000000;
auto resp = Run({"debug", "reload"});
ASSERT_THAT(resp, ErrArg("Out of memory"));
}
} // namespace dfly } // namespace dfly

View file

@ -1083,25 +1083,24 @@ std::optional<fb2::Future<GenericError>> ServerFamily::Load(string_view load_pat
DCHECK_GT(shard_count(), 0u); DCHECK_GT(shard_count(), 0u);
// TODO: to move it to helio.
auto immediate = [](auto val) {
fb2::Future<GenericError> future;
future.Resolve(val);
return future;
};
if (ServerState::tlocal() && !ServerState::tlocal()->is_master) { if (ServerState::tlocal() && !ServerState::tlocal()->is_master) {
fb2::Future<GenericError> future; return immediate(string("Replica cannot load data"));
future.Resolve(string("Replica cannot load data"));
return future;
} }
auto paths_result = snapshot_storage_->ExpandSnapshot(path); auto expand_result = snapshot_storage_->ExpandSnapshot(path);
if (!paths_result) { if (!expand_result) {
LOG(ERROR) << "Failed to load snapshot: " << paths_result.error().Format(); LOG(ERROR) << "Failed to load snapshot: " << expand_result.error().Format();
fb2::Future<GenericError> future; return immediate(expand_result.error());
future.Resolve(paths_result.error());
return future;
} }
std::vector<std::string> paths = *paths_result;
LOG(INFO) << "Loading " << path;
auto new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING); auto new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING);
if (new_state != GlobalState::LOADING) { if (new_state != GlobalState::LOADING) {
LOG(WARNING) << new_state << " in progress, ignored"; LOG(WARNING) << new_state << " in progress, ignored";
@ -1110,6 +1109,10 @@ std::optional<fb2::Future<GenericError>> ServerFamily::Load(string_view load_pat
auto& pool = service_.proactor_pool(); auto& pool = service_.proactor_pool();
const vector<string>& paths = *expand_result;
LOG(INFO) << "Loading " << path;
vector<fb2::Fiber> load_fibers; vector<fb2::Fiber> load_fibers;
load_fibers.reserve(paths.size()); load_fibers.reserve(paths.size());
@ -1125,39 +1128,36 @@ std::optional<fb2::Future<GenericError>> ServerFamily::Load(string_view load_pat
proactor = pool.GetNextProactor(); proactor = pool.GetNextProactor();
} }
auto load_fiber = [this, aggregated_result, existing_keys, path = std::move(path)]() { auto load_func = [this, aggregated_result, existing_keys, path = std::move(path)]() {
auto load_result = LoadRdb(path, existing_keys); auto load_result = LoadRdb(path, existing_keys);
if (load_result.has_value()) if (load_result.has_value())
aggregated_result->keys_read.fetch_add(*load_result); aggregated_result->keys_read.fetch_add(*load_result);
else else
aggregated_result->first_error = load_result.error(); aggregated_result->first_error = load_result.error();
}; };
load_fibers.push_back(proactor->LaunchFiber(std::move(load_fiber))); load_fibers.push_back(proactor->LaunchFiber(std::move(load_func)));
} }
fb2::Future<GenericError> future; fb2::Future<GenericError> future;
// Run fiber that empties the channel and sets ec_promise. // Run fiber that empties the channel and sets ec_promise.
auto load_join_fiber = [this, aggregated_result, load_fibers = std::move(load_fibers), auto load_join_func = [this, aggregated_result, load_fibers = std::move(load_fibers),
future]() mutable { future]() mutable {
for (auto& fiber : load_fibers) { for (auto& fiber : load_fibers) {
fiber.Join(); fiber.Join();
} }
if (aggregated_result->first_error) { if (aggregated_result->first_error) {
LOG(ERROR) << "Rdb load failed. " << (*aggregated_result->first_error).message(); LOG(ERROR) << "Rdb load failed: " << (*aggregated_result->first_error).message();
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); } else {
future.Resolve(*aggregated_result->first_error); RdbLoader::PerformPostLoad(&service_);
return; LOG(INFO) << "Load finished, num keys read: " << aggregated_result->keys_read;
} }
RdbLoader::PerformPostLoad(&service_);
LOG(INFO) << "Load finished, num keys read: " << aggregated_result->keys_read;
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
future.Resolve(*(aggregated_result->first_error)); future.Resolve(*(aggregated_result->first_error));
}; };
pool.GetNextProactor()->Dispatch(std::move(load_join_fiber)); pool.GetNextProactor()->Dispatch(std::move(load_join_func));
return future; return future;
} }
@ -1196,6 +1196,7 @@ void ServerFamily::SnapshotScheduling() {
io::Result<size_t> ServerFamily::LoadRdb(const std::string& rdb_file, io::Result<size_t> ServerFamily::LoadRdb(const std::string& rdb_file,
LoadExistingKeys existing_keys) { LoadExistingKeys existing_keys) {
VLOG(1) << "Loading data from " << rdb_file; VLOG(1) << "Loading data from " << rdb_file;
CHECK(fb2::ProactorBase::IsProactorThread()) << "must be called from proactor thread";
error_code ec; error_code ec;
io::ReadonlyFileOrError res = snapshot_storage_->OpenReadFile(rdb_file); io::ReadonlyFileOrError res = snapshot_storage_->OpenReadFile(rdb_file);

View file

@ -195,7 +195,6 @@ disconnect_cases = [
] ]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases) @pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases)
async def test_disconnect_replica( async def test_disconnect_replica(
df_factory: DflyInstanceFactory, df_factory: DflyInstanceFactory,
@ -327,7 +326,6 @@ master_crash_cases = [
] ]
@pytest.mark.asyncio
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases) @pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases)
async def test_disconnect_master( async def test_disconnect_master(
@ -397,7 +395,6 @@ Test re-connecting replica to different masters.
rotating_master_cases = [(4, [4, 4, 4, 4], dict(keys=2_000, dbcount=4))] rotating_master_cases = [(4, [4, 4, 4, 4], dict(keys=2_000, dbcount=4))]
@pytest.mark.asyncio
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("t_replica, t_masters, seeder_config", rotating_master_cases) @pytest.mark.parametrize("t_replica, t_masters, seeder_config", rotating_master_cases)
async def test_rotating_masters(df_factory, df_seeder_factory, t_replica, t_masters, seeder_config): async def test_rotating_masters(df_factory, df_seeder_factory, t_replica, t_masters, seeder_config):
@ -433,7 +430,6 @@ async def test_rotating_masters(df_factory, df_seeder_factory, t_replica, t_mast
fill_task.cancel() fill_task.cancel()
@pytest.mark.asyncio
@pytest.mark.slow @pytest.mark.slow
async def test_cancel_replication_immediately(df_factory, df_seeder_factory: DflySeederFactory): async def test_cancel_replication_immediately(df_factory, df_seeder_factory: DflySeederFactory):
""" """
@ -491,7 +487,6 @@ Check replica keys at the end.
""" """
@pytest.mark.asyncio
async def test_flushall(df_factory): async def test_flushall(df_factory):
master = df_factory.create(proactor_threads=4) master = df_factory.create(proactor_threads=4)
replica = df_factory.create(proactor_threads=2) replica = df_factory.create(proactor_threads=2)
@ -542,7 +537,6 @@ Test journal rewrites.
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio
async def test_rewrites(df_factory): async def test_rewrites(df_factory):
CLOSE_TIMESTAMP = int(time.time()) + 100 CLOSE_TIMESTAMP = int(time.time()) + 100
CLOSE_TIMESTAMP_MS = CLOSE_TIMESTAMP * 1000 CLOSE_TIMESTAMP_MS = CLOSE_TIMESTAMP * 1000
@ -727,7 +721,6 @@ Test automatic replication of expiry.
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio
async def test_expiry(df_factory: DflyInstanceFactory, n_keys=1000): async def test_expiry(df_factory: DflyInstanceFactory, n_keys=1000):
master = df_factory.create() master = df_factory.create()
replica = df_factory.create() replica = df_factory.create()
@ -866,7 +859,6 @@ return 'OK'
""" """
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, num_ops, num_keys, num_par, flags", script_cases) @pytest.mark.parametrize("t_master, t_replicas, num_ops, num_keys, num_par, flags", script_cases)
async def test_scripts(df_factory, t_master, t_replicas, num_ops, num_keys, num_par, flags): async def test_scripts(df_factory, t_master, t_replicas, num_ops, num_keys, num_par, flags):
master = df_factory.create(proactor_threads=t_master) master = df_factory.create(proactor_threads=t_master)
@ -900,7 +892,6 @@ async def test_scripts(df_factory, t_master, t_replicas, num_ops, num_keys, num_
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio
async def test_auth_master(df_factory, n_keys=20): async def test_auth_master(df_factory, n_keys=20):
masterpass = "requirepass" masterpass = "requirepass"
replicapass = "replicapass" replicapass = "replicapass"
@ -966,7 +957,6 @@ async def test_script_transfer(df_factory):
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@pytest.mark.asyncio
async def test_role_command(df_factory, n_keys=20): async def test_role_command(df_factory, n_keys=20):
master = df_factory.create() master = df_factory.create()
replica = df_factory.create() replica = df_factory.create()
@ -1064,7 +1054,6 @@ async def assert_replica_reconnections(replica_inst, initial_reconnects_count):
@dfly_args({"proactor_threads": 2}) @dfly_args({"proactor_threads": 2})
@pytest.mark.asyncio
async def test_replication_info(df_factory: DflyInstanceFactory, df_seeder_factory, n_keys=2000): async def test_replication_info(df_factory: DflyInstanceFactory, df_seeder_factory, n_keys=2000):
master = df_factory.create() master = df_factory.create()
replica = df_factory.create(replication_acks_interval=100) replica = df_factory.create(replication_acks_interval=100)
@ -1096,7 +1085,6 @@ More details in https://github.com/dragonflydb/dragonfly/issues/1231
""" """
@pytest.mark.asyncio
@pytest.mark.slow @pytest.mark.slow
async def test_flushall_in_full_sync(df_factory): async def test_flushall_in_full_sync(df_factory):
master = df_factory.create(proactor_threads=4) master = df_factory.create(proactor_threads=4)
@ -1155,7 +1143,6 @@ redis.call('SET', 'A', 'ErrroR')
""" """
@pytest.mark.asyncio
async def test_readonly_script(df_factory): async def test_readonly_script(df_factory):
master = df_factory.create(proactor_threads=2) master = df_factory.create(proactor_threads=2)
replica = df_factory.create(proactor_threads=2) replica = df_factory.create(proactor_threads=2)
@ -1188,7 +1175,6 @@ take_over_cases = [
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases) @pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_counters(df_factory, master_threads, replica_threads): async def test_take_over_counters(df_factory, master_threads, replica_threads):
master = df_factory.create(proactor_threads=master_threads) master = df_factory.create(proactor_threads=master_threads)
replica1 = df_factory.create(proactor_threads=replica_threads) replica1 = df_factory.create(proactor_threads=replica_threads)
@ -1243,7 +1229,6 @@ async def test_take_over_counters(df_factory, master_threads, replica_threads):
@pytest.mark.parametrize("master_threads, replica_threads", take_over_cases) @pytest.mark.parametrize("master_threads, replica_threads", take_over_cases)
@pytest.mark.asyncio
async def test_take_over_seeder( async def test_take_over_seeder(
request, df_factory, df_seeder_factory, master_threads, replica_threads request, df_factory, df_seeder_factory, master_threads, replica_threads
): ):
@ -1299,7 +1284,6 @@ async def test_take_over_seeder(
@pytest.mark.parametrize("master_threads, replica_threads", [[4, 4]]) @pytest.mark.parametrize("master_threads, replica_threads", [[4, 4]])
@pytest.mark.asyncio
async def test_take_over_read_commands(df_factory, master_threads, replica_threads): async def test_take_over_read_commands(df_factory, master_threads, replica_threads):
master = df_factory.create(proactor_threads=master_threads) master = df_factory.create(proactor_threads=master_threads)
replica = df_factory.create(proactor_threads=replica_threads) replica = df_factory.create(proactor_threads=replica_threads)
@ -1333,7 +1317,6 @@ async def test_take_over_read_commands(df_factory, master_threads, replica_threa
await promt_task await promt_task
@pytest.mark.asyncio
async def test_take_over_timeout(df_factory, df_seeder_factory): async def test_take_over_timeout(df_factory, df_seeder_factory):
master = df_factory.create(proactor_threads=2) master = df_factory.create(proactor_threads=2)
replica = df_factory.create(proactor_threads=2) replica = df_factory.create(proactor_threads=2)
@ -1379,7 +1362,6 @@ async def test_take_over_timeout(df_factory, df_seeder_factory):
replication_cases = [(8, 8)] replication_cases = [(8, 8)]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replica", replication_cases) @pytest.mark.parametrize("t_master, t_replica", replication_cases)
async def test_no_tls_on_admin_port( async def test_no_tls_on_admin_port(
df_factory: DflyInstanceFactory, df_factory: DflyInstanceFactory,
@ -1428,7 +1410,6 @@ async def test_no_tls_on_admin_port(
replication_cases = [(8, 8, False), (8, 8, True)] replication_cases = [(8, 8, False), (8, 8, True)]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replica, test_admin_port", replication_cases) @pytest.mark.parametrize("t_master, t_replica, test_admin_port", replication_cases)
async def test_tls_replication( async def test_tls_replication(
df_factory, df_factory,
@ -1521,7 +1502,6 @@ async def wait_for_replica_status(
raise RuntimeError("Client did not become available in time!") raise RuntimeError("Client did not become available in time!")
@pytest.mark.asyncio
async def test_replicaof_flag(df_factory): async def test_replicaof_flag(df_factory):
# tests --replicaof works under normal conditions # tests --replicaof works under normal conditions
master = df_factory.create( master = df_factory.create(
@ -1555,7 +1535,6 @@ async def test_replicaof_flag(df_factory):
assert "VALUE" == val assert "VALUE" == val
@pytest.mark.asyncio
async def test_replicaof_flag_replication_waits(df_factory): async def test_replicaof_flag_replication_waits(df_factory):
# tests --replicaof works when we launch replication before the master # tests --replicaof works when we launch replication before the master
BASE_PORT = 1111 BASE_PORT = 1111
@ -1599,7 +1578,6 @@ async def test_replicaof_flag_replication_waits(df_factory):
assert "VALUE" == val assert "VALUE" == val
@pytest.mark.asyncio
async def test_replicaof_flag_disconnect(df_factory): async def test_replicaof_flag_disconnect(df_factory):
# test stopping replication when started using --replicaof # test stopping replication when started using --replicaof
master = df_factory.create( master = df_factory.create(
@ -1639,7 +1617,6 @@ async def test_replicaof_flag_disconnect(df_factory):
assert role[0] == "master" assert role[0] == "master"
@pytest.mark.asyncio
async def test_df_crash_on_memcached_error(df_factory): async def test_df_crash_on_memcached_error(df_factory):
master = df_factory.create( master = df_factory.create(
memcached_port=11211, memcached_port=11211,
@ -1667,7 +1644,6 @@ async def test_df_crash_on_memcached_error(df_factory):
memcached_client.set("key", "data", noreply=False) memcached_client.set("key", "data", noreply=False)
@pytest.mark.asyncio
async def test_df_crash_on_replicaof_flag(df_factory): async def test_df_crash_on_replicaof_flag(df_factory):
master = df_factory.create( master = df_factory.create(
proactor_threads=2, proactor_threads=2,
@ -1931,7 +1907,6 @@ async def test_search_with_stream(df_factory: DflyInstanceFactory):
# @pytest.mark.slow # @pytest.mark.slow
@pytest.mark.asyncio
async def test_client_pause_with_replica(df_factory, df_seeder_factory): async def test_client_pause_with_replica(df_factory, df_seeder_factory):
master = df_factory.create(proactor_threads=4) master = df_factory.create(proactor_threads=4)
replica = df_factory.create(proactor_threads=4) replica = df_factory.create(proactor_threads=4)
@ -2003,7 +1978,6 @@ async def test_replicaof_reject_on_load(df_factory, df_seeder_factory):
await c_replica.execute_command(f"REPLICAOF localhost {master.port}") await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
@pytest.mark.asyncio
async def test_heartbeat_eviction_propagation(df_factory): async def test_heartbeat_eviction_propagation(df_factory):
master = df_factory.create( master = df_factory.create(
proactor_threads=1, cache_mode="true", maxmemory="256mb", enable_heartbeat_eviction="false" proactor_threads=1, cache_mode="true", maxmemory="256mb", enable_heartbeat_eviction="false"
@ -2039,7 +2013,6 @@ async def test_heartbeat_eviction_propagation(df_factory):
assert set(keys_master) == set(keys_replica) assert set(keys_master) == set(keys_replica)
@pytest.mark.asyncio
async def test_policy_based_eviction_propagation(df_factory, df_seeder_factory): async def test_policy_based_eviction_propagation(df_factory, df_seeder_factory):
master = df_factory.create( master = df_factory.create(
proactor_threads=2, proactor_threads=2,
@ -2076,7 +2049,6 @@ async def test_policy_based_eviction_propagation(df_factory, df_seeder_factory):
assert set(keys_master).difference(keys_replica) == set() assert set(keys_master).difference(keys_replica) == set()
@pytest.mark.asyncio
async def test_journal_doesnt_yield_issue_2500(df_factory, df_seeder_factory): async def test_journal_doesnt_yield_issue_2500(df_factory, df_seeder_factory):
""" """
Issues many SETEX commands through a Lua script so that no yields are done between them. Issues many SETEX commands through a Lua script so that no yields are done between them.
@ -2129,7 +2101,6 @@ async def test_journal_doesnt_yield_issue_2500(df_factory, df_seeder_factory):
assert set(keys_master) == set(keys_replica) assert set(keys_master) == set(keys_replica)
@pytest.mark.asyncio
async def test_saving_replica(df_factory): async def test_saving_replica(df_factory):
master = df_factory.create(proactor_threads=1) master = df_factory.create(proactor_threads=1)
replica = df_factory.create(proactor_threads=1, dbfilename=f"dump_{tmp_file_name()}") replica = df_factory.create(proactor_threads=1, dbfilename=f"dump_{tmp_file_name()}")
@ -2157,7 +2128,6 @@ async def test_saving_replica(df_factory):
assert not await is_saving(c_replica) assert not await is_saving(c_replica)
@pytest.mark.asyncio
async def test_start_replicating_while_save(df_factory): async def test_start_replicating_while_save(df_factory):
master = df_factory.create(proactor_threads=4) master = df_factory.create(proactor_threads=4)
replica = df_factory.create(proactor_threads=4, dbfilename=f"dump_{tmp_file_name()}") replica = df_factory.create(proactor_threads=4, dbfilename=f"dump_{tmp_file_name()}")
@ -2183,7 +2153,6 @@ async def test_start_replicating_while_save(df_factory):
assert not await is_saving(c_replica) assert not await is_saving(c_replica)
@pytest.mark.asyncio
async def test_user_acl_replication(df_factory): async def test_user_acl_replication(df_factory):
master = df_factory.create(proactor_threads=4) master = df_factory.create(proactor_threads=4)
replica = df_factory.create(proactor_threads=4) replica = df_factory.create(proactor_threads=4)
@ -2217,7 +2186,6 @@ async def test_user_acl_replication(df_factory):
@pytest.mark.parametrize("break_conn", [False, True]) @pytest.mark.parametrize("break_conn", [False, True])
@pytest.mark.asyncio
async def test_replica_reconnect(df_factory, break_conn): async def test_replica_reconnect(df_factory, break_conn):
""" """
Test replica does not connect to master if master restarted Test replica does not connect to master if master restarted
@ -2270,7 +2238,6 @@ async def test_replica_reconnect(df_factory, break_conn):
assert await c_replica.execute_command("get k") == "6789" assert await c_replica.execute_command("get k") == "6789"
@pytest.mark.asyncio
async def test_announce_ip_port(df_factory): async def test_announce_ip_port(df_factory):
master = df_factory.create() master = df_factory.create()
replica = df_factory.create(replica_announce_ip="overrode-host", announce_port="1337") replica = df_factory.create(replica_announce_ip="overrode-host", announce_port="1337")
@ -2291,7 +2258,6 @@ async def test_announce_ip_port(df_factory):
assert port == "1337" assert port == "1337"
@pytest.mark.asyncio
async def test_replication_timeout_on_full_sync(df_factory: DflyInstanceFactory, df_seeder_factory): async def test_replication_timeout_on_full_sync(df_factory: DflyInstanceFactory, df_seeder_factory):
# setting replication_timeout to a very small value to force the replica to timeout # setting replication_timeout to a very small value to force the replica to timeout
master = df_factory.create(replication_timeout=100, vmodule="replica=2,dflycmd=2") master = df_factory.create(replication_timeout=100, vmodule="replica=2,dflycmd=2")
@ -2435,7 +2401,6 @@ async def test_replicate_old_master(
# For more information plz refer to the issue on gh: # For more information plz refer to the issue on gh:
# https://github.com/dragonflydb/dragonfly/issues/3504 # https://github.com/dragonflydb/dragonfly/issues/3504
@dfly_args({"proactor_threads": 1}) @dfly_args({"proactor_threads": 1})
@pytest.mark.asyncio
async def test_empty_hash_map_replicate_old_master(df_factory): async def test_empty_hash_map_replicate_old_master(df_factory):
cpu = platform.processor() cpu = platform.processor()
if cpu != "x86_64": if cpu != "x86_64":
@ -2494,7 +2459,6 @@ async def test_empty_hash_map_replicate_old_master(df_factory):
# For more information plz refer to the issue on gh: # For more information plz refer to the issue on gh:
# https://github.com/dragonflydb/dragonfly/issues/3504 # https://github.com/dragonflydb/dragonfly/issues/3504
@dfly_args({"proactor_threads": 1}) @dfly_args({"proactor_threads": 1})
@pytest.mark.asyncio
async def test_empty_hashmap_loading_bug(df_factory: DflyInstanceFactory): async def test_empty_hashmap_loading_bug(df_factory: DflyInstanceFactory):
cpu = platform.processor() cpu = platform.processor()
if cpu != "x86_64": if cpu != "x86_64":
@ -2565,7 +2529,6 @@ async def test_replicating_mc_flags(df_factory):
assert c_mc_replica.get(f"key{i}") == str.encode(f"value{i}") assert c_mc_replica.get(f"key{i}") == str.encode(f"value{i}")
@pytest.mark.asyncio
async def test_double_take_over(df_factory, df_seeder_factory): async def test_double_take_over(df_factory, df_seeder_factory):
master = df_factory.create(proactor_threads=4, dbfilename="", admin_port=ADMIN_PORT) master = df_factory.create(proactor_threads=4, dbfilename="", admin_port=ADMIN_PORT)
replica = df_factory.create(proactor_threads=4, dbfilename="", admin_port=ADMIN_PORT + 1) replica = df_factory.create(proactor_threads=4, dbfilename="", admin_port=ADMIN_PORT + 1)
@ -2607,7 +2570,6 @@ async def test_double_take_over(df_factory, df_seeder_factory):
assert await seeder.compare(capture, port=master.port) assert await seeder.compare(capture, port=master.port)
@pytest.mark.asyncio
async def test_replica_of_replica(df_factory): async def test_replica_of_replica(df_factory):
# Can't connect a replica to a replica, but OK to connect 2 replicas to the same master # Can't connect a replica to a replica, but OK to connect 2 replicas to the same master
master = df_factory.create(proactor_threads=2) master = df_factory.create(proactor_threads=2)
@ -2627,7 +2589,6 @@ async def test_replica_of_replica(df_factory):
assert await c_replica2.execute_command(f"REPLICAOF localhost {master.port}") == "OK" assert await c_replica2.execute_command(f"REPLICAOF localhost {master.port}") == "OK"
@pytest.mark.asyncio
async def test_replication_timeout_on_full_sync_heartbeat_expiry( async def test_replication_timeout_on_full_sync_heartbeat_expiry(
df_factory: DflyInstanceFactory, df_seeder_factory df_factory: DflyInstanceFactory, df_seeder_factory
): ):
@ -2680,7 +2641,6 @@ async def test_replication_timeout_on_full_sync_heartbeat_expiry(
"element_size, elements_number", "element_size, elements_number",
[(16, 20000), (20000, 16)], [(16, 20000), (20000, 16)],
) )
@pytest.mark.asyncio
async def test_big_containers(df_factory, element_size, elements_number): async def test_big_containers(df_factory, element_size, elements_number):
master = df_factory.create(proactor_threads=4) master = df_factory.create(proactor_threads=4)
replica = df_factory.create(proactor_threads=4) replica = df_factory.create(proactor_threads=4)
@ -2708,3 +2668,18 @@ async def test_big_containers(df_factory, element_size, elements_number):
replica_data = await StaticSeeder.capture(c_replica) replica_data = await StaticSeeder.capture(c_replica)
master_data = await StaticSeeder.capture(c_master) master_data = await StaticSeeder.capture(c_master)
assert master_data == replica_data assert master_data == replica_data
async def test_master_too_big(df_factory):
master = df_factory.create(proactor_threads=4)
replica = df_factory.create(proactor_threads=2, maxmemory="600mb")
df_factory.start_all([master, replica])
c_master = master.client()
c_replica = replica.client()
await c_master.execute_command("DEBUG POPULATE 1000000 key 1000 RAND")
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
# We should never sync due to used memory too high during full sync
with pytest.raises(TimeoutError):
await wait_available_async(c_replica, timeout=10)

View file

@ -83,8 +83,8 @@ async def tick_timer(func, timeout=5, step=0.1):
await asyncio.sleep(step) await asyncio.sleep(step)
if last_error: if last_error:
raise RuntimeError("Timed out!") from last_error raise TimeoutError("Timed out!") from last_error
raise RuntimeError("Timed out!") raise TimeoutError("Timed out!")
async def info_tick_timer(client: aioredis.Redis, section=None, **kwargs): async def info_tick_timer(client: aioredis.Redis, section=None, **kwargs):
@ -113,7 +113,7 @@ async def wait_available_async(
assert "Dragonfly is loading the dataset in memory" in str(e) assert "Dragonfly is loading the dataset in memory" in str(e)
timeout -= time.time() - start timeout -= time.time() - start
if timeout <= 0: if timeout <= 0:
raise RuntimeError("Timed out!") raise TimeoutError("Timed out!")
# Secondly for replicas, we make sure they reached stable state replicaton # Secondly for replicas, we make sure they reached stable state replicaton
async for info, breaker in info_tick_timer(clients, "REPLICATION", timeout=timeout): async for info, breaker in info_tick_timer(clients, "REPLICATION", timeout=timeout):