From bd3e57d2622b53931b97d109224781e0ab6cdeab Mon Sep 17 00:00:00 2001 From: s-shiraki <54130718+highpon@users.noreply.github.com> Date: Sun, 17 Dec 2023 03:42:15 +0900 Subject: [PATCH] feat(server): Implement NUMSUB subcommand (#2282) * feat(server): Implement NUMSUB subcommand * fix * fix * fix * fix * fix * fix * fix * fix: test * fix: build error --- CONTRIBUTORS.md | 1 + src/facade/reply_builder.cc | 0 src/facade/reply_builder.h | 0 src/server/main_service.cc | 18 +++++++++++ src/server/main_service.h | 1 + tests/dragonfly/connection_test.py | 50 ++++++++++++++++++++++++++++++ 6 files changed, 70 insertions(+) mode change 100644 => 100755 src/facade/reply_builder.cc mode change 100644 => 100755 src/facade/reply_builder.h mode change 100644 => 100755 src/server/main_service.cc mode change 100644 => 100755 tests/dragonfly/connection_test.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index a8a29cd1d..5ca6ec6bd 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -14,5 +14,6 @@ * **[Ali-Akber Saifee](https://github.com/alisaifee)** * **[Elle Y](https://github.com/inohime)** * **[ATM SALEH](https://github.com/ATM-SALEH)** +* **[Shohei Shiraki](https://github.com/highpon)** * **[Leonardo Mello](https://github.com/lsvmello)** * **[Nico Coetzee](https://github.com/nicc777)** diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc old mode 100644 new mode 100755 diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h old mode 100644 new mode 100755 diff --git a/src/server/main_service.cc b/src/server/main_service.cc old mode 100644 new mode 100755 index aa560c613..bbc2184d6 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -2197,6 +2197,18 @@ void Service::PubsubPatterns(ConnectionContext* cntx) { cntx->SendLong(pattern_count); } +void Service::PubsubNumSub(CmdArgList args, ConnectionContext* cntx) { + int channels_size = args.size(); + auto* rb = static_cast(cntx->reply_builder()); + rb->StartArray(channels_size * 2); + + for (auto i = 0; i < channels_size; i++) { + auto channel = ArgS(args, i); + rb->SendBulkString(channel); + rb->SendLong(ServerState::tlocal()->channel_store()->FetchSubscribers(channel).size()); + } +} + void Service::Monitor(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "starting monitor on this connection: " << cntx->conn()->GetClientId(); // we are registering the current connection for all threads so they will be aware of @@ -2221,6 +2233,9 @@ void Service::Pubsub(CmdArgList args, ConnectionContext* cntx) { "\tReturn the currently active channels matching a (default: '*').", "NUMPAT", "\tReturn number of subscriptions to patterns.", + "NUMSUB [ ]", + "\tReturns the number of subscribers for the specified channels, excluding", + "\tpattern subscriptions.", "HELP", "\tPrints this help."}; @@ -2238,6 +2253,9 @@ void Service::Pubsub(CmdArgList args, ConnectionContext* cntx) { PubsubChannels(pattern, cntx); } else if (subcmd == "NUMPAT") { PubsubPatterns(cntx); + } else if (subcmd == "NUMSUB") { + args.remove_prefix(1); + PubsubNumSub(args, cntx); } else { cntx->SendError(UnknownSubCmd(subcmd, "PUBSUB")); } diff --git a/src/server/main_service.h b/src/server/main_service.h index f80e92b5b..0912d7b91 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -150,6 +150,7 @@ class Service : public facade::ServiceInterface { void PubsubChannels(std::string_view pattern, ConnectionContext* cntx); void PubsubPatterns(ConnectionContext* cntx); + void PubsubNumSub(CmdArgList channels, ConnectionContext* cntx); struct EvalArgs { std::string_view sha; // only one of them is defined. diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py old mode 100644 new mode 100755 index f48db493b..2c128ffd5 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -1,6 +1,7 @@ import random import pytest import asyncio +import time from redis import asyncio as aioredis from redis.exceptions import ConnectionError as redis_conn_error import async_timeout @@ -287,6 +288,55 @@ async def test_multi_pubsub(async_client): assert state, message +""" +Test PUBSUB NUMSUB command. +""" + + +@pytest.mark.asyncio +async def test_pubsub_subcommand_for_numsub(async_client): + subs1 = [async_client.pubsub() for i in range(5)] + for s in subs1: + await s.subscribe("channel_name1") + result = await async_client.pubsub_numsub("channel_name1") + assert result[0][0] == "channel_name1" and result[0][1] == 5 + + for s in subs1: + await s.unsubscribe("channel_name1") + result = await async_client.pubsub_numsub("channel_name1") + + retry = 5 + for i in range(0, retry): + result = await async_client.pubsub_numsub("channel_name1") + if result[0][0] == "channel_name1" and result[0][1] == 0: + break + else: + time.sleep(1) + + assert result[0][0] == "channel_name1" and result[0][1] == 0 + + result = await async_client.pubsub_numsub() + assert len(result) == 0 + + subs2 = [async_client.pubsub() for i in range(5)] + for s in subs2: + await s.subscribe("channel_name2") + + subs3 = [async_client.pubsub() for i in range(10)] + for s in subs3: + await s.subscribe("channel_name3") + + result = await async_client.pubsub_numsub("channel_name2", "channel_name3") + assert result[0][0] == "channel_name2" and result[0][1] == 5 + assert result[1][0] == "channel_name3" and result[1][1] == 10 + + for s in subs2: + await s.unsubscribe("channel_name2") + + for s in subs3: + await s.unsubscribe("channel_name3") + + """ Test that pubsub clients who are stuck on backpressure from a slow client (the one in the test doesn't read messages at all) will eventually unblock when it disconnects.