Pytests overhaul (#569)

This commit is contained in:
Vladislav 2023-01-09 23:31:15 +03:00 committed by GitHub
parent 25a16ed343
commit 5ef8454aa7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 647 additions and 194 deletions

View file

@ -15,6 +15,8 @@ You can override the location of the binary using `DRAGONFLY_PATH` environment v
### Custom arguments
- use `--gdb` to start all instances inside gdb.
- use `--df arg=val` to pass custom arguments to all dragonfly instances. Can be used multiple times.
- use `--log-seeder file` to store all single-db commands from the lastest tests seeder inside file.
### Before you start
Please make sure that you have python 3 installed on you local host.

View file

@ -1,19 +1,19 @@
import pytest
import typing
import time
import subprocess
import time
import subprocess
from dataclasses import dataclass
START_DELAY = 0.4
START_GDB_DELAY = 3.0
@dataclass
class DflyParams:
path: str
cwd: str
gdb: bool
args: list
env: any
@ -29,24 +29,13 @@ class DflyInstance:
self.proc = None
def start(self):
arglist = DflyInstance.format_args(self.args)
print(f"Starting instance on {self.port} with arguments {arglist}")
args = [self.params.path, *arglist]
if self.params.gdb:
args = ["gdb", "--ex", "r", "--args"] + args
self.proc = subprocess.Popen(args, cwd=self.params.cwd)
self._start()
# Give Dragonfly time to start and detect possible failure causes
# Gdb starts slowly
time.sleep(0.4 if not self.params.gdb else 3.0)
time.sleep(START_DELAY if not self.params.gdb else START_GDB_DELAY)
return_code = self.proc.poll()
if return_code is not None:
raise Exception(
f"Failed to start instance, return code {return_code}")
self._check_status()
def stop(self, kill=False):
proc, self.proc = self.proc, None
@ -59,11 +48,26 @@ class DflyInstance:
proc.kill()
else:
proc.terminate()
outs, errs = proc.communicate(timeout=15)
proc.communicate(timeout=15)
except subprocess.TimeoutExpired:
print("Unable to terminate DragonflyDB gracefully, it was killed")
outs, errs = proc.communicate()
print(outs, errs)
proc.kill()
def _start(self):
base_args = [f"--{v}" for v in self.params.args]
all_args = self.format_args(self.args) + base_args
print(f"Starting instance on {self.port} with arguments {all_args}")
run_cmd = [self.params.path, *all_args]
if self.params.gdb:
run_cmd = ["gdb", "--ex", "r", "--args"] + run_cmd
self.proc = subprocess.Popen(run_cmd, cwd=self.params.cwd)
def _check_status(self):
return_code = self.proc.poll()
if return_code is not None:
raise Exception(
f"Failed to start instance, return code {return_code}")
def __getitem__(self, k):
return self.args.get(k)
@ -99,6 +103,17 @@ class DflyInstanceFactory:
self.instances.append(instance)
return instance
def start_all(self, instances):
""" Start multiple instances in parallel """
for instance in instances:
instance._start()
delay = START_DELAY if not self.params.gdb else START_GDB_DELAY
time.sleep(delay * (1 + len(instances) / 2))
for instance in instances:
instance._check_status()
def stop_all(self):
"""Stop all lanched instances."""
for instance in self.instances:

View file

@ -13,6 +13,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory
from . import DflyInstance, DflyInstanceFactory, DflyParams
from .utility import DflySeederFactory
DATABASE_INDEX = 1
@ -39,12 +40,9 @@ def test_env(tmp_dir: Path):
env["DRAGONFLY_TMP"] = str(tmp_dir)
return env
def pytest_addoption(parser):
parser.addoption(
'--gdb', action='store_true', default=False, help='Run instances in gdb'
)
@pytest.fixture(scope="session", params=[{}])
def df_seeder_factory(request) -> DflySeederFactory:
return DflySeederFactory(request.config.getoption("--log-seeder"))
@pytest.fixture(scope="session", params=[{}])
def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
@ -61,6 +59,7 @@ def df_factory(request, tmp_dir, test_env) -> DflyInstanceFactory:
path=path,
cwd=tmp_dir,
gdb=request.config.getoption("--gdb"),
args=request.config.getoption("--df"),
env=test_env
)
@ -136,3 +135,21 @@ async def async_client(async_pool):
client = aioredis.Redis(connection_pool=async_pool)
await client.flushall()
return client
def pytest_addoption(parser):
"""
Custom pytest options:
--gdb - start all instances inside gdb
--df arg - pass arg to all instances, can be used multiple times
--log-seeder file - to log commands of last seeder run
"""
parser.addoption(
'--gdb', action='store_true', default=False, help='Run instances in gdb'
)
parser.addoption(
'--df', action='append', default=[], help='Add arguments to dragonfly'
)
parser.addoption(
'--log-seeder', action='store', default=None, help='Store last generator commands in file'
)

View file

@ -1,4 +1,5 @@
import pytest
import redis
from redis.commands.json.path import Path
from .utility import *

View file

@ -3,7 +3,7 @@ import pytest
import asyncio
import aioredis
import random
from itertools import count, chain, repeat
from itertools import chain, repeat
from .utility import *
from . import dfly_args
@ -21,73 +21,79 @@ Test full replication pipeline. Test full sync with streaming changes and stable
# 1. Number of master threads
# 2. Number of threads for each replica
# 3. Number of keys stored and sent in full sync
# 4. Number of keys overwritten during full sync
# 3. Seeder config
replication_cases = [
(8, [8], 20000, 5000),
(8, [8], 10000, 10000),
(8, [2, 2, 2, 2], 20000, 5000),
(6, [6, 6, 6], 30000, 15000),
(4, [1] * 12, 10000, 4000),
(8, [8], dict(keys=10_000, dbcount=4)),
(6, [6, 6, 6], dict(keys=4_000, dbcount=4)),
(8, [2, 2, 2, 2], dict(keys=4_000, dbcount=4)),
(4, [8, 8], dict(keys=4_000, dbcount=4)),
(4, [1] * 8, dict(keys=500, dbcount=2)),
]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, n_keys, n_stream_keys", replication_cases)
async def test_replication_all(df_local_factory, t_master, t_replicas, n_keys, n_stream_keys):
@pytest.mark.parametrize("t_master, t_replicas, seeder_config", replication_cases)
async def test_replication_all(df_local_factory, df_seeder_factory, t_master, t_replicas, seeder_config):
master = df_local_factory.create(port=1111, proactor_threads=t_master)
replicas = [
df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t)
for i, t in enumerate(t_replicas)
]
# Start master and fill with test data
# Start master
master.start()
c_master = aioredis.Redis(port=master.port)
await batch_fill_data_async(c_master, gen_test_data(n_keys, seed=1))
# Fill master with test data
seeder = df_seeder_factory.create(port=master.port, **seeder_config)
await seeder.run(target_deviation=0.1)
# Start replicas
for replica in replicas:
replica.start()
df_local_factory.start_all(replicas)
c_replicas = [aioredis.Redis(port=replica.port) for replica in replicas]
async def stream_data():
""" Stream data during stable state replication phase and afterwards """
gen = gen_test_data(n_stream_keys, seed=2)
for chunk in grouper(3, gen):
await c_master.mset({k: v for k, v in chunk})
# Start data stream
stream_task = asyncio.create_task(seeder.run(target_times=3))
await asyncio.sleep(0.0)
# Start replication
async def run_replication(c_replica):
await c_replica.execute_command("REPLICAOF localhost " + str(master.port))
async def check_replication(c_replica):
""" Check that static and streamed data arrived """
await wait_available_async(c_replica)
# Check range [n_stream_keys, n_keys] is of seed 1
await batch_check_data_async(c_replica, gen_test_data(n_keys, start=n_stream_keys, seed=1))
# Check range [0, n_stream_keys] is of seed 2
await asyncio.sleep(1.0)
await batch_check_data_async(c_replica, gen_test_data(n_stream_keys, seed=2))
# Start streaming data and run REPLICAOF in parallel
stream_fut = asyncio.create_task(stream_data())
await asyncio.gather(*(asyncio.create_task(run_replication(c))
for c in c_replicas))
assert not stream_fut.done(
), "Weak testcase. Increase number of streamed keys to surpass full sync"
await stream_fut
# Wait for streaming to finish
assert not stream_task.done(
), "Weak testcase. Increase number of streamed iterations to surpass full sync"
await stream_task
# Check full sync results
await asyncio.gather(*(check_replication(c) for c in c_replicas))
# Check data after full sync
await asyncio.sleep(3.0)
await check_data(seeder, replicas, c_replicas)
# Check stable state streaming
await batch_fill_data_async(c_master, gen_test_data(n_keys, seed=3))
# Stream more data in stable state
await seeder.run(target_times=2)
await asyncio.sleep(1.0)
await asyncio.gather(*(batch_check_data_async(c, gen_test_data(n_keys, seed=3))
for c in c_replicas))
# Check data after stable state stream
await asyncio.sleep(3.0)
await check_data(seeder, replicas, c_replicas)
# Issue lots of deletes
# TODO: Enable after stable state is faster
# seeder.target(100)
# await seeder.run(target_deviation=0.1)
# Check data after deletes
# await asyncio.sleep(2.0)
# await check_data(seeder, replicas, c_replicas)
async def check_data(seeder, replicas, c_replicas):
capture = await seeder.capture()
for (replica, c_replica) in zip(replicas, c_replicas):
await wait_available_async(c_replica)
assert await seeder.compare(capture, port=replica.port)
"""
@ -109,22 +115,20 @@ Three types are tested:
# 5. Number of distinct keys that are constantly streamed
disconnect_cases = [
# balanced
(8, [4, 4], [4, 4], [4], 10000),
(8, [2] * 6, [2] * 6, [2, 2], 10000),
(8, [4, 4], [4, 4], [4], 4_000),
(4, [2] * 4, [2] * 4, [2, 2], 2_000),
# full sync heavy
(8, [4] * 6, [], [], 10000),
(8, [2] * 12, [], [], 10000),
(8, [4] * 4, [], [], 4_000),
# stable state heavy
(8, [], [4] * 6, [], 10000),
(8, [], [2] * 12, [], 10000),
(8, [], [4] * 4, [], 4_000),
# disconnect only
(8, [], [], [2] * 6, 10000)
(8, [], [], [4] * 4, 4_000)
]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases)
async def test_disconnect_replica(df_local_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys):
async def test_disconnect_replica(df_local_factory, df_seeder_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys):
master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master)
replicas = [
(df_local_factory.create(
@ -143,8 +147,7 @@ async def test_disconnect_replica(df_local_factory, t_master, t_crash_fs, t_cras
c_master = aioredis.Redis(port=master.port, single_connection_client=True)
# Start replicas and create clients
for replica, _ in replicas:
replica.start()
df_local_factory.start_all([replica for replica, _ in replicas])
c_replicas = [
(replica, aioredis.Redis(port=replica.port), crash_type)
@ -158,13 +161,8 @@ async def test_disconnect_replica(df_local_factory, t_master, t_crash_fs, t_cras
]
# Start data fill loop
async def fill_loop():
local_c = aioredis.Redis(
port=master.port, single_connection_client=True)
for seed in count(1):
await batch_fill_data_async(local_c, gen_test_data(n_keys, seed=seed))
fill_task = asyncio.create_task(fill_loop())
seeder = df_seeder_factory.create(port=master.port, keys=n_keys, dbcount=2)
fill_task = asyncio.create_task(seeder.run())
# Run full sync
async def full_sync(replica, c_replica, crash_type):
@ -204,18 +202,19 @@ async def test_disconnect_replica(df_local_factory, t_master, t_crash_fs, t_cras
assert await c_replica.ping()
# Stop streaming
fill_task.cancel()
seeder.stop()
await fill_task
# Check master survived all crashes
assert await c_master.ping()
# Check phase 3 replicas are up-to-date and there is no gap or lag
def check_gen(): return gen_test_data(n_keys//5, seed=0)
await batch_fill_data_async(c_master, check_gen())
await seeder.run(target_times=2)
await asyncio.sleep(1.0)
for _, c_replica, _ in replicas_of_type(lambda t: t > 1):
await batch_check_data_async(c_replica, check_gen())
capture = await seeder.capture()
for replica, _, _ in replicas_of_type(lambda t: t > 1):
assert await seeder.compare(capture, port=replica.port)
# Check disconnects
async def disconnect(replica, c_replica, crash_type):
@ -228,9 +227,9 @@ async def test_disconnect_replica(df_local_factory, t_master, t_crash_fs, t_cras
await asyncio.sleep(0.5)
# Check phase 3 replica survived
for _, c_replica, _ in replicas_of_type(lambda t: t == 2):
for replica, c_replica, _ in replicas_of_type(lambda t: t == 2):
assert await c_replica.ping()
await batch_check_data_async(c_replica, check_gen())
assert await seeder.compare(capture, port=replica.port)
# Check master survived all disconnects
assert await c_master.ping()
@ -254,16 +253,14 @@ Three types are tested:
# 3. Number of times a random crash happens
# 4. Number of keys transferred (the more, the higher the propability to not miss full sync)
master_crash_cases = [
(4, [4], 3, 1000),
(8, [8], 3, 5000),
(6, [6, 6, 6], 3, 5000),
(4, [2] * 8, 3, 5000),
(6, [6], 3, 2_000),
(4, [4, 4, 4], 3, 2_000),
]
@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replicas, n_random_crashes, n_keys", master_crash_cases)
async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_random_crashes, n_keys):
async def test_disconnect_master(df_local_factory, df_seeder_factory, t_master, t_replicas, n_random_crashes, n_keys):
master = df_local_factory.create(port=1111, proactor_threads=t_master)
replicas = [
df_local_factory.create(
@ -271,11 +268,11 @@ async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_rando
for i, t in enumerate(t_replicas)
]
for replica in replicas:
replica.start()
df_local_factory.start_all(replicas)
c_replicas = [aioredis.Redis(port=replica.port) for replica in replicas]
seeder = df_seeder_factory.create(port=master.port, keys=n_keys, dbcount=2)
async def crash_master_fs():
await asyncio.sleep(random.random() / 10 + 0.1 * len(replicas))
master.stop(kill=True)
@ -284,7 +281,8 @@ async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_rando
master.start()
c_master = aioredis.Redis(port=master.port)
assert await c_master.ping()
await batch_fill_data_async(c_master, gen_test_data(n_keys, seed=0))
seeder.reset()
await seeder.run(target_deviation=0.1)
await start_master()
@ -302,15 +300,17 @@ async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_rando
await start_master()
await asyncio.sleep(1 + len(replicas) * 0.5) # Replicas check every 500ms.
for c_replica in c_replicas:
capture = await seeder.capture()
for replica, c_replica in zip(replicas, c_replicas):
await wait_available_async(c_replica)
await batch_check_data_async(c_replica, gen_test_data(n_keys, seed=0))
assert await seeder.compare(capture, port=replica.port)
# Crash master during stable state
master.stop(kill=True)
await start_master()
await asyncio.sleep(1 + len(replicas) * 0.5)
capture = await seeder.capture()
for c_replica in c_replicas:
await wait_available_async(c_replica)
await batch_check_data_async(c_replica, gen_test_data(n_keys, seed=0))
assert await seeder.compare(capture, port=replica.port)

View file

@ -1,9 +1,5 @@
import pytest
import redis
import random
from string import ascii_lowercase
import time
import datetime
from .utility import *
@ -44,6 +40,8 @@ When this issue is fully fixed, this test would failed, and then it should
change to match the fact that we supporting this operation.
For now we are expecting to get an error
'''
def test_multi_eval(client):
try:
pipeline = client.pipeline()
@ -66,17 +64,21 @@ def test_connection_name(client):
name = client.execute_command("CLIENT GETNAME")
assert name == "test_conn_name"
'''
make sure that the scan command is working with python
'''
def test_scan(client):
try:
for key, val in gen_test_data(n=10, seed="set-test-key"):
res = client.set(key, val)
assert res is not None
cur, keys = client.scan(cursor=0, match=key, count=2)
assert cur == 0
assert len(keys) == 1
assert keys[0] == key
except Exception as e:
assert False, str(e)
def gen_test_data():
for i in range(10):
yield f"key-{i}", f"value-{i}"
for key, val in gen_test_data():
res = client.set(key, val)
assert res is not None
cur, keys = client.scan(cursor=0, match=key, count=2)
assert cur == 0
assert len(keys) == 1
assert keys[0] == key

View file

@ -1,24 +1,20 @@
import time
import pytest
import redis
import string
import os
import glob
from pathlib import Path
from . import dfly_args
from .utility import batch_check_data, batch_fill_data, gen_test_data
from .utility import DflySeeder, wait_available_async
BASIC_ARGS = {"dir": "{DRAGONFLY_TMP}/"}
NUM_KEYS = 100
SEEDER_ARGS = dict(keys=12_000, dbcount=5)
class SnapshotTestBase:
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
self.rdb_out = tmp_dir / "test.rdb"
if self.rdb_out.exists():
self.rdb_out.unlink()
def get_main_file(self, suffix):
def is_main(f): return "summary" in f if suffix == "dfs" else True
@ -26,55 +22,62 @@ class SnapshotTestBase:
return next(f for f in sorted(files) if is_main(f))
@dfly_args({**BASIC_ARGS, "dbfilename": "test"})
@dfly_args({**BASIC_ARGS, "dbfilename": "test-rdb"})
class TestRdbSnapshot(SnapshotTestBase):
"""Test single file rdb snapshot"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
super().setup(tmp_dir)
def test_snapshot(self, client: redis.Redis):
batch_fill_data(client, gen_test_data(NUM_KEYS))
@pytest.mark.asyncio
async def test_snapshot(self, df_seeder_factory, async_client, df_server):
seeder = df_seeder_factory.create(port=df_server.port, **SEEDER_ARGS)
await seeder.run(target_deviation=0.1)
start_capture = await seeder.capture()
# save + flush + load
client.execute_command("SAVE")
assert client.flushall()
client.execute_command("DEBUG LOAD " + super().get_main_file("rdb"))
await async_client.execute_command("SAVE")
assert await async_client.flushall()
await async_client.execute_command("DEBUG LOAD " + super().get_main_file("rdb"))
batch_check_data(client, gen_test_data(NUM_KEYS))
assert await seeder.compare(start_capture)
@dfly_args({**BASIC_ARGS, "dbfilename": "test"})
@dfly_args({**BASIC_ARGS, "dbfilename": "test-dfs"})
class TestDflySnapshot(SnapshotTestBase):
"""Test multi file snapshot"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
self.tmp_dir = tmp_dir
files = glob.glob(str(tmp_dir.absolute()) + 'test-*.dfs')
for file in files:
os.remove(file)
def test_snapshot(self, client: redis.Redis):
batch_fill_data(client, gen_test_data(NUM_KEYS))
@pytest.mark.asyncio
async def test_snapshot(self, df_seeder_factory, async_client, df_server):
seeder = df_seeder_factory.create(port=df_server.port, **SEEDER_ARGS)
await seeder.run(target_deviation=0.1)
start_capture = await seeder.capture()
# save + flush + load
client.execute_command("SAVE DF")
assert client.flushall()
client.execute_command("DEBUG LOAD " + super().get_main_file("dfs"))
await async_client.execute_command("SAVE DF")
assert await async_client.flushall()
await async_client.execute_command("DEBUG LOAD " + super().get_main_file("dfs"))
batch_check_data(client, gen_test_data(NUM_KEYS))
assert await seeder.compare(start_capture)
@dfly_args({**BASIC_ARGS, "dbfilename": "test.rdb", "save_schedule": "*:*"})
@dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic.rdb", "save_schedule": "*:*"})
class TestPeriodicSnapshot(SnapshotTestBase):
"""Test periodic snapshotting"""
@pytest.fixture(autouse=True)
def setup(self, tmp_dir: Path):
super().setup(tmp_dir)
def test_snapshot(self, client: redis.Redis):
batch_fill_data(client, gen_test_data(NUM_KEYS))
@pytest.mark.asyncio
async def test_snapshot(self, df_seeder_factory, df_server):
seeder = df_seeder_factory.create(port=df_server.port, keys=10)
await seeder.run(target_deviation=0.5)
time.sleep(60)
assert self.rdb_out.exists()
assert (self.tmp_dir / "test-periodic.rdb").exists()

View file

@ -1,11 +1,16 @@
import redis
import aioredis
import itertools
import time
import sys
import asyncio
import random
import string
import itertools
import time
import difflib
from enum import Enum
def grouper(n, iterable):
def chunked(n, iterable):
"""Transform iterable into iterator of chunks of size n"""
it = iter(iterable)
while True:
@ -15,7 +20,9 @@ def grouper(n, iterable):
yield chunk
BATCH_SIZE = 100
def eprint(*args, **kwargs):
"""Print to stderr"""
print(*args, file=sys.stderr, **kwargs)
def gen_test_data(n, start=0, seed=None):
@ -23,60 +30,466 @@ def gen_test_data(n, start=0, seed=None):
yield "k-"+str(i), "v-"+str(i) + ("-"+str(seed) if seed else "")
def batch_fill_data(client: redis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
def batch_fill_data(client, gen):
BATCH_SIZE = 100
for group in chunked(BATCH_SIZE, gen):
client.mset({k: v for k, v, in group})
async def batch_fill_data_async(client: aioredis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
await client.mset({k: v for k, v in group})
def as_str_val(v) -> str:
if isinstance(v, str):
return v
elif isinstance(v, bytes):
return v.decode()
else:
return str(v)
def batch_check_data(client: redis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
vals = [as_str_val(v) for v in client.mget(k for k, _ in group)]
gvals = [v for _, v in group]
assert vals == gvals
async def batch_check_data_async(client: aioredis.Redis, gen):
for group in grouper(BATCH_SIZE, gen):
vals = [as_str_val(v) for v in await client.mget(k for k, _ in group)]
gvals = [v for _, v in group]
assert vals == gvals
def wait_available(client: redis.Redis):
its = 0
while True:
try:
client.get('key')
print("wait_available iterations:", its)
return
except redis.ResponseError as e:
assert "Can not execute during LOADING" in str(e)
time.sleep(0.01)
its += 1
async def wait_available_async(client: aioredis.Redis):
"""Block until instance exits loading phase"""
its = 0
while True:
try:
await client.get('key')
print("wait_available iterations:", its)
return
except aioredis.ResponseError as e:
assert "Can not execute during LOADING" in str(e)
# Print W to indicate test is waiting for replica
print('W', end='', flush=True)
await asyncio.sleep(0.01)
its += 1
class SizeChange(Enum):
SHRINK = 0
NO_CHANGE = 1
GROW = 2
class ValueType(Enum):
STRING = 0
LIST = 1
SET = 2
HSET = 3
ZSET = 4
@staticmethod
def randomize():
return random.choice([t for t in ValueType])
class CommandGenerator:
"""Class for generating complex command sequences"""
def __init__(self, target_keys, val_size, batch_size, max_multikey):
self.key_cnt_target = target_keys
self.val_size = val_size
self.batch_size = min(batch_size, target_keys)
self.max_multikey = max_multikey
# Key management
self.key_sets = [set() for _ in ValueType]
self.key_cursor = 0
self.key_cnt = 0
# Grow factors
self.diff_speed = 5
self.base_diff_prob = 0.2
self.min_diff_prob = 0.1
def keys(self):
return itertools.chain(*self.key_sets)
def keys_and_types(self):
return ((k, t) for t in list(ValueType) for k in self.set_for_type(t))
def set_for_type(self, t: ValueType):
return self.key_sets[t.value]
def add_key(self, t: ValueType):
"""Add new key of type t"""
k, self.key_cursor = self.key_cursor, self.key_cursor + 1
self.set_for_type(t).add(k)
return k
def randomize_nonempty_set(self):
"""Return random non-empty set and its type"""
if not any(self.key_sets):
return None, None
t = ValueType.randomize()
s = self.set_for_type(t)
if len(s) == 0:
return self.randomize_nonempty_set()
else:
return s, t
def randomize_key(self, t=None, pop=False):
"""Return random key and its type"""
if t is None:
s, t = self.randomize_nonempty_set()
else:
s = self.set_for_type(t)
if s is None or len(s) == 0:
return None, None
k = s.pop()
if not pop:
s.add(k)
return k, t
def generate_val(self, t: ValueType):
"""Generate filler value of configured size for type t"""
def rand_str(k=3, s=''):
# Use small k value to reduce mem usage and increase number of ops
return s.join(random.choices(string.ascii_letters, k=k))
if t == ValueType.STRING:
# Random string for MSET
return rand_str(self.val_size)
elif t == ValueType.LIST:
# Random sequence k-letter elements for LPUSH
return ' '.join(rand_str() for _ in range(self.val_size//4))
elif t == ValueType.SET:
# Random sequence of k-letter elements for SADD
return ' '.join(rand_str() for _ in range(self.val_size//4))
elif t == ValueType.HSET:
# Random sequence of k-letter keys + int and two start values for HSET
return 'v0 0 v1 0 ' + ' '.join(
rand_str() + ' ' + str(random.randint(0, self.val_size))
for _ in range(self.val_size//5)
)
else:
# Random sequnce of k-letter keys and int score for ZSET
return ' '.join(str(random.randint(0, self.val_size)) + ' ' + rand_str()
for _ in range(self.val_size//4))
def gen_shrink_cmd(self):
"""
Generate command that shrinks data: DEL of random keys.
"""
keys_gen = (self.randomize_key(pop=True)
for _ in range(random.randint(1, self.max_multikey)))
keys = [f"k{k}" for k, _ in keys_gen if k is not None]
if len(keys) == 0:
return None, 0
return "DEL " + " ".join(keys), -len(keys)
UPDATE_ACTIONS = [
('APPEND {k} {val}', ValueType.STRING),
('SETRANGE {k} 10 {val}', ValueType.STRING),
('LPUSH {k} {val}', ValueType.LIST),
('LPOP {k}', ValueType.LIST),
#('SADD {k} {val}', ValueType.SET),
#('SPOP {k}', ValueType.SET),
#('HSETNX {k} v0 {val}', ValueType.HSET),
#('HINCRBY {k} v1 1', ValueType.HSET),
#('ZPOPMIN {k} 1', ValueType.ZSET),
#('ZADD {k} 0 {val}', ValueType.ZSET)
]
def gen_update_cmd(self):
"""
Generate command that makes no change to keyset: random of UPDATE_ACTIONS.
"""
cmd, t = random.choice(self.UPDATE_ACTIONS)
k, _ = self.randomize_key(t)
val = ''.join(random.choices(string.ascii_letters, k=4))
return cmd.format(k=f"k{k}", val=val) if k is not None else None, 0
GROW_ACTINONS = {
ValueType.STRING: 'MSET',
ValueType.LIST: 'LPUSH',
ValueType.SET: 'SADD',
ValueType.HSET: 'HMSET',
ValueType.ZSET: 'ZADD'
}
def gen_grow_cmd(self):
"""
Generate command that grows keyset: Initialize key of random type with filler value.
"""
# TODO: Implement COPY in Dragonfly.
t = ValueType.randomize()
if t == ValueType.STRING:
count = random.randint(1, self.max_multikey)
else:
count = 1
keys = (self.add_key(t) for _ in range(count))
payload = " ".join(f"k{k}" + " " + self.generate_val(t) for k in keys)
return self.GROW_ACTINONS[t] + " " + payload, count
def make(self, action):
""" Create command for action and return it together with number of keys added (removed)"""
if action == SizeChange.SHRINK:
return self.gen_shrink_cmd()
elif action == SizeChange.NO_CHANGE:
return self.gen_update_cmd()
else:
return self.gen_grow_cmd()
def reset(self):
self.key_sets = [set() for _ in ValueType]
self.key_cursor = 0
self.key_cnt = 0
def size_change_probs(self):
"""Calculate probabilities of size change actions"""
# Relative distance to key target
dist = (self.key_cnt_target - self.key_cnt) / self.key_cnt_target
# Shrink has a roughly twice as large expected number of changed keys than grow
return [
max(self.base_diff_prob - self.diff_speed * dist, self.min_diff_prob),
1.0,
max(self.base_diff_prob + 2 *
self.diff_speed * dist, self.min_diff_prob)
]
def generate(self):
"""Generate next batch of commands, return it and ratio of current keys to target"""
changes = []
cmds = []
while len(cmds) < self.batch_size:
# Re-calculating changes in small groups
if len(changes) == 0:
changes = random.choices(
list(SizeChange), weights=self.size_change_probs(), k=50)
cmd, delta = self.make(changes.pop())
if cmd is not None:
cmds.append(cmd)
self.key_cnt += delta
return cmds, self.key_cnt/self.key_cnt_target
class DataCapture:
"""
Captured state of single database.
"""
def __init__(self, entries):
self.entries = entries
def compare(self, other):
if self.entries == other.entries:
return True
self._print_diff(other)
return False
def _print_diff(self, other):
eprint("=== DIFF ===")
printed = 0
diff = difflib.ndiff(self.entries, other.entries)
for line in diff:
if line.startswith(' '):
continue
eprint(line)
if printed >= 20:
eprint("... omitted ...")
break
printed += 1
eprint("=== END DIFF ===")
class DflySeeder:
"""
Data seeder with support for multiple types and commands.
Usage:
Create a seeder with target number of keys (100k) of specified size (200) and work on 5 dbs,
seeder = new DflySeeder(keys=100_000, value_size=200, dbcount=5)
Stop when we are in 5% of target number of keys (i.e. above 95_000)
Because its probabilistic we might never reach exactly 100_000.
await seeder.run(target_deviation=0.05)
Run 3 iterations (full batches) in stable state, crate a capture and compare it to
replica on port 1112
await seeder.run(target_times=3)
capture = await seeder.capture()
assert await seeder.compare(capture, port=1112)
"""
def __init__(self, port=6379, keys=1000, val_size=50, batch_size=1000, max_multikey=5, dbcount=1, log_file=None):
self.gen = CommandGenerator(
keys, val_size, batch_size, max_multikey
)
self.port = port
self.dbcount = dbcount
self.stop_flag = False
self.log_file = log_file
if self.log_file is not None:
open(self.log_file, 'w').close()
async def run(self, target_times=None, target_deviation=None):
"""
Run a seeding cycle on all dbs either until stop(), a fixed number of batches (target_times)
or until reaching an allowed deviation from the target number of keys (target_deviation)
"""
print(f"Running times:{target_times} deviation:{target_deviation}")
self.stop_flag = False
queues = [asyncio.Queue(maxsize=3) for _ in range(self.dbcount)]
producer = asyncio.create_task(self._generator_task(
queues, target_times=target_times, target_deviation=target_deviation))
consumers = [
asyncio.create_task(self._executor_task(i, queue))
for i, queue in enumerate(queues)
]
time_start = time.time()
cmdcount = await producer
for consumer in consumers:
await consumer
took = time.time() - time_start
qps = round(cmdcount * self.dbcount / took, 2)
print(f"Filling took: {took}, QPS: {qps}")
def stop(self):
"""Stop all invocations to run"""
self.stop_flag = True
def reset(self):
""" Reset internal state. Needs to be called after flush or restart"""
self.gen.reset()
async def capture(self, port=None, target_db=0, keys=None):
"""Create DataCapture for selected db"""
if port is None:
port = self.port
if keys is None:
keys = sorted(list(self.gen.keys_and_types()))
client = aioredis.Redis(port=port, db=target_db)
capture = DataCapture(await self._capture_entries(client, keys))
await client.connection_pool.disconnect()
return capture
async def compare(self, initial_capture, port=6379):
"""Compare data capture with all dbs of instance and return True if all dbs are correct"""
print(f"comparing capture to {port}")
keys = sorted(list(self.gen.keys_and_types()))
captures = await asyncio.gather(*(
self.capture(port=port, target_db=db, keys=keys) for db in range(self.dbcount)
))
for db, capture in zip(range(self.dbcount), captures):
if not initial_capture.compare(capture):
eprint(f">>> Inconsistent data on port {port}, db {db}")
return False
return True
def target(self, key_cnt):
self.gen.key_cnt_target = key_cnt
async def _generator_task(self, queues, target_times=None, target_deviation=None):
cpu_time = 0
submitted = 0
deviation = 0.0
file = None
if self.log_file:
file = open(self.log_file, 'a')
def should_run():
if self.stop_flag:
return False
if target_times is not None and submitted >= target_times:
return False
if target_deviation is not None and abs(1-deviation) < target_deviation:
return False
return True
while should_run():
start_time = time.time()
blob, deviation = self.gen.generate()
cpu_time += (time.time() - start_time)
await asyncio.gather(*(q.put(blob) for q in queues))
submitted += 1
if file is not None:
file.write('\n'.join(blob))
print('.', end='', flush=True)
await asyncio.sleep(0.0)
print("\ncpu time", cpu_time, "batches", submitted)
await asyncio.gather(*(q.put(None) for q in queues))
for q in queues:
await q.join()
if file is not None:
file.flush()
return submitted * self.gen.batch_size
async def _executor_task(self, db, queue):
client = aioredis.Redis(port=self.port, db=db)
while True:
cmds = await queue.get()
if cmds is None:
queue.task_done()
break
pipe = client.pipeline(transaction=False)
for cmd in cmds:
pipe.execute_command(cmd)
await pipe.execute()
queue.task_done()
await client.connection_pool.disconnect()
CAPTURE_COMMANDS = {
ValueType.STRING: lambda pipe, k: pipe.get(k),
ValueType.LIST: lambda pipe, k: pipe.lrange(k, 0, -1),
ValueType.SET: lambda pipe, k: pipe.smembers(k),
ValueType.HSET: lambda pipe, k: pipe.hgetall(k),
ValueType.ZSET: lambda pipe, k: pipe.zrange(
k, start=0, end=-1, withscores=True)
}
CAPTURE_EXTRACTORS = {
ValueType.STRING: lambda res, tostr: (tostr(res),),
ValueType.LIST: lambda res, tostr: (tostr(s) for s in res),
ValueType.SET: lambda res, tostr: sorted(tostr(s) for s in res),
ValueType.HSET: lambda res, tostr: sorted(tostr(k)+"="+tostr(v) for k, v in res.items()),
ValueType.ZSET: lambda res, tostr: (
tostr(s)+"-"+str(f) for (s, f) in res)
}
async def _capture_entries(self, client, keys):
def tostr(b):
return b.decode("utf-8") if isinstance(b, bytes) else str(b)
entries = []
for group in chunked(self.gen.batch_size * 2, keys):
pipe = client.pipeline(transaction=False)
for k, t in group:
self.CAPTURE_COMMANDS[t](pipe, f"k{k}")
results = await pipe.execute()
for (k, t), res in zip(group, results):
out = f"{t.name} k{k}: " + \
' '.join(self.CAPTURE_EXTRACTORS[t](res, tostr))
entries.append(out)
return entries
class DflySeederFactory:
"""
Used to pass params to a DflySeeder.
"""
def __init__(self, log_file=None):
self.log_file = log_file
def create(self, **kwargs):
return DflySeeder(log_file=self.log_file, **kwargs)