fix(server): fix compatibility with rdb snapshot (#3121)

* fix server: fix compatibility with rdb snapshot


Signed-off-by: adi_holden <adi@dragonflydb.io>
This commit is contained in:
adiholden 2024-06-04 09:28:18 +03:00 committed by GitHub
parent b1063f7823
commit 6e33261402
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 111 additions and 56 deletions

View file

@ -70,7 +70,7 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_; VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_;
snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] { snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] {
IterateBucketsFb(cll); IterateBucketsFb(cll, stream_journal);
db_slice_->UnregisterOnChange(snapshot_version_); db_slice_->UnregisterOnChange(snapshot_version_);
if (cll->IsCancelled()) { if (cll->IsCancelled()) {
Cancel(); Cancel();
@ -174,7 +174,7 @@ void SliceSnapshot::Join() {
// and survived until it finished. // and survived until it finished.
// Serializes all the entries with version less than snapshot_version_. // Serializes all the entries with version less than snapshot_version_.
void SliceSnapshot::IterateBucketsFb(const Cancellation* cll) { void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut) {
{ {
auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex()); auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex());
ThisFiber::SetName(std::move(fiber_name)); ThisFiber::SetName(std::move(fiber_name));
@ -223,8 +223,10 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll) {
} // for (dbindex) } // for (dbindex)
CHECK(!serialize_bucket_running_); CHECK(!serialize_bucket_running_);
CHECK(!serializer_->SendFullSyncCut()); if (send_full_sync_cut) {
PushSerializedToChannel(true); CHECK(!serializer_->SendFullSyncCut());
PushSerializedToChannel(true);
}
// serialized + side_saved must be equal to the total saved. // serialized + side_saved must be equal to the total saved.
VLOG(1) << "Exit SnapshotSerializer (loop_serialized/side_saved/cbcalls): " VLOG(1) << "Exit SnapshotSerializer (loop_serialized/side_saved/cbcalls): "

View file

@ -88,7 +88,7 @@ class SliceSnapshot {
private: private:
// Main fiber that iterates over all buckets in the db slice // Main fiber that iterates over all buckets in the db slice
// and submits them to SerializeBucket. // and submits them to SerializeBucket.
void IterateBucketsFb(const Cancellation* cll); void IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut);
// Called on traversing cursor by IterateBucketsFb. // Called on traversing cursor by IterateBucketsFb.
bool BucketSaveCb(PrimeIterator it); bool BucketSaveCb(PrimeIterator it);

View file

@ -21,7 +21,7 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from .instance import DflyInstance, DflyParams, DflyInstanceFactory from .instance import DflyInstance, DflyParams, DflyInstanceFactory, RedisServer
from . import PortPicker, dfly_args from . import PortPicker, dfly_args
from .utility import DflySeederFactory, gen_ca_cert, gen_certificate from .utility import DflySeederFactory, gen_ca_cert, gen_certificate
@ -366,3 +366,24 @@ def run_before_and_after_test():
yield # this is where the testing happens yield # this is where the testing happens
# Teardown # Teardown
@pytest.fixture(scope="function")
def redis_server(port_picker) -> RedisServer:
s = RedisServer(port_picker.get_available_port())
try:
s.start()
except FileNotFoundError as e:
pytest.skip("Redis server not found")
return None
time.sleep(1)
yield s
s.stop()
@pytest.fixture(scope="function")
def redis_local_server(port_picker) -> RedisServer:
s = RedisServer(port_picker.get_available_port())
time.sleep(1)
yield s
s.stop()

View file

@ -364,3 +364,36 @@ class DflyInstanceFactory:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Factory({self.args})" return f"Factory({self.args})"
class RedisServer:
def __init__(self, port):
self.port = port
self.proc = None
def start(self, **kwargs):
command = [
"redis-server-6.2.11",
f"--port {self.port}",
"--save ''",
"--appendonly no",
"--protected-mode no",
"--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0",
]
# Convert kwargs to command-line arguments
for key, value in kwargs.items():
if value is None:
command.append(f"--{key}")
else:
command.append(f"--{key} {value}")
self.proc = subprocess.Popen(command)
logging.debug(self.proc.args)
def stop(self):
self.proc.terminate()
try:
self.proc.wait(timeout=10)
except Exception as e:
pass

View file

@ -8,33 +8,6 @@ from .instance import DflyInstanceFactory
from .proxy import Proxy from .proxy import Proxy
class RedisServer:
def __init__(self, port):
self.port = port
self.proc = None
def start(self):
self.proc = subprocess.Popen(
[
"redis-server-6.2.11",
f"--port {self.port}",
"--save ''",
"--appendonly no",
"--protected-mode no",
"--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0",
]
)
logging.debug(self.proc.args)
def stop(self):
self.proc.terminate()
try:
self.proc.wait(timeout=10)
except Exception as e:
pass
# Checks that master redis and dragonfly replica are synced by writing a random key to master # Checks that master redis and dragonfly replica are synced by writing a random key to master
# and waiting for it to exist in replica. Foreach db in 0..dbcount-1. # and waiting for it to exist in replica. Foreach db in 0..dbcount-1.
async def await_synced(c_master: aioredis.Redis, c_replica: aioredis.Redis, dbcount=1): async def await_synced(c_master: aioredis.Redis, c_replica: aioredis.Redis, dbcount=1):
@ -71,19 +44,6 @@ async def check_data(seeder, replicas, c_replicas):
assert await seeder.compare(capture, port=replica.port) assert await seeder.compare(capture, port=replica.port)
@pytest.fixture(scope="function")
def redis_server(port_picker) -> RedisServer:
s = RedisServer(port_picker.get_available_port())
try:
s.start()
except FileNotFoundError as e:
pytest.skip("Redis server not found")
return None
time.sleep(1)
yield s
s.stop()
full_sync_replication_specs = [ full_sync_replication_specs = [
([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])), ([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])),
([1], dict(keys=5000, dbcount=2, unsupported_types=[ValueType.JSON])), ([1], dict(keys=5000, dbcount=2, unsupported_types=[ValueType.JSON])),

View file

@ -18,20 +18,24 @@ except ImportError:
class SeederBase: class SeederBase:
UID_COUNTER = 1 # multiple generators should not conflict on keys UID_COUNTER = 1 # multiple generators should not conflict on keys
CACHED_SCRIPTS = {} CACHED_SCRIPTS = {}
TYPES = ["STRING", "LIST", "SET", "HASH", "ZSET", "JSON"] DEFAULT_TYPES = ["STRING", "LIST", "SET", "HASH", "ZSET", "JSON"]
def __init__(self): def __init__(self, types: typing.Optional[typing.List[str]] = None):
self.uid = SeederBase.UID_COUNTER self.uid = SeederBase.UID_COUNTER
SeederBase.UID_COUNTER += 1 SeederBase.UID_COUNTER += 1
self.types = types if types is not None else SeederBase.DEFAULT_TYPES
@classmethod @classmethod
async def capture(clz, client: aioredis.Redis) -> typing.Tuple[int]: async def capture(
clz, client: aioredis.Redis, types: typing.Optional[typing.List[str]] = None
) -> typing.Tuple[int]:
"""Generate hash capture for all data stored in instance pointed by client""" """Generate hash capture for all data stored in instance pointed by client"""
sha = await client.script_load(clz._load_script("hash")) sha = await client.script_load(clz._load_script("hash"))
types_to_capture = types if types is not None else clz.DEFAULT_TYPES
return tuple( return tuple(
await asyncio.gather( await asyncio.gather(
*(clz._run_capture(client, sha, data_type) for data_type in clz.TYPES) *(clz._run_capture(client, sha, data_type) for data_type in types_to_capture)
) )
) )
@ -69,8 +73,15 @@ class SeederBase:
class StaticSeeder(SeederBase): class StaticSeeder(SeederBase):
"""Wrapper around DEBUG POPULATE with fuzzy key sizes and a balanced type mix""" """Wrapper around DEBUG POPULATE with fuzzy key sizes and a balanced type mix"""
def __init__(self, key_target=10_000, data_size=100, variance=5, samples=10): def __init__(
SeederBase.__init__(self) self,
key_target=10_000,
data_size=100,
variance=5,
samples=10,
types: typing.Optional[typing.List[str]] = None,
):
SeederBase.__init__(self, types)
self.key_target = key_target self.key_target = key_target
self.data_size = data_size self.data_size = data_size
self.variance = variance self.variance = variance
@ -79,7 +90,7 @@ class StaticSeeder(SeederBase):
async def run(self, client: aioredis.Redis): async def run(self, client: aioredis.Redis):
"""Run with specified options until key_target is met""" """Run with specified options until key_target is met"""
samples = [ samples = [
(dtype, f"k-s{self.uid}u{i}-") for i, dtype in enumerate(self.TYPES * self.samples) (dtype, f"k-s{self.uid}u{i}-") for i, dtype in enumerate(self.types * self.samples)
] ]
# Handle samples in chuncks of 24 to not overload client pool and instance # Handle samples in chuncks of 24 to not overload client pool and instance
@ -89,7 +100,7 @@ class StaticSeeder(SeederBase):
) )
async def _run_unit(self, client: aioredis.Redis, dtype: str, prefix: str): async def _run_unit(self, client: aioredis.Redis, dtype: str, prefix: str):
key_target = self.key_target // (self.samples * len(self.TYPES)) key_target = self.key_target // (self.samples * len(self.types))
if dtype == "STRING": if dtype == "STRING":
dsize = random.uniform(self.data_size / self.variance, self.data_size * self.variance) dsize = random.uniform(self.data_size / self.variance, self.data_size * self.variance)
csize = 1 csize = 1
@ -120,7 +131,7 @@ class Seeder(SeederBase):
self.units = [ self.units = [
Seeder.Unit( Seeder.Unit(
prefix=f"k-s{self.uid}u{i}-", prefix=f"k-s{self.uid}u{i}-",
type=Seeder.TYPES[i % len(Seeder.TYPES)], type=Seeder.DEFAULT_TYPES[i % len(Seeder.DEFAULT_TYPES)],
counter=0, counter=0,
stop_key=f"_s{self.uid}u{i}-stop", stop_key=f"_s{self.uid}u{i}-stop",
) )

View file

@ -17,7 +17,7 @@ async def test_static_seeder(async_client: aioredis.Redis):
@dfly_args({"proactor_threads": 4}) @dfly_args({"proactor_threads": 4})
async def test_seeder_key_target(async_client: aioredis.Redis): async def test_seeder_key_target(async_client: aioredis.Redis):
"""Ensure seeder reaches its key targets""" """Ensure seeder reaches its key targets"""
s = Seeder(units=len(Seeder.TYPES) * 2, key_target=5000) s = Seeder(units=len(Seeder.DEFAULT_TYPES) * 2, key_target=5000)
# Ensure tests are not reasonably slow # Ensure tests are not reasonably slow
async with async_timeout.timeout(1 + 4): async with async_timeout.timeout(1 + 4):

View file

@ -1,4 +1,5 @@
import pytest import pytest
import logging
import os import os
import glob import glob
import asyncio import asyncio
@ -7,6 +8,7 @@ import redis
from redis import asyncio as aioredis from redis import asyncio as aioredis
from pathlib import Path from pathlib import Path
import boto3 import boto3
from .instance import RedisServer
from . import dfly_args from . import dfly_args
from .utility import wait_available_async, chunked, is_saving from .utility import wait_available_async, chunked, is_saving
@ -124,6 +126,32 @@ async def test_dbfilenames(
assert await StaticSeeder.capture(client) == start_capture assert await StaticSeeder.capture(client) == start_capture
@pytest.mark.asyncio
@dfly_args({**BASIC_ARGS, "proactor_threads": 4, "dbfilename": "test-redis-load-rdb"})
async def test_redis_load_snapshot(
async_client: aioredis.Redis, df_server, redis_local_server: RedisServer, tmp_dir: Path
):
"""
Test redis server loading dragonfly snapshot rdb format
"""
await StaticSeeder(
**LIGHTWEIGHT_SEEDER_ARGS, types=["STRING", "LIST", "SET", "HASH", "ZSET"]
).run(async_client)
await async_client.execute_command("SAVE", "rdb")
dbsize = await async_client.dbsize()
await async_client.connection_pool.disconnect()
df_server.stop()
redis_local_server.start(dir=tmp_dir, dbfilename="test-redis-load-rdb.rdb")
await asyncio.sleep(1)
c_master = aioredis.Redis(port=redis_local_server.port)
await c_master.ping()
assert await c_master.dbsize() == dbsize
@pytest.mark.slow @pytest.mark.slow
@dfly_args({**BASIC_ARGS, "dbfilename": "test-cron", "snapshot_cron": "* * * * *"}) @dfly_args({**BASIC_ARGS, "dbfilename": "test-cron", "snapshot_cron": "* * * * *"})
async def test_cron_snapshot(tmp_dir: Path, async_client: aioredis.Redis): async def test_cron_snapshot(tmp_dir: Path, async_client: aioredis.Redis):