From da2ad7eceb4523d16b2a285d3806e15b145788b6 Mon Sep 17 00:00:00 2001 From: Abhradeep Chakraborty Date: Tue, 11 Jul 2023 10:41:19 +0530 Subject: [PATCH] feat(stream): add support for xreadgroup command (#1475) Signed-off-by: Abhradeep Chakraborty --- src/server/stream_family.cc | 330 +++++++++++++++++++++++++++---- src/server/stream_family.h | 1 + src/server/stream_family_test.cc | 157 +++++++++++++++ src/server/transaction.cc | 2 +- 4 files changed, 453 insertions(+), 37 deletions(-) diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc index 713ec9db2..5990b46f1 100644 --- a/src/server/stream_family.cc +++ b/src/server/stream_family.cc @@ -83,16 +83,36 @@ struct RangeOpts { ParsedStreamId end; bool is_rev = false; uint32_t count = kuint32max; + + // readgroup range fields + streamCG* group = nullptr; + streamConsumer* consumer = nullptr; + bool noack = false; +}; + +struct StreamIDsItem { + ParsedStreamId id; + + // Readgroup fields - id and group-consumer pair is exclusive. + streamCG* group = nullptr; + streamConsumer* consumer = nullptr; }; struct ReadOpts { // Contains a mapping from stream name to the starting stream ID. - unordered_map stream_ids; + unordered_map stream_ids; // Contains the maximum number of entries to return for each stream. uint32_t count = kuint32max; // Contains the time to block waiting for entries, or -1 if should not block. int64_t timeout = -1; size_t streams_arg = 0; + + // readgroup fields + bool read_group = false; + bool serve_history = false; + string_view group_name; + string_view consumer_name; + bool noack = false; }; const char kInvalidStreamId[] = "Invalid stream ID specified as stream command argument"; @@ -100,6 +120,7 @@ const char kXGroupKeyNotFound[] = "The XGROUP subcommand requires the key to exist. " "Note that for CREATE you may want to use the MKSTREAM option to create " "an empty stream automatically."; +const char kSameStreamFound[] = "Same stream specified multiple time"; const uint32_t STREAM_LISTPACK_MAX_SIZE = 1 << 30; const uint32_t kStreamNodeMaxBytes = 4096; @@ -613,6 +634,9 @@ OpResult OpRange(const OpArgs& op_args, string_view key, const RangeO Record rec; rec.id = id; rec.kv_arr.reserve(numfields); + if (opts.group && streamCompareID(&id, &opts.group->last_id) > 0) { + opts.group->last_id = id; + } /* Emit the field-value pairs. */ while (numfields--) { @@ -622,11 +646,40 @@ OpResult OpRange(const OpArgs& op_args, string_view key, const RangeO string skey(reinterpret_cast(key), key_len); string sval(reinterpret_cast(value), value_len); - rec.kv_arr.emplace_back(move(skey), move(sval)); + rec.kv_arr.emplace_back(std::move(skey), std::move(sval)); } - result.push_back(move(rec)); + result.push_back(std::move(rec)); + if (opts.group && !opts.noack) { + unsigned char buf[sizeof(streamID)]; + StreamEncodeID(buf, &id); + + /* Try to add a new NACK. Most of the time this will work and + * will not require extra lookups. We'll fix the problem later + * if we find that there is already an entry for this ID. */ + streamNACK* nack = streamCreateNACK(opts.consumer); + int group_inserted = raxTryInsert(opts.group->pel, buf, sizeof(buf), nack, nullptr); + int consumer_inserted = raxTryInsert(opts.consumer->pel, buf, sizeof(buf), nack, nullptr); + + /* Now we can check if the entry was already busy, and + * in that case reassign the entry to the new consumer, + * or update it if the consumer is the same as before. */ + if (group_inserted == 0) { + streamFreeNACK(nack); + nack = static_cast(raxFind(opts.group->pel, buf, sizeof(buf))); + DCHECK(nack != raxNotFound); + raxRemove(nack->consumer->pel, buf, sizeof(buf), NULL); + /* Update the consumer and NACK metadata. */ + nack->consumer = opts.consumer; + nack->delivery_time = mstime(); + nack->delivery_count = 1; + /* Add the entry in the new consumer local PEL. */ + raxInsert(opts.consumer->pel, buf, sizeof(buf), nack, NULL); + } else if (group_inserted == 1 && consumer_inserted == 0) { + return OpStatus::SKIPPED; // ("NACK half-created. Should not be possible."); + } + } if (opts.count == result.size()) break; } @@ -636,6 +689,49 @@ OpResult OpRange(const OpArgs& op_args, string_view key, const RangeO return result; } +OpResult OpRangeFromConsumerPEL(const OpArgs& op_args, string_view key, + const RangeOpts& opts) { + RecordVec result; + + if (opts.count == 0) + return result; + + unsigned char start_key[sizeof(streamID)]; + unsigned char end_key[sizeof(streamID)]; + auto sstart = opts.start.val; + auto send = opts.end.val; + + StreamEncodeID(start_key, &sstart); + StreamEncodeID(end_key, &send); + raxIterator ri; + + raxStart(&ri, opts.consumer->pel); + raxSeek(&ri, ">=", start_key, sizeof(start_key)); + size_t ecount = 0; + while (raxNext(&ri) && (!opts.count || ecount < opts.count)) { + if (memcmp(ri.key, &send, ri.key_len) > 0) + break; + streamID id; + + streamDecodeID(ri.key, &id); + RangeOpts ropts; + ropts.start.val = id; + ropts.end.val = id; + auto op_result = OpRange(op_args, key, ropts); + if (!op_result || !op_result.value().size()) { + result.push_back(Record{id, vector>()}); + } else { + streamNACK* nack = static_cast(ri.data); + nack->delivery_time = mstime(); + nack->delivery_count++; + result.push_back(std::move(op_result.value()[0])); + } + ecount++; + } + raxStop(&ri); + return result; +} + // Returns a map of stream to the ID of the last entry in the stream. Any // streams not found are omitted from the result. OpResult>> OpLastIDs(const OpArgs& op_args, @@ -682,9 +778,21 @@ vector OpRead(const OpArgs& op_args, const ArgSlice& args, const Read for (size_t i = 0; i < args.size(); ++i) { string_view key = args[i]; - range_opts.start = opts.stream_ids.at(key); + auto sitem = opts.stream_ids.at(key); + if (!sitem.group && opts.read_group) { + continue; + } + range_opts.start = sitem.id; + range_opts.group = sitem.group; + range_opts.consumer = sitem.consumer; + range_opts.noack = opts.noack; - auto range_res = OpRange(op_args, key, range_opts); + OpResult range_res; + + if (opts.serve_history) + range_res = OpRangeFromConsumerPEL(op_args, key, range_opts); + else + range_res = OpRange(op_args, key, range_opts); if (range_res) { response[i] = std::move(range_res.value()); } @@ -817,6 +925,45 @@ OpStatus OpDestroyGroup(const OpArgs& op_args, string_view key, string_view gnam return OpStatus::SKIPPED; } +struct GroupConsumerPair { + streamCG* group; + streamConsumer* consumer; +}; + +struct GroupConsumerPairOpts { + string_view group; + string_view consumer; +}; + +vector OpGetGroupConsumerPairs(ArgSlice slice_args, const OpArgs& op_args, + const GroupConsumerPairOpts& opts) { + vector sid_items(slice_args.size()); + + // get group and consumer + for (size_t i = 0; i < slice_args.size(); i++) { + string_view key = slice_args[i]; + streamCG* group = nullptr; + streamConsumer* consumer = nullptr; + auto group_res = FindGroup(op_args, key, opts.group); + if (!group_res) { + continue; + } + if (group = group_res->second; !group) { + continue; + } + + op_args.shard->tmp_str1 = + sdscpylen(op_args.shard->tmp_str1, opts.consumer.data(), opts.consumer.size()); + consumer = streamLookupConsumer(group, op_args.shard->tmp_str1, SLC_NO_REFRESH); + if (!consumer) { + consumer = streamCreateConsumer(group, op_args.shard->tmp_str1, NULL, 0, + SCC_NO_NOTIFY | SCC_NO_DIRTIFY); + } + sid_items[i] = {group, consumer}; + } + return sid_items; +} + // XGROUP CREATECONSUMER key groupname consumername OpResult OpCreateConsumer(const OpArgs& op_args, string_view key, string_view gname, string_view consumer_name) { @@ -1354,12 +1501,35 @@ void StreamFamily::XRevRange(CmdArgList args, ConnectionContext* cntx) { XRangeGeneric(std::move(args), true, cntx); } -std::optional ParseReadArgsOrReply(CmdArgList args, ConnectionContext* cntx) { +std::optional ParseReadArgsOrReply(CmdArgList args, bool read_group, + ConnectionContext* cntx) { size_t streams_count = 0; ReadOpts opts; + opts.read_group = read_group; + size_t id_indx = 0; - for (size_t id_indx = 0; id_indx < args.size(); ++id_indx) { + if (opts.read_group) { + ToUpper(&args[id_indx]); + string_view arg = ArgS(args, id_indx); + + if (arg.size() - 1 < 2) { + (*cntx)->SendError(kSyntaxErr); + return std::nullopt; + } + + if (arg != "GROUP") { + const auto m = "Missing 'GROUP' in 'XREADGROUP' command"; + (*cntx)->SendError(m, kSyntaxErr); + return std::nullopt; + } + id_indx++; + opts.group_name = ArgS(args, id_indx); + opts.consumer_name = ArgS(args, ++id_indx); + id_indx++; + } + + for (; id_indx < args.size(); ++id_indx) { ToUpper(&args[id_indx]); string_view arg = ArgS(args, id_indx); @@ -1378,13 +1548,14 @@ std::optional ParseReadArgsOrReply(CmdArgList args, ConnectionContext* (*cntx)->SendError(kInvalidIntErr); return std::nullopt; } + } else if (opts.read_group && arg == "NOACK") { + opts.noack = true; } else if (arg == "STREAMS" && remaining_args) { opts.streams_arg = id_indx + 1; size_t pair_count = args.size() - opts.streams_arg; if ((pair_count % 2) != 0) { - const auto m = - "Unbalanced 'XREAD' list of streams: for each stream key an ID must be specified"; + const auto m = "Unbalanced list of streams: for each stream key an ID must be specified"; (*cntx)->SendError(m, kSyntaxErr); return std::nullopt; } @@ -1407,25 +1578,45 @@ std::optional ParseReadArgsOrReply(CmdArgList args, ConnectionContext* string_view key = ArgS(args, i - streams_count); string_view idstr = ArgS(args, i); + StreamIDsItem sitem; ParsedStreamId id; if (idstr == "$") { // Set ID to 0 so if the ID cannot be resolved (when the stream doesn't // exist) it takes the first entry added. + if (opts.read_group) { + (*cntx)->SendError("The $ can be specified only when calling XREAD.", kSyntaxErr); + return std::nullopt; + } id.val.ms = 0; id.val.seq = 0; id.last_id = true; - opts.stream_ids.emplace(key, id); + sitem.id = id; + auto [_, is_inserted] = opts.stream_ids.emplace(key, sitem); + if (!is_inserted) { + (*cntx)->SendError(kSameStreamFound); + return std::nullopt; + } continue; } if (idstr == ">") { - // XREADGROUP is not supported. - (*cntx)->SendError( - "The > ID can be specified only when calling XREADGROUP using the GROUP " - " option.", - kSyntaxErr); - return std::nullopt; + if (!opts.read_group) { + (*cntx)->SendError( + "The > ID can be specified only when calling XREADGROUP using the GROUP " + " option.", + kSyntaxErr); + return std::nullopt; + } + id.val.ms = UINT64_MAX; + id.val.seq = UINT64_MAX; + sitem.id = id; + auto [_, is_inserted] = opts.stream_ids.emplace(key, sitem); + if (!is_inserted) { + (*cntx)->SendError(kSameStreamFound); + return std::nullopt; + } + continue; } if (!ParseID(idstr, true, 0, &id)) { @@ -1436,9 +1627,13 @@ std::optional ParseReadArgsOrReply(CmdArgList args, ConnectionContext* // We only include messages with IDs greater than start so increment the // starting ID. streamIncrID(&id.val); - opts.stream_ids.emplace(key, id); + sitem.id = id; + auto [_, is_inserted] = opts.stream_ids.emplace(key, sitem); + if (!is_inserted) { + (*cntx)->SendError(kSameStreamFound); + return std::nullopt; + } } - return opts; } @@ -1498,7 +1693,12 @@ void XReadBlock(ReadOpts opts, ConnectionContext* cntx) { .ms = UINT64_MAX, .seq = UINT64_MAX, }}; - range_opts.start = opts.stream_ids.at(*wake_key); + auto sitem = opts.stream_ids.at(*wake_key); + range_opts.start = sitem.id; + range_opts.group = sitem.group; + range_opts.consumer = sitem.consumer; + range_opts.noack = opts.noack; + result = OpRange(t->GetOpArgs(shard), *wake_key, range_opts); key = *wake_key; } @@ -1530,14 +1730,8 @@ void XReadBlock(ReadOpts opts, ConnectionContext* cntx) { } } -void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) { - auto opts = ParseReadArgsOrReply(args, cntx); - if (!opts) { - return; - } - - cntx->transaction->Schedule(); - +// Read entries from given streams +void XReadImpl(CmdArgList args, std::optional opts, ConnectionContext* cntx) { auto last_ids = StreamLastIDs(cntx->transaction); if (!last_ids) { // Close the transaction. @@ -1555,21 +1749,37 @@ void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) { // Resolve '$' IDs and check if there are any streams with entries that can // be resolved without blocking. bool block = true; - for (auto& [stream, requested_id] : opts->stream_ids) { + for (auto& [stream, requested_sitem] : opts->stream_ids) { if (auto last_id_it = last_ids->find(stream); last_id_it != last_ids->end()) { streamID last_id = last_id_it->second; - // Resolve $ to the last ID in the stream. - if (requested_id.last_id) { - requested_id.val = last_id; - // We only include messages with IDs greater than the last message so - // increment the ID. - streamIncrID(&requested_id.val); - requested_id.last_id = false; + if (opts->read_group && !requested_sitem.group) { + // if the group associated with the key is not found, + // we will not read entries from the key. continue; } - if (streamCompareID(&last_id, &requested_id.val) >= 0) { + // Resolve $ to the last ID in the stream. + if (requested_sitem.id.last_id && !opts->read_group) { + requested_sitem.id.val = last_id; + // We only include messages with IDs greater than the last message so + // increment the ID. + streamIncrID(&requested_sitem.id.val); + requested_sitem.id.last_id = false; + continue; + } + if (opts->read_group) { + // If '>' is not provided, consumer PEL is used. So don't need to block. + if (requested_sitem.id.val.ms != UINT64_MAX || requested_sitem.id.val.seq != UINT64_MAX) { + block = false; + opts->serve_history = true; + continue; + } + requested_sitem.id.val = requested_sitem.group->last_id; + streamIncrID(&requested_sitem.id.val); + } + + if (streamCompareID(&last_id, &requested_sitem.id.val) >= 0) { block = false; } } @@ -1640,6 +1850,52 @@ void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) { } } +void XReadGeneric(CmdArgList args, bool read_group, ConnectionContext* cntx) { + auto opts = ParseReadArgsOrReply(args, read_group, cntx); + if (!opts) { + return; + } + + vector> res_pairs(shard_set->size()); + auto cb = [&](Transaction* t, EngineShard* shard) { + auto sid = shard->shard_id(); + auto s_args = t->GetShardArgs(sid); + GroupConsumerPairOpts gc_opts = {opts->group_name, opts->consumer_name}; + + res_pairs[sid] = OpGetGroupConsumerPairs(s_args, t->GetOpArgs(shard), gc_opts); + return OpStatus::OK; + }; + cntx->transaction->Schedule(); + if (opts->read_group) { + // If the command is `XReadGroup`, we need to get + // the (group, consumer) pairs for each key. + cntx->transaction->Execute(std::move(cb), false); + + for (size_t i = 0; i < shard_set->size(); i++) { + auto s_item = res_pairs[i]; + auto s_args = cntx->transaction->GetShardArgs(i); + if (s_item.size() == 0) { + continue; + } + for (size_t j = 0; j < s_args.size(); j++) { + string_view key = s_args[j]; + StreamIDsItem& item = opts->stream_ids.at(key); + item.consumer = s_item[j].consumer; + item.group = s_item[j].group; + } + } + } + return XReadImpl(args, opts, cntx); +} + +void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) { + return XReadGeneric(args, false, cntx); +} + +void StreamFamily::XReadGroup(CmdArgList args, ConnectionContext* cntx) { + return XReadGeneric(args, true, cntx); +} + void StreamFamily::XSetId(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); string_view idstr = ArgS(args, 1); @@ -1763,6 +2019,8 @@ void StreamFamily::Register(CommandRegistry* registry) { << CI{"XREVRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(XRevRange) << CI{"XREAD", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 3, 3, 1} .HFUNC(XRead) + << CI{"XREADGROUP", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -6, 6, 6, 1} + .HFUNC(XReadGroup) << CI{"XSETID", CO::WRITE | CO::DENYOOM, 3, 1, 1, 1}.HFUNC(XSetId) << CI{"XTRIM", CO::WRITE | CO::FAST, -4, 1, 1, 1}.HFUNC(XTrim) << CI{"_XGROUP_HELP", CO::NOSCRIPT | CO::HIDDEN, 2, 0, 0, 0}.SetHandler(XGroupHelp); diff --git a/src/server/stream_family.h b/src/server/stream_family.h index 9c3e9137d..f964ce766 100644 --- a/src/server/stream_family.h +++ b/src/server/stream_family.h @@ -24,6 +24,7 @@ class StreamFamily { static void XRevRange(CmdArgList args, ConnectionContext* cntx); static void XRange(CmdArgList args, ConnectionContext* cntx); static void XRead(CmdArgList args, ConnectionContext* cntx); + static void XReadGroup(CmdArgList args, ConnectionContext* cntx); static void XSetId(CmdArgList args, ConnectionContext* cntx); static void XTrim(CmdArgList args, ConnectionContext* cntx); static void XRangeGeneric(CmdArgList args, bool is_rev, ConnectionContext* cntx); diff --git a/src/server/stream_family_test.cc b/src/server/stream_family_test.cc index 325337c79..2d2944a62 100644 --- a/src/server/stream_family_test.cc +++ b/src/server/stream_family_test.cc @@ -159,6 +159,86 @@ TEST_F(StreamFamilyTest, XRead) { EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); } +TEST_F(StreamFamilyTest, XReadGroup) { + Run({"xadd", "foo", "1-*", "k1", "v1"}); + Run({"xadd", "foo", "1-*", "k2", "v2"}); + Run({"xadd", "foo", "1-*", "k3", "v3"}); + Run({"xadd", "bar", "1-*", "k4", "v4"}); + + Run({"xadd", "mystream", "k1", "v1"}); + Run({"xadd", "mystream", "k2", "v2"}); + Run({"xadd", "mystream", "k3", "v3"}); + + Run({"xgroup", "create", "foo", "group", "0"}); + Run({"xgroup", "create", "bar", "group", "0"}); + + // consumer PEL is empty, so resp should have empty list + auto resp = Run({"xreadgroup", "group", "group", "alice", "streams", "foo", "0"}); + EXPECT_THAT(resp, ArrLen(0)); + + // should return unread entries with key "foo" + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "foo", ">"}); + // only "foo" key entries are read + EXPECT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec()[1], ArrLen(3)); + + Run({"xadd", "foo", "1-*", "k5", "v5"}); + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "bar", "foo", ">", ">"}); + EXPECT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("bar", ArrLen(1))); + EXPECT_THAT(resp.GetVec()[0].GetVec()[1].GetVec()[0].GetVec(), ElementsAre("1-0", ArrLen(2))); + EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("foo", ArrLen(1))); + EXPECT_THAT(resp.GetVec()[1].GetVec()[1].GetVec()[0].GetVec(), ElementsAre("1-3", ArrLen(2))); + + // now we can specify id for "foo" and it fetches from alice's consumer PEL + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "foo", "0"}); + EXPECT_THAT(resp.GetVec()[1], ArrLen(4)); + + // now ">" gives nil + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "foo", ">"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); + + // count limits the fetched entries + resp = Run( + {"xreadgroup", "group", "group", "alice", "count", "2", "streams", "foo", "bar", "0", "0"}); + EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("foo", ArrLen(2))); + EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("bar", ArrLen(1))); + + // bob will not get entries of alice + resp = Run({"xreadgroup", "group", "group", "bob", "streams", "foo", "0"}); + EXPECT_THAT(resp, ArrLen(0)); + + resp = Run({"xinfo", "groups", "foo"}); + // 2 consumers created + EXPECT_THAT(resp.GetVec()[3], IntArg(2)); + // check last_delivery_id + EXPECT_THAT(resp.GetVec()[7], "1-3"); + + // Noack + Run({"xadd", "foo", "1-*", "k6", "v6"}); + resp = Run({"xreadgroup", "group", "group", "bob", "noack", "streams", "foo", ">"}); + // check basic results + EXPECT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), ElementsAre("foo", ArrLen(1))); + // Entry is not inserted in Bob's consumer PEL. + resp = Run({"xreadgroup", "group", "group", "bob", "streams", "foo", "0"}); + EXPECT_THAT(resp, ArrLen(0)); + + // No Group + resp = Run({"xreadgroup", "group", "nogroup", "alice", "streams", "foo", "0"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); + + // '>' gives the null array result if group doesn't exist + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "mystream", ">"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); + + Run({"xadd", "foo", "1-*", "k7", "v7"}); + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "mystream", "foo", ">", ">"}); + // Only entries of 'foo' is read + EXPECT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), ElementsAre("foo", ArrLen(1))); +} + TEST_F(StreamFamilyTest, XReadBlock) { Run({"xadd", "foo", "1-*", "k1", "v1"}); Run({"xadd", "foo", "1-*", "k2", "v2"}); @@ -197,6 +277,51 @@ TEST_F(StreamFamilyTest, XReadBlock) { EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1))); } +TEST_F(StreamFamilyTest, XReadGroupBlock) { + Run({"xadd", "foo", "1-*", "k1", "v1"}); + Run({"xadd", "foo", "1-*", "k2", "v2"}); + Run({"xadd", "foo", "1-*", "k3", "v3"}); + Run({"xadd", "bar", "1-*", "k4", "v4"}); + + Run({"xgroup", "create", "foo", "group", "0"}); + Run({"xgroup", "create", "bar", "group", "0"}); + + // Receive all records from both streams. + auto resp = Run( + {"xreadgroup", "group", "group", "alice", "block", "100", "streams", "foo", "bar", ">", ">"}); + EXPECT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("foo", ArrLen(3))); + EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("bar", ArrLen(1))); + + // Timeout + resp = Run( + {"xreadgroup", "group", "group", "alice", "block", "1", "streams", "foo", "bar", ">", ">"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY)); + + // Run XREADGROUP BLOCK from 2 fibers. + RespExpr resp0, resp1; + auto fb0 = pp_->at(0)->LaunchFiber(Launch::dispatch, [&] { + resp0 = Run( + {"xreadgroup", "group", "group", "alice", "block", "0", "streams", "foo", "bar", ">", ">"}); + }); + auto fb1 = pp_->at(1)->LaunchFiber(Launch::dispatch, [&] { + resp1 = Run( + {"xreadgroup", "group", "group", "alice", "block", "0", "streams", "foo", "bar", ">", ">"}); + }); + ThisFiber::SleepFor(50us); + + resp = pp_->at(1)->Await([&] { return Run("xadd", {"xadd", "foo", "1-*", "k5", "v5"}); }); + + fb0.Join(); + fb1.Join(); + + // Both xread calls should have been unblocked. + // + // Note when the response has length 1, Run returns the first element. + EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1))); + EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1))); +} + TEST_F(StreamFamilyTest, XReadInvalidArgs) { // Invalid COUNT value. auto resp = Run({"xread", "count", "invalid", "streams", "s1", "s2", "0", "0"}); @@ -228,6 +353,38 @@ TEST_F(StreamFamilyTest, XReadInvalidArgs) { EXPECT_THAT(resp, ErrArg("key holding the wrong kind of value")); } +TEST_F(StreamFamilyTest, XReadGroupInvalidArgs) { + Run({"xgroup", "create", "group", "foo", "0", "mkstream"}); + // Invalid COUNT value. + auto resp = + Run({"xreadgroup", "group", "group", "alice", "count", "invalid", "streams", "foo", "0"}); + EXPECT_THAT(resp, ErrArg("not an integer or out of range")); + + // Invalid "stream" instead of GROUP. + resp = Run({"xreadgroup", "stream", "group", "alice", "count", "1", "streams", "foo", "0"}); + EXPECT_THAT(resp, ErrArg("Missing 'GROUP' in 'XREADGROUP' command")); + + // Missing streams. + resp = Run({"xreadgroup", "group", "group", "alice", "streams"}); + EXPECT_THAT(resp, ErrArg("wrong number of arguments for 'xreadgroup' command")); + + // Missing consumer. + resp = Run({"xreadgroup", "group", "group", "streams", "foo", "0"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + // Missing block value. + resp = Run({"xreadgroup", "group", "group", "alice", "block", "streams", "foo", "0"}); + EXPECT_THAT(resp, ErrArg("not an integer or out of range")); + + // Invalid block value. + resp = Run({"xreadgroup", "group", "group", "alice", "block", "invalid", "streams", "foo", "0"}); + EXPECT_THAT(resp, ErrArg("not an integer or out of range")); + + // Unbalanced list of streams. + resp = Run({"xreadgroup", "group", "group", "alice", "streams", "s1", "s2", "s3", "0", "0"}); + EXPECT_THAT(resp, ErrArg("syntax error")); +} + TEST_F(StreamFamilyTest, Issue854) { auto resp = Run({"xgroup", "help"}); EXPECT_THAT(resp, ArgType(RespExpr::ARRAY)); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 834ee4c42..11339f3f2 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -1414,7 +1414,7 @@ OpResult DetermineKeys(const CommandId* cid, CmdArgList args) { string_view name{cid->name()}; - if (name == "XREAD") { + if (name == "XREAD" || name == "XREADGROUP") { for (size_t i = 0; i < args.size(); ++i) { string_view arg = ArgS(args, i); if (absl::EqualsIgnoreCase(arg, "STREAMS")) {