feat: Yield inside huge values migration serialization (#4197)

* feat: Yield inside huge values migration serialization

With #4144 we break huge values slot migration into multiple commands.
This PR now adds yield between those commands.
It also adds a test that checks that modifying huge values while doing a
migration works well, and that RSS doesn't grow too much.

Fixes #4100
This commit is contained in:
Shahar Mike 2025-01-05 16:28:45 +02:00 committed by GitHub
parent ff4add0c9e
commit 7860a169d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 148 additions and 16 deletions

View file

@ -218,6 +218,12 @@ void RestoreStreamer::Run() {
if (fiber_cancelled_) // Could have been cancelled in above call too if (fiber_cancelled_) // Could have been cancelled in above call too
return; return;
std::lock_guard guard(big_value_mu_);
// Locking this never preempts. See snapshot.cc for why we need it.
auto* blocking_counter = db_slice_->BlockingCounter();
std::lock_guard blocking_counter_guard(*blocking_counter);
WriteBucket(it); WriteBucket(it);
}); });
@ -281,7 +287,6 @@ bool RestoreStreamer::ShouldWrite(cluster::SlotId slot_id) const {
void RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) { void RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) {
if (it.GetVersion() < snapshot_version_) { if (it.GetVersion() < snapshot_version_) {
FiberAtomicGuard fg;
it.SetVersion(snapshot_version_); it.SetVersion(snapshot_version_);
string key_buffer; // we can reuse it string key_buffer; // we can reuse it
for (; !it.is_done(); ++it) { for (; !it.is_done(); ++it) {
@ -302,6 +307,7 @@ void RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) {
} }
void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) { void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) {
std::lock_guard guard(big_value_mu_);
DCHECK_EQ(db_index, 0) << "Restore migration only allowed in cluster mode in db0"; DCHECK_EQ(db_index, 0) << "Restore migration only allowed in cluster mode in db0";
PrimeTable* table = db_slice_->GetTables(0).first; PrimeTable* table = db_slice_->GetTables(0).first;
@ -319,8 +325,12 @@ void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req
void RestoreStreamer::WriteEntry(string_view key, const PrimeValue& pk, const PrimeValue& pv, void RestoreStreamer::WriteEntry(string_view key, const PrimeValue& pk, const PrimeValue& pv,
uint64_t expire_ms) { uint64_t expire_ms) {
CmdSerializer serializer([&](std::string s) { Write(std::move(s)); }, CmdSerializer serializer(
ServerState::tlocal()->serialization_max_chunk_size); [&](std::string s) {
Write(std::move(s));
ThrottleIfNeeded();
},
ServerState::tlocal()->serialization_max_chunk_size);
serializer.SerializeEntry(key, pk, pv, expire_ms); serializer.SerializeEntry(key, pk, pv, expire_ms);
} }

View file

@ -112,6 +112,7 @@ class RestoreStreamer : public JournalStreamer {
cluster::SlotSet my_slots_; cluster::SlotSet my_slots_;
bool fiber_cancelled_ = false; bool fiber_cancelled_ = false;
bool snapshot_finished_ = false; bool snapshot_finished_ = false;
ThreadLocalMutex big_value_mu_;
}; };
} // namespace dfly } // namespace dfly

View file

@ -14,8 +14,7 @@ from .replication_test import check_all_replicas_finished
from redis.cluster import RedisCluster from redis.cluster import RedisCluster
from redis.cluster import ClusterNode from redis.cluster import ClusterNode
from .proxy import Proxy from .proxy import Proxy
from .seeder import SeederBase from .seeder import Seeder, SeederBase, StaticSeeder
from .seeder import StaticSeeder
from . import dfly_args from . import dfly_args
@ -33,6 +32,11 @@ def monotonically_increasing_port_number():
next_port = monotonically_increasing_port_number() next_port = monotonically_increasing_port_number()
async def get_memory(client, field):
info = await client.info("memory")
return info[field]
class RedisClusterNode: class RedisClusterNode:
def __init__(self, port): def __init__(self, port):
self.port = port self.port = port
@ -1981,6 +1985,7 @@ async def test_cluster_migration_cancel(df_factory: DflyInstanceFactory):
@dfly_args({"proactor_threads": 2, "cluster_mode": "yes"}) @dfly_args({"proactor_threads": 2, "cluster_mode": "yes"})
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.opt_only
async def test_cluster_migration_huge_container(df_factory: DflyInstanceFactory): async def test_cluster_migration_huge_container(df_factory: DflyInstanceFactory):
instances = [ instances = [
df_factory.create(port=next(next_port), admin_port=next(next_port)) for i in range(2) df_factory.create(port=next(next_port), admin_port=next(next_port)) for i in range(2)
@ -1995,7 +2000,7 @@ async def test_cluster_migration_huge_container(df_factory: DflyInstanceFactory)
logging.debug("Generating huge containers") logging.debug("Generating huge containers")
seeder = StaticSeeder( seeder = StaticSeeder(
key_target=10, key_target=100,
data_size=10_000_000, data_size=10_000_000,
collection_size=10_000, collection_size=10_000,
variance=1, variance=1,
@ -2005,6 +2010,8 @@ async def test_cluster_migration_huge_container(df_factory: DflyInstanceFactory)
await seeder.run(nodes[0].client) await seeder.run(nodes[0].client)
source_data = await StaticSeeder.capture(nodes[0].client) source_data = await StaticSeeder.capture(nodes[0].client)
mem_before = await get_memory(nodes[0].client, "used_memory_rss")
nodes[0].migrations = [ nodes[0].migrations = [
MigrationInfo("127.0.0.1", instances[1].admin_port, [(0, 16383)], nodes[1].id) MigrationInfo("127.0.0.1", instances[1].admin_port, [(0, 16383)], nodes[1].id)
] ]
@ -2017,6 +2024,74 @@ async def test_cluster_migration_huge_container(df_factory: DflyInstanceFactory)
target_data = await StaticSeeder.capture(nodes[1].client) target_data = await StaticSeeder.capture(nodes[1].client)
assert source_data == target_data assert source_data == target_data
# Get peak memory, because migration removes the data
mem_after = await get_memory(nodes[0].client, "used_memory_peak_rss")
logging.debug(f"Memory before {mem_before} after {mem_after}")
assert mem_after < mem_before * 1.1
@dfly_args({"proactor_threads": 2, "cluster_mode": "yes"})
@pytest.mark.parametrize("chunk_size", [1_000_000, 30])
@pytest.mark.asyncio
async def test_cluster_migration_while_seeding(
df_factory: DflyInstanceFactory, df_seeder_factory: DflySeederFactory, chunk_size
):
instances = [
df_factory.create(
port=next(next_port),
admin_port=next(next_port),
serialization_max_chunk_size=chunk_size,
)
for _ in range(2)
]
df_factory.start_all(instances)
nodes = [await create_node_info(instance) for instance in instances]
nodes[0].slots = [(0, 16383)]
nodes[1].slots = []
client0 = nodes[0].client
client1 = nodes[1].client
await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes])
logging.debug("Seeding cluster")
seeder = df_seeder_factory.create(
keys=10_000, port=instances[0].port, cluster_mode=True, mirror_to_fake_redis=True
)
await seeder.run(target_deviation=0.1)
seed = asyncio.create_task(seeder.run())
await asyncio.sleep(1)
nodes[0].migrations = [
MigrationInfo("127.0.0.1", instances[1].admin_port, [(0, 16383)], nodes[1].id)
]
logging.debug("Migrating slots")
await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes])
logging.debug("Waiting for migration to finish")
await wait_for_status(nodes[0].admin_client, nodes[1].id, "FINISHED", timeout=300)
logging.debug("Migration finished")
logging.debug("Finalizing migration")
nodes[0].slots = []
nodes[1].slots = [(0, 16383)]
await push_config(json.dumps(generate_config(nodes)), [node.admin_client for node in nodes])
await asyncio.sleep(1) # Let seeder feed dest before migration finishes
seeder.stop()
await seed
logging.debug("Seeding finished")
assert (
await get_memory(client0, "used_memory_peak_rss")
< await get_memory(client0, "used_memory_rss") * 1.1
)
capture = await seeder.capture_fake_redis()
assert await seeder.compare(capture, instances[1].port)
def parse_lag(replication_info: str): def parse_lag(replication_info: str):
lags = re.findall("lag=([0-9]+)\r\n", replication_info) lags = re.findall("lag=([0-9]+)\r\n", replication_info)

View file

@ -25,3 +25,4 @@ pytest-emoji==0.2.0
pytest-icdiff==0.8 pytest-icdiff==0.8
pytest-timeout==2.2.0 pytest-timeout==2.2.0
asyncio==3.4.3 asyncio==3.4.3
fakeredis[json]==2.26.2

View file

@ -177,14 +177,16 @@ class Seeder(SeederBase):
] ]
sha = await client.script_load(Seeder._load_script("generate")) sha = await client.script_load(Seeder._load_script("generate"))
await asyncio.gather( for unit in self.units:
*(self._run_unit(client, sha, unit, using_stopkey, args) for unit in self.units) # Must be serial, otherwise cluster clients throws an exception
) await self._run_unit(client, sha, unit, using_stopkey, args)
async def stop(self, client: aioredis.Redis): async def stop(self, client: aioredis.Redis):
"""Request seeder seeder if it's running without a target, future returned from start() must still be awaited""" """Request seeder seeder if it's running without a target, future returned from start() must still be awaited"""
await asyncio.gather(*(client.set(unit.stop_key, "X") for unit in self.units)) for unit in self.units:
# Must be serial, otherwise cluster clients throws an exception
await client.set(unit.stop_key, "X")
def change_key_target(self, target: int): def change_key_target(self, target: int):
"""Change key target, applied only on succeeding runs""" """Change key target, applied only on succeeding runs"""

View file

@ -4,6 +4,8 @@ import string
from redis import asyncio as aioredis from redis import asyncio as aioredis
from . import dfly_args from . import dfly_args
from .seeder import Seeder, StaticSeeder from .seeder import Seeder, StaticSeeder
from .instance import DflyInstanceFactory, DflyInstance
from .utility import *
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
@ -114,3 +116,22 @@ async def test_seeder_capture(async_client: aioredis.Redis):
# Do another change # Do another change
await async_client.spop("set1") await async_client.spop("set1")
assert capture != await Seeder.capture(async_client) assert capture != await Seeder.capture(async_client)
@pytest.mark.asyncio
@dfly_args({"proactor_threads": 2})
async def test_seeder_fake_redis(
df_factory: DflyInstanceFactory, df_seeder_factory: DflySeederFactory
):
instance = df_factory.create()
df_factory.start_all([instance])
seeder = df_seeder_factory.create(
keys=100, port=instance.port, unsupported_types=[ValueType.JSON], mirror_to_fake_redis=True
)
await seeder.run(target_ops=5_000)
capture = await seeder.capture_fake_redis()
assert await seeder.compare(capture, instance.port)

View file

@ -14,6 +14,7 @@ import json
import subprocess import subprocess
import pytest import pytest
import os import os
import fakeredis
from typing import Iterable, Union from typing import Iterable, Union
from enum import Enum from enum import Enum
@ -271,7 +272,7 @@ class CommandGenerator:
("LPUSH {k} {val}", ValueType.LIST), ("LPUSH {k} {val}", ValueType.LIST),
("LPOP {k}", ValueType.LIST), ("LPOP {k}", ValueType.LIST),
("SADD {k} {val}", ValueType.SET), ("SADD {k} {val}", ValueType.SET),
("SPOP {k}", ValueType.SET), # ("SPOP {k}", ValueType.SET), # Disabled because it is inconsistent
("HSETNX {k} v0 {val}", ValueType.HSET), ("HSETNX {k} v0 {val}", ValueType.HSET),
("HINCRBY {k} v1 1", ValueType.HSET), ("HINCRBY {k} v1 1", ValueType.HSET),
("ZPOPMIN {k} 1", ValueType.ZSET), ("ZPOPMIN {k} 1", ValueType.ZSET),
@ -423,6 +424,7 @@ class DflySeeder:
unsupported_types=[], unsupported_types=[],
stop_on_failure=True, stop_on_failure=True,
cluster_mode=False, cluster_mode=False,
mirror_to_fake_redis=False,
): ):
if cluster_mode: if cluster_mode:
max_multikey = 1 max_multikey = 1
@ -436,11 +438,16 @@ class DflySeeder:
self.multi_transaction_probability = multi_transaction_probability self.multi_transaction_probability = multi_transaction_probability
self.stop_flag = False self.stop_flag = False
self.stop_on_failure = stop_on_failure self.stop_on_failure = stop_on_failure
self.fake_redis = None
self.log_file = log_file self.log_file = log_file
if self.log_file is not None: if self.log_file is not None:
open(self.log_file, "w").close() open(self.log_file, "w").close()
if mirror_to_fake_redis:
logging.debug("Creating FakeRedis instance")
self.fake_redis = fakeredis.FakeAsyncRedis()
async def run(self, target_ops=None, target_deviation=None): async def run(self, target_ops=None, target_deviation=None):
""" """
Run a seeding cycle on all dbs either until stop(), a fixed number of commands (target_ops) Run a seeding cycle on all dbs either until stop(), a fixed number of commands (target_ops)
@ -474,6 +481,14 @@ class DflySeeder:
"""Reset internal state. Needs to be called after flush or restart""" """Reset internal state. Needs to be called after flush or restart"""
self.gen.reset() self.gen.reset()
async def capture_fake_redis(self):
keys = sorted(list(self.gen.keys_and_types()))
# TODO: support multiple databases
assert self.dbcount == 1
assert self.fake_redis != None
capture = DataCapture(await self._capture_entries(self.fake_redis, keys))
return [capture]
async def capture(self, port=None): async def capture(self, port=None):
"""Create DataCapture for all dbs""" """Create DataCapture for all dbs"""
@ -588,12 +603,19 @@ class DflySeeder:
queue.task_done() queue.task_done()
break break
pipe = client.pipeline(transaction=tx_data[1])
for cmd in tx_data[0]:
pipe.execute_command(*cmd)
try: try:
await pipe.execute() if self.fake_redis is None:
pipe = client.pipeline(transaction=tx_data[1])
for cmd in tx_data[0]:
pipe.execute_command(*cmd)
await pipe.execute()
else:
# To mirror consistently to Fake Redis we must only send to it successful
# commands. We can't use pipes because they might succeed partially.
for cmd in tx_data[0]:
dfly_resp = await client.execute_command(*cmd)
fake_resp = await self.fake_redis.execute_command(*cmd)
assert dfly_resp == fake_resp
except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e: except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError) as e:
if self.stop_on_failure: if self.stop_on_failure:
await self._close_client(client) await self._close_client(client)