feat(server): Implement NUMSUB subcommand (#2282)

* feat(server): Implement NUMSUB subcommand

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix: test

* fix: build error
This commit is contained in:
s-shiraki 2023-12-17 03:42:15 +09:00 committed by GitHub
parent 4cce3b4a01
commit bd3e57d262
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 70 additions and 0 deletions

View file

@ -14,5 +14,6 @@
* **[Ali-Akber Saifee](https://github.com/alisaifee)** * **[Ali-Akber Saifee](https://github.com/alisaifee)**
* **[Elle Y](https://github.com/inohime)** * **[Elle Y](https://github.com/inohime)**
* **[ATM SALEH](https://github.com/ATM-SALEH)** * **[ATM SALEH](https://github.com/ATM-SALEH)**
* **[Shohei Shiraki](https://github.com/highpon)**
* **[Leonardo Mello](https://github.com/lsvmello)** * **[Leonardo Mello](https://github.com/lsvmello)**
* **[Nico Coetzee](https://github.com/nicc777)** * **[Nico Coetzee](https://github.com/nicc777)**

0
src/facade/reply_builder.cc Normal file → Executable file
View file

0
src/facade/reply_builder.h Normal file → Executable file
View file

18
src/server/main_service.cc Normal file → Executable file
View file

@ -2197,6 +2197,18 @@ void Service::PubsubPatterns(ConnectionContext* cntx) {
cntx->SendLong(pattern_count); cntx->SendLong(pattern_count);
} }
void Service::PubsubNumSub(CmdArgList args, ConnectionContext* cntx) {
int channels_size = args.size();
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
rb->StartArray(channels_size * 2);
for (auto i = 0; i < channels_size; i++) {
auto channel = ArgS(args, i);
rb->SendBulkString(channel);
rb->SendLong(ServerState::tlocal()->channel_store()->FetchSubscribers(channel).size());
}
}
void Service::Monitor(CmdArgList args, ConnectionContext* cntx) { void Service::Monitor(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "starting monitor on this connection: " << cntx->conn()->GetClientId(); VLOG(1) << "starting monitor on this connection: " << cntx->conn()->GetClientId();
// we are registering the current connection for all threads so they will be aware of // we are registering the current connection for all threads so they will be aware of
@ -2221,6 +2233,9 @@ void Service::Pubsub(CmdArgList args, ConnectionContext* cntx) {
"\tReturn the currently active channels matching a <pattern> (default: '*').", "\tReturn the currently active channels matching a <pattern> (default: '*').",
"NUMPAT", "NUMPAT",
"\tReturn number of subscriptions to patterns.", "\tReturn number of subscriptions to patterns.",
"NUMSUB [<channel> <channel...>]",
"\tReturns the number of subscribers for the specified channels, excluding",
"\tpattern subscriptions.",
"HELP", "HELP",
"\tPrints this help."}; "\tPrints this help."};
@ -2238,6 +2253,9 @@ void Service::Pubsub(CmdArgList args, ConnectionContext* cntx) {
PubsubChannels(pattern, cntx); PubsubChannels(pattern, cntx);
} else if (subcmd == "NUMPAT") { } else if (subcmd == "NUMPAT") {
PubsubPatterns(cntx); PubsubPatterns(cntx);
} else if (subcmd == "NUMSUB") {
args.remove_prefix(1);
PubsubNumSub(args, cntx);
} else { } else {
cntx->SendError(UnknownSubCmd(subcmd, "PUBSUB")); cntx->SendError(UnknownSubCmd(subcmd, "PUBSUB"));
} }

View file

@ -150,6 +150,7 @@ class Service : public facade::ServiceInterface {
void PubsubChannels(std::string_view pattern, ConnectionContext* cntx); void PubsubChannels(std::string_view pattern, ConnectionContext* cntx);
void PubsubPatterns(ConnectionContext* cntx); void PubsubPatterns(ConnectionContext* cntx);
void PubsubNumSub(CmdArgList channels, ConnectionContext* cntx);
struct EvalArgs { struct EvalArgs {
std::string_view sha; // only one of them is defined. std::string_view sha; // only one of them is defined.

50
tests/dragonfly/connection_test.py Normal file → Executable file
View file

@ -1,6 +1,7 @@
import random import random
import pytest import pytest
import asyncio import asyncio
import time
from redis import asyncio as aioredis from redis import asyncio as aioredis
from redis.exceptions import ConnectionError as redis_conn_error from redis.exceptions import ConnectionError as redis_conn_error
import async_timeout import async_timeout
@ -287,6 +288,55 @@ async def test_multi_pubsub(async_client):
assert state, message assert state, message
"""
Test PUBSUB NUMSUB command.
"""
@pytest.mark.asyncio
async def test_pubsub_subcommand_for_numsub(async_client):
subs1 = [async_client.pubsub() for i in range(5)]
for s in subs1:
await s.subscribe("channel_name1")
result = await async_client.pubsub_numsub("channel_name1")
assert result[0][0] == "channel_name1" and result[0][1] == 5
for s in subs1:
await s.unsubscribe("channel_name1")
result = await async_client.pubsub_numsub("channel_name1")
retry = 5
for i in range(0, retry):
result = await async_client.pubsub_numsub("channel_name1")
if result[0][0] == "channel_name1" and result[0][1] == 0:
break
else:
time.sleep(1)
assert result[0][0] == "channel_name1" and result[0][1] == 0
result = await async_client.pubsub_numsub()
assert len(result) == 0
subs2 = [async_client.pubsub() for i in range(5)]
for s in subs2:
await s.subscribe("channel_name2")
subs3 = [async_client.pubsub() for i in range(10)]
for s in subs3:
await s.subscribe("channel_name3")
result = await async_client.pubsub_numsub("channel_name2", "channel_name3")
assert result[0][0] == "channel_name2" and result[0][1] == 5
assert result[1][0] == "channel_name3" and result[1][1] == 10
for s in subs2:
await s.unsubscribe("channel_name2")
for s in subs3:
await s.unsubscribe("channel_name3")
""" """
Test that pubsub clients who are stuck on backpressure from a slow client (the one in the test doesn't read messages at all) Test that pubsub clients who are stuck on backpressure from a slow client (the one in the test doesn't read messages at all)
will eventually unblock when it disconnects. will eventually unblock when it disconnects.