diff --git a/src/server/cluster/cluster_family_test.cc b/src/server/cluster/cluster_family_test.cc index 315cfe1a5..69cef1a68 100644 --- a/src/server/cluster/cluster_family_test.cc +++ b/src/server/cluster/cluster_family_test.cc @@ -625,7 +625,6 @@ TEST_F(ClusterFamilyTest, ClusterModePubSubNotAllowed) { ErrArg("PSUBSCRIBE is not supported in cluster mode yet")); EXPECT_THAT(Run({"PUNSUBSCRIBE", "ch?"}), ErrArg("PUNSUBSCRIBE is not supported in cluster mode yet")); - EXPECT_THAT(Run({"PUBSUB", "CHANNELS"}), ErrArg("PUBSUB is not supported in cluster mode yet")); } TEST_F(ClusterFamilyTest, ClusterFirstConfigCallDropsEntriesNotOwnedByNode) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 12dfb5196..bd902c988 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -2337,14 +2337,12 @@ void Service::PubsubChannels(string_view pattern, SinkReplyBuilder* builder) { void Service::PubsubPatterns(SinkReplyBuilder* builder) { size_t pattern_count = ServerState::tlocal()->channel_store()->PatternCount(); - builder->SendLong(pattern_count); } void Service::PubsubNumSub(CmdArgList args, SinkReplyBuilder* builder) { auto* rb = static_cast(builder); rb->StartArray(args.size() * 2); - for (string_view channel : args) { rb->SendBulkString(channel); rb->SendLong(ServerState::tlocal()->channel_store()->FetchSubscribers(channel).size()); @@ -2362,9 +2360,6 @@ void Service::Monitor(CmdArgList args, const CommandContext& cmd_cntx) { void Service::Pubsub(CmdArgList args, const CommandContext& cmd_cntx) { auto* rb = static_cast(cmd_cntx.rb); - if (IsClusterEnabled()) { - return rb->SendError("PUBSUB is not supported in cluster mode yet"); - } if (args.size() < 1) { rb->SendError(WrongNumArgsError(cmd_cntx.conn_cntx->cid->name())); return; @@ -2382,6 +2377,12 @@ void Service::Pubsub(CmdArgList args, const CommandContext& cmd_cntx) { "NUMSUB [ ]", "\tReturns the number of subscribers for the specified channels, excluding", "\tpattern subscriptions.", + "SHARDCHANNELS [pattern]", + "\tReturns a list of active shard channels, optionally matching the specified pattern ", + "(default: '*').", + "SHARDNUMSUB [ ]", + "\tReturns the number of subscribers for the specified shard channels, excluding", + "\tpattern subscriptions.", "HELP", "\tPrints this help."}; @@ -2389,16 +2390,21 @@ void Service::Pubsub(CmdArgList args, const CommandContext& cmd_cntx) { return; } - if (subcmd == "CHANNELS") { + // Don't allow SHARD subcommands in non cluster mode + if (!IsClusterEnabledOrEmulated() && ((subcmd == "SHARDCHANNELS") || (subcmd == "SHARDNUMSUB"))) { + auto err = absl::StrCat("PUBSUB ", subcmd, " is not supported in non cluster mode"); + return rb->SendError(err); + } + + if (subcmd == "CHANNELS" || subcmd == "SHARDCHANNELS") { string_view pattern; if (args.size() > 1) { pattern = ArgS(args, 1); } - PubsubChannels(pattern, rb); } else if (subcmd == "NUMPAT") { PubsubPatterns(rb); - } else if (subcmd == "NUMSUB") { + } else if (subcmd == "NUMSUB" || subcmd == "SHARDNUMSUB") { args.remove_prefix(1); PubsubNumSub(args, rb); } else { diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index 613341fcc..d0e59ba64 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -7,6 +7,7 @@ #include #include "absl/strings/str_cat.h" +#include "base/flags.h" #include "base/gtest.h" #include "base/logging.h" #include "facade/facade_test.h" @@ -17,6 +18,8 @@ using namespace std; using namespace util; using namespace boost; +ABSL_DECLARE_FLAG(string, cluster_mode); + namespace dfly { class ServerFamilyTest : public BaseFamilyTest { @@ -547,4 +550,17 @@ TEST_F(ServerFamilyTest, CommandDocsOk) { EXPECT_THAT(Run({"command", "docs"}), "OK"); } +TEST_F(ServerFamilyTest, PubSubCommandErr) { + // Check conditions only in non cluster mode + if (auto cluster_mode = absl::GetFlag(FLAGS_cluster_mode); cluster_mode == "") { + EXPECT_THAT(Run({"PUBSUB", "SHARDCHANNELS"}), + ErrArg("PUBSUB SHARDCHANNELS is not supported in non cluster mode")); + EXPECT_THAT(Run({"PUBSUB", "SHARDNUMSUB"}), + ErrArg("PUBSUB SHARDNUMSUB is not supported in non cluster mode")); + } + EXPECT_THAT(Run({"PUBSUB", "INVALIDSUBCOMMAND"}), + ErrArg("Unknown subcommand or wrong number of arguments for 'INVALIDSUBCOMMAND'. Try " + "PUBSUB HELP.")); +} + } // namespace dfly diff --git a/tests/dragonfly/cluster_test.py b/tests/dragonfly/cluster_test.py index 615208840..9c288da84 100644 --- a/tests/dragonfly/cluster_test.py +++ b/tests/dragonfly/cluster_test.py @@ -2982,6 +2982,51 @@ async def test_cluster_sharded_pub_sub(df_factory: DflyInstanceFactory): assert message == {"type": "unsubscribe", "pattern": None, "channel": b"kostas", "data": 0} +@dfly_args({"proactor_threads": 2, "cluster_mode": "yes"}) +async def test_cluster_sharded_pubsub_shard_commands(df_factory: DflyInstanceFactory): + nodes = [df_factory.create(port=next(next_port)) for i in range(2)] + df_factory.start_all(nodes) + + c_nodes = [node.client() for node in nodes] + + nodes_info = [(await create_node_info(instance)) for instance in nodes] + nodes_info[0].slots = [(0, 16383)] + nodes_info[1].slots = [] + + await push_config(json.dumps(generate_config(nodes_info)), [node.client for node in nodes_info]) + + node_a = ClusterNode("localhost", nodes[0].port) + node_b = ClusterNode("localhost", nodes[1].port) + + consumer_client = RedisCluster(startup_nodes=[node_a, node_b]) + consumer = consumer_client.pubsub() + + consumer.ssubscribe("pubsub-shard-channel") + consumer.ssubscribe("shard-channel") + + message = await c_nodes[0].execute_command("PUBSUB SHARDCHANNELS") + message.sort() + assert message == ["pubsub-shard-channel", "shard-channel"] + + message = await c_nodes[0].execute_command("PUBSUB SHARDCHANNELS pubsub*") + assert message == ["pubsub-shard-channel"] + + message = await c_nodes[0].execute_command("PUBSUB SHARDCHANNELS *channel") + message.sort() + assert message == ["pubsub-shard-channel", "shard-channel"] + + message = await c_nodes[0].execute_command("PUBSUB SHARDNUMSUB pubsub-shard-channel") + assert message == ["pubsub-shard-channel", 1] + + message = await c_nodes[0].execute_command( + "PUBSUB SHARDNUMSUB pubsub-shard-channel shard-channel" + ) + assert message == ["pubsub-shard-channel", 1, "shard-channel", 1] + + message = await c_nodes[0].execute_command("PUBSUB SHARDNUMSUB") + assert message == [] + + @dfly_args({"proactor_threads": 2, "cluster_mode": "yes"}) async def test_cluster_migration_errors_num(df_factory: DflyInstanceFactory): # create cluster with several nodes and create migrations from one node to others