feat(streams): Add support for XREAD BLOCK (#1291)

* feat(streams): Add support for XREAD BLOCK

---------

Signed-off-by: Andrew Dunstall <andydunstall@hotmail.co.uk>
This commit is contained in:
Andy Dunstall 2023-05-27 20:47:31 +01:00 committed by GitHub
parent bc717a037d
commit 1cfeff21a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 309 additions and 73 deletions

View file

@ -212,9 +212,9 @@ void BlockingController::NotifyPending() {
// Double verify we still got the item.
auto [it, exp_it] = owner_->db_slice().FindExt(context, sv_key);
if (!IsValid(it) ||
!(it->second.ObjType() == OBJ_LIST ||
it->second.ObjType() == OBJ_ZSET)) // Only LIST and ZSET are allowed to block.
// Only LIST, ZSET and STREAM are allowed to block.
if (!IsValid(it) || !(it->second.ObjType() == OBJ_LIST || it->second.ObjType() == OBJ_ZSET ||
it->second.ObjType() == OBJ_STREAM))
continue;
NotifyWatchQueue(sv_key, &wt.queue_map);

View file

@ -40,6 +40,7 @@ void BlockingControllerTest::SetUp() {
pp_.reset(fb2::Pool::IOUring(16, kNumThreads));
pp_->Run();
pp_->Await([](unsigned index, ProactorBase* p) { ServerState::Init(index); });
ServerState::Init(kNumThreads);
shard_set = new EngineShardSet(pp_.get());
shard_set->Init(kNumThreads, false);

View file

@ -282,11 +282,8 @@ facade::OpStatus RunCbOnFirstNonEmptyBlocking(BlockingResultCb&& func, std::stri
};
VLOG(1) << "Blocking BLPOP " << trans->DebugId();
auto* stats = ServerState::tl_connection_stats();
++stats->num_blocked_clients;
bool wait_succeeded = trans->WaitOnWatch(limit_tp, std::move(wcb));
--stats->num_blocked_clients;
bool wait_succeeded = trans->WaitOnWatch(limit_tp, std::move(wcb));
if (!wait_succeeded)
return OpStatus::TIMED_OUT;
} else {

View file

@ -830,15 +830,10 @@ OpResult<string> BPopPusher::RunSingle(Transaction* t, time_point tp) {
return op_res;
}
auto* stats = ServerState::tl_connection_stats();
auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; };
// Block
++stats->num_blocked_clients;
bool wait_succeeded = t->WaitOnWatch(tp, std::move(wcb));
--stats->num_blocked_clients;
if (!wait_succeeded)
return OpStatus::TIMED_OUT;
@ -857,19 +852,13 @@ OpResult<string> BPopPusher::RunPair(Transaction* t, time_point tp) {
return op_res;
}
auto* stats = ServerState::tl_connection_stats();
// a hack: we watch in both shards for pop_key but only in the source shard it's relevant.
// Therefore we follow the regular flow of watching the key but for the destination shard it
// will never be triggerred.
// This allows us to run Transaction::Execute on watched transactions in both shards.
auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; };
++stats->num_blocked_clients;
bool wait_succeeded = t->WaitOnWatch(tp, std::move(wcb));
--stats->num_blocked_clients;
if (!wait_succeeded)
return OpStatus::TIMED_OUT;

View file

@ -13,9 +13,11 @@ extern "C" {
#include "base/logging.h"
#include "facade/error.h"
#include "server/blocking_controller.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/engine_shard_set.h"
#include "server/server_state.h"
#include "server/transaction.h"
namespace dfly {
@ -34,8 +36,15 @@ using RecordVec = vector<Record>;
struct ParsedStreamId {
streamID val;
bool has_seq = false; // Was an ID different than "ms-*" specified? for XADD only.
bool id_given = false; // Was an ID different than "*" specified? for XADD only.
// Was an ID different than "ms-*" specified? for XADD only.
bool has_seq = false;
// Was an ID different than "*" specified? for XADD only.
bool id_given = false;
// Whether to lookup messages after the last ID in the stream. Used for XREAD
// when using ID '$'.
bool last_id = false;
};
struct RangeId {
@ -78,6 +87,9 @@ struct ReadOpts {
unordered_map<string_view, ParsedStreamId> stream_ids;
// Contains the maximum number of entries to return for each stream.
uint32_t count = kuint32max;
// Contains the time to block waiting for entries, or -1 if should not block.
int64_t timeout = -1;
size_t streams_arg;
};
const char kInvalidStreamId[] = "Invalid stream ID specified as stream command argument";
@ -555,6 +567,12 @@ OpResult<streamID> OpAdd(const OpArgs& op_args, string_view key, const AddOpts&
}
streamTrim(stream_inst, &add_args);
}
EngineShard* es = op_args.shard;
if (es->blocking_controller()) {
es->blocking_controller()->AwakeWatched(op_args.db_cntx.db_index, key);
}
return result_id;
}
@ -604,6 +622,36 @@ OpResult<RecordVec> OpRange(const OpArgs& op_args, string_view key, const RangeO
return result;
}
// Returns a map of stream to the ID of the last entry in the stream. Any
// streams not found are omitted from the result.
OpResult<vector<pair<string_view, streamID>>> OpLastIDs(const OpArgs& op_args,
const ArgSlice& args) {
DCHECK(!args.empty());
auto& db_slice = op_args.shard->db_slice();
vector<pair<string_view, streamID>> last_ids;
for (string_view key : args) {
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it) {
if (res_it.status() == OpStatus::KEY_NOTFOUND) {
continue;
}
return res_it.status();
}
CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
streamID last_id;
streamLastValidID(s, &last_id);
last_ids.emplace_back(key, last_id);
}
return last_ids;
}
// Returns the range response for each stream on this shard in order of
// GetShardArgs.
vector<RecordVec> OpRead(const OpArgs& op_args, const ArgSlice& args, const ReadOpts& opts) {
@ -1192,97 +1240,243 @@ void StreamFamily::XRevRange(CmdArgList args, ConnectionContext* cntx) {
XRangeGeneric(std::move(args), true, cntx);
}
void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) {
std::optional<ReadOpts> ParseReadArgsOrReply(CmdArgList args, ConnectionContext* cntx) {
size_t streams_count = 0;
size_t streams_arg = 0;
uint32_t count = kuint32max;
ReadOpts opts;
// Parse the arguments.
for (size_t id_indx = 0; id_indx < args.size(); ++id_indx) {
ToUpper(&args[id_indx]);
string_view arg = ArgS(args, id_indx);
size_t remaining_args = args.size() - id_indx - 1;
if (arg == "BLOCK") {
return (*cntx)->SendError("BLOCK is not supported", kSyntaxErrType);
} else if (arg == "COUNT" && remaining_args > 0) {
bool remaining_args = args.size() - id_indx - 1 > 0;
if (arg == "BLOCK" && remaining_args) {
id_indx++;
arg = ArgS(args, id_indx);
if (!absl::SimpleAtoi(arg, &count)) {
return (*cntx)->SendError(kSyntaxErr);
if (!absl::SimpleAtoi(arg, &opts.timeout)) {
(*cntx)->SendError(kInvalidIntErr);
return std::nullopt;
}
} else if (arg == "STREAMS" && remaining_args > 0) {
streams_arg = id_indx + 1;
} else if (arg == "COUNT" && remaining_args) {
id_indx++;
arg = ArgS(args, id_indx);
if (!absl::SimpleAtoi(arg, &opts.count)) {
(*cntx)->SendError(kInvalidIntErr);
return std::nullopt;
}
} else if (arg == "STREAMS" && remaining_args) {
opts.streams_arg = id_indx + 1;
size_t pair_count = args.size() - streams_arg;
size_t pair_count = args.size() - opts.streams_arg;
if ((pair_count % 2) != 0) {
const auto m =
"Unbalanced 'XREAD' list of streams: for each stream key an ID must be specified";
return (*cntx)->SendError(m, kSyntaxErr);
(*cntx)->SendError(m, kSyntaxErr);
return std::nullopt;
}
streams_count = pair_count / 2;
break;
} else {
return (*cntx)->SendError(kSyntaxErr);
(*cntx)->SendError(kSyntaxErr);
return std::nullopt;
}
}
// STREAMS option is required.
if (streams_arg == 0) {
return (*cntx)->SendError(kSyntaxErr);
if (opts.streams_arg == 0) {
(*cntx)->SendError(kSyntaxErr);
return std::nullopt;
}
ReadOpts read_opts;
read_opts.count = count;
// Parse the stream IDs.
for (size_t i = streams_arg + streams_count; i < args.size(); i++) {
for (size_t i = opts.streams_arg + streams_count; i < args.size(); i++) {
string_view key = ArgS(args, i - streams_count);
string_view idstr = ArgS(args, i);
ParsedStreamId id;
if (idstr == "$") {
return (*cntx)->SendError(
"Since BLOCK is not supported, the $ ID is meaningless as it will always return an empty "
"result set.",
kSyntaxErr);
// Set ID to 0 so if the ID cannot be resolved (when the stream doesn't
// exist) it takes the first entry added.
id.val.ms = 0;
id.val.seq = 0;
id.last_id = true;
opts.stream_ids.emplace(key, id);
continue;
}
if (idstr == ">") {
// XREADGROUP is not supported.
return (*cntx)->SendError(
(*cntx)->SendError(
"The > ID can be specified only when calling XREADGROUP using the GROUP <group> "
"<consumer> option.",
kSyntaxErr);
return std::nullopt;
}
ParsedStreamId id;
if (!ParseID(idstr, true, 0, &id)) {
return (*cntx)->SendError(kInvalidStreamId, kSyntaxErrType);
(*cntx)->SendError(kInvalidStreamId, kSyntaxErrType);
return std::nullopt;
}
// We only include messages with IDs greater than start so increment the
// starting ID.
streamIncrID(&id.val);
read_opts.stream_ids.emplace(key, id);
opts.stream_ids.emplace(key, id);
}
unsigned shard_count = shard_set->size();
vector<vector<RecordVec>> xread_resp(shard_count);
return opts;
}
// Returns the last ID of each stream in the transaction.
OpResult<unordered_map<string_view, streamID>> StreamLastIDs(Transaction* trans) {
vector<OpResult<vector<pair<string_view, streamID>>>> last_ids_res(shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
xread_resp[sid] = OpRead(t->GetOpArgs(shard), t->GetShardArgs(shard->shard_id()), read_opts);
last_ids_res[sid] = OpLastIDs(t->GetOpArgs(shard), t->GetShardArgs(shard->shard_id()));
return OpStatus::OK;
};
OpStatus result = cntx->transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, result);
trans->Execute(std::move(cb), false);
unordered_map<string_view, streamID> last_ids;
for (auto res : last_ids_res) {
if (!res) {
return res.status();
}
for (auto& e : *res) {
last_ids.emplace(e.first, e.second);
}
}
return last_ids;
}
void XReadBlock(ReadOpts opts, ConnectionContext* cntx) {
// If BLOCK is not set just return an empty array as there are no resolvable
// entries.
if (opts.timeout == -1 || cntx->transaction->IsMulti()) {
// Close the transaction and release locks.
auto close_cb = [&](Transaction* t, EngineShard* shard) { return OpStatus::OK; };
cntx->transaction->Execute(std::move(close_cb), true);
return (*cntx)->SendNullArray();
}
auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); };
auto tp = (opts.timeout) ? chrono::steady_clock::now() + chrono::milliseconds(opts.timeout)
: Transaction::time_point::max();
bool wait_succeeded = cntx->transaction->WaitOnWatch(tp, std::move(wcb));
if (!wait_succeeded) {
return (*cntx)->SendNullArray();
}
// Resolve the entry in the woken key. Note this must not use OpRead since
// only the shard that contains the woken key blocks for the awoken
// transaction to proceed.
OpResult<RecordVec> result;
std::string key;
auto range_cb = [&](Transaction* t, EngineShard* shard) {
if (auto wake_key = t->GetWakeKey(shard->shard_id()); wake_key) {
RangeOpts range_opts;
range_opts.end = ParsedStreamId{.val = streamID{
.ms = UINT64_MAX,
.seq = UINT64_MAX,
}};
range_opts.start = opts.stream_ids.at(*wake_key);
result = OpRange(t->GetOpArgs(shard), *wake_key, range_opts);
key = *wake_key;
}
return OpStatus::OK;
};
cntx->transaction->Execute(std::move(range_cb), true);
if (result) {
(*cntx)->StartArray(1);
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(key);
(*cntx)->StartArray(result->size());
for (const auto& item : *result) {
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(StreamIdRepr(item.id));
(*cntx)->StartArray(item.kv_arr.size() * 2);
for (const auto& k_v : item.kv_arr) {
(*cntx)->SendBulkString(k_v.first);
(*cntx)->SendBulkString(k_v.second);
}
}
return;
} else {
return (*cntx)->SendNullArray();
}
}
void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) {
auto opts = ParseReadArgsOrReply(args, cntx);
if (!opts) {
return;
}
cntx->transaction->Schedule();
auto last_ids = StreamLastIDs(cntx->transaction);
if (!last_ids) {
// Close the transaction.
auto close_cb = [&](Transaction* t, EngineShard* shard) { return OpStatus::OK; };
cntx->transaction->Execute(std::move(close_cb), true);
if (last_ids.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
return;
}
return (*cntx)->SendNullArray();
}
// Resolve '$' IDs and check if there are any streams with entries that can
// be resolved without blocking.
bool block = true;
for (auto& [stream, requested_id] : opts->stream_ids) {
if (auto last_id_it = last_ids->find(stream); last_id_it != last_ids->end()) {
streamID last_id = last_id_it->second;
// Resolve $ to the last ID in the stream.
if (requested_id.last_id) {
requested_id.val = last_id;
// We only include messages with IDs greater than the last message so
// increment the ID.
streamIncrID(&requested_id.val);
requested_id.last_id = false;
continue;
}
if (streamCompareID(&last_id, &requested_id.val) >= 0) {
block = false;
}
}
}
if (block) {
return XReadBlock(*opts, cntx);
}
vector<vector<RecordVec>> xread_resp(shard_set->size());
auto read_cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
xread_resp[sid] = OpRead(t->GetOpArgs(shard), t->GetShardArgs(shard->shard_id()), *opts);
return OpStatus::OK;
};
cntx->transaction->Execute(std::move(read_cb), true);
// Merge the results into a single response ordered by stream.
vector<RecordVec> res(streams_count);
vector<RecordVec> res(opts->stream_ids.size());
// Track the number of streams with records as empty streams are excluded from
// the response.
int resolved_streams = 0;
for (ShardId sid = 0; sid < shard_count; ++sid) {
for (ShardId sid = 0; sid < shard_set->size(); ++sid) {
if (!cntx->transaction->IsActive(sid))
continue;
@ -1302,7 +1496,7 @@ void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) {
// Add the stream records ordered by the original stream arguments.
size_t indx = cntx->transaction->ReverseArgIndex(sid, i);
res[indx - streams_arg] = std::move(results[i]);
res[indx - opts->streams_arg] = std::move(results[i]);
}
}
@ -1314,7 +1508,7 @@ void StreamFamily::XRead(CmdArgList args, ConnectionContext* cntx) {
}
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(ArgS(args, i + streams_arg));
(*cntx)->SendBulkString(ArgS(args, i + opts->streams_arg));
(*cntx)->StartArray(res[i].size());
for (const auto& item : res[i]) {
(*cntx)->StartArray(2);
@ -1428,7 +1622,6 @@ void StreamFamily::Register(CommandRegistry* registry) {
<< CI{"XLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(XLen)
<< CI{"XRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(XRange)
<< CI{"XREVRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(XRevRange)
// TODO NB: Doesn't support BLOCK
<< CI{"XREAD", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 3, 3, 1}
.HFUNC(XRead)
<< CI{"XSETID", CO::WRITE | CO::DENYOOM, 3, 1, 1, 1}.HFUNC(XSetId)

View file

@ -127,13 +127,13 @@ TEST_F(StreamFamilyTest, XRead) {
Run({"xadd", "bar", "1-*", "k4", "v4"});
// Receive all records from both streams.
auto resp = Run({"xread", "count", "10", "streams", "foo", "bar", "0", "0"});
auto resp = Run({"xread", "streams", "foo", "bar", "0", "0"});
EXPECT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("foo", ArrLen(3)));
EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("bar", ArrLen(1)));
// Order of the requested streams is maintained.
resp = Run({"xread", "count", "10", "streams", "bar", "foo", "0", "0"});
resp = Run({"xread", "streams", "bar", "foo", "0", "0"});
EXPECT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("bar", ArrLen(1)));
EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("foo", ArrLen(3)));
@ -143,31 +143,77 @@ TEST_F(StreamFamilyTest, XRead) {
EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("foo", ArrLen(1)));
EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("bar", ArrLen(1)));
// Stream not found.
resp = Run({"xread", "count", "10", "streams", "foo", "notfound", "0", "0"});
// Note when the response has length 1, Run returns the first element.
EXPECT_THAT(resp.GetVec(), ElementsAre("foo", ArrLen(3)));
// Read from ID.
resp = Run({"xread", "count", "10", "streams", "foo", "bar", "1-1", "2-0"});
// Note when the response has length 1, Run returns the first element.
EXPECT_THAT(resp.GetVec(), ElementsAre("foo", ArrLen(1)));
EXPECT_THAT(resp.GetVec()[1].GetVec()[0].GetVec(), ElementsAre("1-2", ArrLen(2)));
// Stream not found.
resp = Run({"xread", "streams", "foo", "notfound", "0", "0"});
// Note when the response has length 1, Run returns the first element.
EXPECT_THAT(resp.GetVec(), ElementsAre("foo", ArrLen(3)));
// Not found.
resp = Run({"xread", "streams", "notfound", "0"});
EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY));
}
TEST_F(StreamFamilyTest, XReadBlock) {
Run({"xadd", "foo", "1-*", "k1", "v1"});
Run({"xadd", "foo", "1-*", "k2", "v2"});
Run({"xadd", "foo", "1-*", "k3", "v3"});
Run({"xadd", "bar", "1-*", "k4", "v4"});
// Receive all records from both streams.
auto resp = Run({"xread", "block", "100", "streams", "foo", "bar", "0", "0"});
EXPECT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("foo", ArrLen(3)));
EXPECT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("bar", ArrLen(1)));
// Timeout.
resp = Run({"xread", "block", "1", "streams", "foo", "bar", "$", "$"});
EXPECT_THAT(resp, ArgType(RespExpr::NIL_ARRAY));
// Run XREAD BLOCK from 2 fibers.
RespExpr resp0, resp1;
auto fb0 = pp_->at(0)->LaunchFiber(Launch::dispatch, [&] {
resp0 = Run({"xread", "block", "0", "streams", "foo", "bar", "$", "$"});
});
auto fb1 = pp_->at(1)->LaunchFiber(Launch::dispatch, [&] {
resp1 = Run({"xread", "block", "0", "streams", "foo", "bar", "$", "$"});
});
ThisFiber::SleepFor(50us);
resp = pp_->at(1)->Await([&] { return Run("xadd", {"xadd", "foo", "1-*", "k5", "v5"}); });
fb0.Join();
fb1.Join();
// Both xread calls should have been unblocked.
//
// Note when the response has length 1, Run returns the first element.
EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1)));
EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1)));
}
TEST_F(StreamFamilyTest, XReadInvalidArgs) {
// Using BLOCK when it is not supported.
auto resp = Run({"xread", "count", "5", "block", "2000", "streams", "s1", "s2", "0", "0"});
EXPECT_THAT(resp, ErrArg("BLOCK is not supported"));
// Invalid COUNT value.
resp = Run({"xread", "count", "invalid", "streams", "s1", "s2", "0", "0"});
EXPECT_THAT(resp, ErrArg("syntax error"));
auto resp = Run({"xread", "count", "invalid", "streams", "s1", "s2", "0", "0"});
EXPECT_THAT(resp, ErrArg("not an integer or out of range"));
// Missing COUNT value.
resp = Run({"xread", "count"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments for 'xread' command"));
// Invalid BLOCK value.
resp = Run({"xread", "block", "invalid", "streams", "s1", "s2", "0", "0"});
EXPECT_THAT(resp, ErrArg("not an integer or out of range"));
// Missing BLOCK value.
resp = Run({"xread", "block", "streams", "s1", "s2", "0", "0"});
EXPECT_THAT(resp, ErrArg("not an integer or out of range"));
// Missing STREAMS.
resp = Run({"xread", "count", "5"});
EXPECT_THAT(resp, ErrArg("syntax error"));
@ -175,6 +221,11 @@ TEST_F(StreamFamilyTest, XReadInvalidArgs) {
// Unbalanced list of streams.
resp = Run({"xread", "count", "invalid", "streams", "s1", "s2", "s3", "0", "0"});
EXPECT_THAT(resp, ErrArg("syntax error"));
// Wrong type.
Run({"set", "foo", "v"});
resp = Run({"xread", "streams", "foo", "0"});
EXPECT_THAT(resp, ErrArg("key holding the wrong kind of value"));
}
TEST_F(StreamFamilyTest, Issue854) {

View file

@ -1116,6 +1116,9 @@ bool Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_provi
wakeup_requested_.load(memory_order_relaxed) > 0;
};
auto* stats = ServerState::tl_connection_stats();
++stats->num_blocked_clients;
cv_status status = cv_status::no_timeout;
if (tp == time_point::max()) {
DVLOG(1) << "WaitOnWatch foreva " << DebugId();
@ -1131,6 +1134,8 @@ bool Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_provi
DVLOG(1) << "WaitOnWatch await_until " << int(status);
}
--stats->num_blocked_clients;
bool is_expired = (coordinator_state_ & COORD_CANCELLED) || status == cv_status::timeout;
UnwatchBlocking(is_expired, wkeys_provider);
coordinator_state_ &= ~COORD_BLOCKED;