mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 02:15:45 +02:00
237 lines
7.8 KiB
Python
237 lines
7.8 KiB
Python
import asyncio
|
|
import sys
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from asyncio import timeout as async_timeout
|
|
else:
|
|
from async_timeout import timeout as async_timeout
|
|
import pytest
|
|
import pytest_asyncio
|
|
import redis
|
|
import redis.asyncio
|
|
|
|
from fakeredis import FakeServer, aioredis
|
|
from test import testtools
|
|
|
|
pytestmark = []
|
|
pytestmark.extend(
|
|
[
|
|
pytest.mark.asyncio,
|
|
]
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def conn(async_redis: redis.asyncio.Redis):
|
|
"""A single connection, rather than a pool."""
|
|
async with async_redis.client() as conn:
|
|
yield conn
|
|
|
|
|
|
async def test_ping(async_redis: redis.asyncio.Redis):
|
|
pong = await async_redis.ping()
|
|
assert pong is True
|
|
|
|
|
|
async def test_types(async_redis: redis.asyncio.Redis):
|
|
await async_redis.hset(
|
|
"hash", mapping={"key1": "value1", "key2": "value2", "key3": 123}
|
|
)
|
|
result = await async_redis.hgetall("hash")
|
|
assert result == {b"key1": b"value1", b"key2": b"value2", b"key3": b"123"}
|
|
|
|
|
|
async def test_transaction(async_redis: redis.asyncio.Redis):
|
|
async with async_redis.pipeline(transaction=True) as tr:
|
|
tr.set("key1", "value1")
|
|
tr.set("key2", "value2")
|
|
ok1, ok2 = await tr.execute()
|
|
assert ok1
|
|
assert ok2
|
|
result = await async_redis.get("key1")
|
|
assert result == b"value1"
|
|
|
|
|
|
async def test_transaction_fail(async_redis: redis.asyncio.Redis):
|
|
await async_redis.set("foo", "1")
|
|
async with async_redis.pipeline(transaction=True) as tr:
|
|
await tr.watch("foo")
|
|
await async_redis.set("foo", "2") # Different connection
|
|
tr.multi()
|
|
tr.get("foo")
|
|
with pytest.raises(redis.asyncio.WatchError):
|
|
await tr.execute()
|
|
|
|
|
|
async def test_pubsub(async_redis, event_loop):
|
|
queue = asyncio.Queue()
|
|
|
|
async def reader(ps):
|
|
while True:
|
|
message = await ps.get_message(ignore_subscribe_messages=True, timeout=5)
|
|
if message is not None:
|
|
if message.get("data") == b"stop":
|
|
break
|
|
queue.put_nowait(message)
|
|
|
|
async with async_timeout(5), async_redis.pubsub() as ps:
|
|
await ps.subscribe("channel")
|
|
task = event_loop.create_task(reader(ps))
|
|
await async_redis.publish("channel", "message1")
|
|
await async_redis.publish("channel", "message2")
|
|
result1 = await queue.get()
|
|
result2 = await queue.get()
|
|
assert result1 == {
|
|
"channel": b"channel",
|
|
"pattern": None,
|
|
"type": "message",
|
|
"data": b"message1",
|
|
}
|
|
assert result2 == {
|
|
"channel": b"channel",
|
|
"pattern": None,
|
|
"type": "message",
|
|
"data": b"message2",
|
|
}
|
|
await async_redis.publish("channel", "stop")
|
|
await task
|
|
|
|
|
|
@pytest.mark.slow
|
|
async def test_pubsub_timeout(async_redis: redis.asyncio.Redis):
|
|
async with async_redis.pubsub() as ps:
|
|
await ps.subscribe("channel")
|
|
await ps.get_message(timeout=0.5) # Subscription message
|
|
message = await ps.get_message(timeout=0.5)
|
|
assert message is None
|
|
|
|
|
|
@pytest.mark.slow
|
|
async def test_pubsub_disconnect(async_redis: redis.asyncio.Redis):
|
|
async with async_redis.pubsub() as ps:
|
|
await ps.subscribe("channel")
|
|
await ps.connection.disconnect()
|
|
message = await ps.get_message(timeout=0.5) # Subscription message
|
|
assert message is not None
|
|
message = await ps.get_message(timeout=0.5)
|
|
assert message is None
|
|
|
|
|
|
async def test_blocking_ready(async_redis, conn):
|
|
"""Blocking command which does not need to block."""
|
|
await async_redis.rpush("list", "x")
|
|
result = await conn.blpop("list", timeout=1)
|
|
assert result == (b"list", b"x")
|
|
|
|
|
|
@pytest.mark.slow
|
|
async def test_blocking_timeout(conn):
|
|
"""Blocking command that times out without completing."""
|
|
result = await conn.blpop("missing", timeout=1)
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.slow
|
|
async def test_blocking_unblock(async_redis, conn, event_loop):
|
|
"""Blocking command that gets unblocked after some time."""
|
|
|
|
async def unblock():
|
|
await asyncio.sleep(0.1)
|
|
await async_redis.rpush("list", "y")
|
|
|
|
task = event_loop.create_task(unblock())
|
|
result = await conn.blpop("list", timeout=1)
|
|
assert result == (b"list", b"y")
|
|
await task
|
|
|
|
|
|
async def test_wrongtype_error(async_redis: redis.asyncio.Redis):
|
|
await async_redis.set("foo", "bar")
|
|
with pytest.raises(redis.asyncio.ResponseError, match="^WRONGTYPE"):
|
|
await async_redis.rpush("foo", "baz")
|
|
|
|
|
|
async def test_syntax_error(async_redis: redis.asyncio.Redis):
|
|
with pytest.raises(
|
|
redis.asyncio.ResponseError,
|
|
match="^wrong number of arguments for 'get' command$",
|
|
):
|
|
await async_redis.execute_command("get")
|
|
|
|
|
|
@testtools.run_test_if_lupa
|
|
class TestScripts:
|
|
async def test_no_script_error(self, async_redis: redis.asyncio.Redis):
|
|
with pytest.raises(redis.exceptions.NoScriptError):
|
|
await async_redis.evalsha("0123456789abcdef0123456789abcdef", 0)
|
|
|
|
@pytest.mark.max_server("6.2.7")
|
|
async def test_failed_script_error6(self, async_redis):
|
|
await async_redis.set("foo", "bar")
|
|
with pytest.raises(redis.asyncio.ResponseError, match="^Error running script"):
|
|
await async_redis.eval('return redis.call("ZCOUNT", KEYS[1])', 1, "foo")
|
|
|
|
@pytest.mark.min_server("7")
|
|
async def test_failed_script_error7(self, async_redis):
|
|
await async_redis.set("foo", "bar")
|
|
with pytest.raises(redis.asyncio.ResponseError):
|
|
await async_redis.eval('return redis.call("ZCOUNT", KEYS[1])', 1, "foo")
|
|
|
|
|
|
async def test_type(async_redis: redis.asyncio.Redis):
|
|
await async_redis.set("string_key", "value")
|
|
await async_redis.lpush("list_key", "value")
|
|
await async_redis.sadd("set_key", "value")
|
|
await async_redis.zadd("zset_key", {"value": 1})
|
|
await async_redis.hset("hset_key", "key", "value")
|
|
|
|
assert b"string" == await async_redis.type("string_key") # noqa: E721
|
|
assert b"list" == await async_redis.type("list_key") # noqa: E721
|
|
assert b"set" == await async_redis.type("set_key") # noqa: E721
|
|
assert b"zset" == await async_redis.type("zset_key") # noqa: E721
|
|
assert b"hash" == await async_redis.type("hset_key") # noqa: E721
|
|
assert b"none" == await async_redis.type("none_key") # noqa: E721
|
|
|
|
|
|
async def test_xdel(async_redis: redis.asyncio.Redis):
|
|
stream = "stream"
|
|
|
|
# deleting from an empty stream doesn't do anything
|
|
assert await async_redis.xdel(stream, 1) == 0
|
|
|
|
m1 = await async_redis.xadd(stream, {"foo": "bar"})
|
|
m2 = await async_redis.xadd(stream, {"foo": "bar"})
|
|
m3 = await async_redis.xadd(stream, {"foo": "bar"})
|
|
|
|
# xdel returns the number of deleted elements
|
|
assert await async_redis.xdel(stream, m1) == 1
|
|
assert await async_redis.xdel(stream, m2, m3) == 2
|
|
|
|
|
|
async def test_connection_with_username_and_password():
|
|
server = FakeServer()
|
|
r = aioredis.FakeRedis(server=server, username="username", password="password")
|
|
|
|
test_value = "this_is_a_test"
|
|
await r.hset("test:key", "test_hash", test_value)
|
|
result = await r.hget("test:key", "test_hash")
|
|
assert result.decode() == test_value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cause_fakeredis_bug(async_redis):
|
|
if sys.version_info < (3, 11):
|
|
return
|
|
|
|
async def worker_task():
|
|
assert await async_redis.rpush("list1", "list1_val") == 1 # 1
|
|
assert await async_redis.blpop("list2") == (b"list2", b"list2_val") # 4
|
|
assert await async_redis.set("foo", "bar") is True # 5
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
tg.create_task(worker_task())
|
|
assert await async_redis.blpop("list1") == (b"list1", b"list1_val") # 2
|
|
assert await async_redis.rpush("list2", "list2_val") == 1 # 3
|
|
|
|
# await async_redis.get("foo") # uncomment to make test pass
|
|
assert await async_redis.get("foo") == b"bar"
|