From ec9754150f2df8e012f479eaf0682a0764abf8ce Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Thu, 2 Jun 2022 07:54:34 +0300 Subject: [PATCH] Implement PSUBSCRIBE/PUNSUBSCRIBE commands. Add minimal tests. --- README.md | 8 +- src/facade/dragonfly_connection.cc | 31 +++++-- src/facade/dragonfly_connection.h | 15 ++- src/server/channel_slice.cc | 47 ++++++++-- src/server/channel_slice.h | 18 +++- src/server/conn_context.cc | 99 +++++++++++++++++++- src/server/conn_context.h | 9 +- src/server/dragonfly_test.cc | 18 +++- src/server/main_service.cc | 32 +++++-- src/server/main_service.h | 3 + src/server/test_utils.cc | 141 +++++++++++++++++++++++------ src/server/test_utils.h | 36 ++++---- src/server/zset_family_test.cc | 1 - 13 files changed, 373 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index e6f6f7bb9..33c432e11 100644 --- a/README.md +++ b/README.md @@ -268,15 +268,15 @@ API 2.0 - [X] HSETNX - [X] HVALS - [X] HSCAN -- [ ] PubSub family +- [X] PubSub family - [X] PUBLISH - [ ] PUBSUB - [ ] PUBSUB CHANNELS - [X] SUBSCRIBE - [X] UNSUBSCRIBE - - [ ] PSUBSCRIBE - - [ ] PUNSUBSCRIBE -- [ ] Server Family + - [X] PSUBSCRIBE + - [X] PUNSUBSCRIBE +- [X] Server Family - [ ] WATCH - [ ] UNWATCH - [X] DISCARD diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 89e6df9db..e68497494 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -69,11 +69,11 @@ constexpr size_t kMinReadSize = 256; constexpr size_t kMaxReadSize = 32_KB; struct AsyncMsg { - absl::Span msg_vec; + Connection::PubMessage pub_msg; fibers_ext::BlockingCounter bc; - AsyncMsg(absl::Span vec, fibers_ext::BlockingCounter b) - : msg_vec(vec), bc(move(b)) { + AsyncMsg(const Connection::PubMessage& pmsg, fibers_ext::BlockingCounter b) + : pub_msg(pmsg), bc(move(b)) { } }; @@ -245,15 +245,17 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) { breaker_cb_ = breaker_cb; } -void Connection::SendMsgVecAsync(absl::Span msg_vec, +void Connection::SendMsgVecAsync(const PubMessage& pub_msg, fibers_ext::BlockingCounter bc) { + DCHECK(cc_); + if (cc_->conn_closing) { bc.Dec(); return; } void* ptr = mi_malloc(sizeof(AsyncMsg)); - AsyncMsg* amsg = new (ptr) AsyncMsg(msg_vec, move(bc)); + AsyncMsg* amsg = new (ptr) AsyncMsg(pub_msg, move(bc)); ptr = mi_malloc(sizeof(Request)); Request* req = new (ptr) Request(0, 0); @@ -571,7 +573,24 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) { if (req->async_msg) { ++stats->async_writes_cnt; - builder->SendRawVec(req->async_msg->msg_vec); + + RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; + const PubMessage& pub_msg = req->async_msg->pub_msg; + string_view arr[4]; + + if (pub_msg.pattern.empty()) { + arr[0] = "message"; + arr[1] = pub_msg.channel; + arr[2] = pub_msg.message; + rbuilder->SendStringArr(absl::Span{arr, 3}); + } else { + arr[0] = "pmessage"; + arr[1] = pub_msg.pattern; + arr[2] = pub_msg.channel; + arr[3] = pub_msg.message; + rbuilder->SendStringArr(absl::Span{arr, 4}); + } + req->async_msg->bc.Dec(); req->async_msg->~AsyncMsg(); diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index ea06a7452..dbee2a47e 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -46,11 +46,20 @@ class Connection : public util::Connection { using BreakerCb = std::function; void RegisterOnBreak(BreakerCb breaker_cb); - // This interface is used to pass a raw message directly to the socket via zero-copy interface. + // This interface is used to pass a published message directly to the socket without + // copying strings. // Once the msg is sent "bc" will be decreased so that caller could release the underlying // storage for the message. - void SendMsgVecAsync(absl::Span msg_vec, - util::fibers_ext::BlockingCounter bc); + // virtual - to allow the testing code to override it. + + struct PubMessage { + // if empty - means its a regular message, otherwise it's pmessage. + std::string_view pattern; + std::string_view channel; + std::string_view message; + }; + + virtual void SendMsgVecAsync(const PubMessage& pub_msg, util::fibers_ext::BlockingCounter bc); void SetName(std::string_view name) { CopyCharBuf(name, sizeof(name_), name_); diff --git a/src/server/channel_slice.cc b/src/server/channel_slice.cc index 69772ffaa..93f5ec637 100644 --- a/src/server/channel_slice.cc +++ b/src/server/channel_slice.cc @@ -4,6 +4,10 @@ #include "server/channel_slice.h" +extern "C" { +#include "redis/util.h" +} + namespace dfly { using namespace std; @@ -11,6 +15,14 @@ ChannelSlice::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid) : conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) { } +void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) { + auto [it, added] = channels_.emplace(channel, nullptr); + if (added) { + it->second.reset(new Channel); + } + it->second->subscribers.emplace(me, SubscriberInternal{thread_id}); +} + void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) { auto it = channels_.find(channel); if (it != channels_.end()) { @@ -20,29 +32,52 @@ void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me } } -void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) { - auto [it, added] = channels_.emplace(channel, nullptr); +void ChannelSlice::AddGlobPattern(string_view pattern, ConnectionContext* me, uint32_t thread_id) { + auto [it, added] = patterns_.emplace(pattern, nullptr); if (added) { it->second.reset(new Channel); } it->second->subscribers.emplace(me, SubscriberInternal{thread_id}); } +void ChannelSlice::RemoveGlobPattern(string_view pattern, ConnectionContext* me) { + auto it = patterns_.find(pattern); + if (it != patterns_.end()) { + it->second->subscribers.erase(me); + if (it->second->subscribers.empty()) + patterns_.erase(it); + } +} + auto ChannelSlice::FetchSubscribers(string_view channel) -> vector { vector res; auto it = channels_.find(channel); if (it != channels_.end()) { res.reserve(it->second->subscribers.size()); - for (const auto& k_v : it->second->subscribers) { - Subscriber s(k_v.first, k_v.second.thread_id); - s.borrow_token.Inc(); + CopySubsribers(it->second->subscribers, string{}, &res); + } - res.push_back(std::move(s)); + for (const auto& k_v : patterns_) { + const string& pat = k_v.first; + // 1 - match + if (stringmatchlen(pat.data(), pat.size(), channel.data(), channel.size(), 0) == 1) { + CopySubsribers(k_v.second->subscribers, pat, &res); } } return res; } +void ChannelSlice::CopySubsribers(const SubsribeMap& src, const std::string& pattern, + vector* dest) { + for (const auto& sub : src) { + Subscriber s(sub.first, sub.second.thread_id); + s.pattern = pattern; + s.borrow_token.Inc(); + + dest->push_back(std::move(s)); + } +} + } // namespace dfly diff --git a/src/server/channel_slice.h b/src/server/channel_slice.h index 4876759af..9403aabed 100644 --- a/src/server/channel_slice.h +++ b/src/server/channel_slice.h @@ -19,6 +19,9 @@ class ChannelSlice { util::fibers_ext::BlockingCounter borrow_token; uint32_t thread_id; + // non-empty if was registered via psubscribe + std::string pattern; + Subscriber(ConnectionContext* cntx, uint32_t tid); // Subscriber() : borrow_token(0) {} @@ -31,18 +34,27 @@ class ChannelSlice { std::vector FetchSubscribers(std::string_view channel); - void RemoveSubscription(std::string_view channel, ConnectionContext* me); void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id); + void RemoveSubscription(std::string_view channel, ConnectionContext* me); + + void AddGlobPattern(std::string_view pattern, ConnectionContext* me, uint32_t thread_id); + void RemoveGlobPattern(std::string_view pattern, ConnectionContext* me); private: struct SubscriberInternal { uint32_t thread_id; // proactor thread id. - SubscriberInternal(uint32_t tid) : thread_id(tid) {} + SubscriberInternal(uint32_t tid) : thread_id(tid) { + } }; + using SubsribeMap = absl::flat_hash_map; + + static void CopySubsribers(const SubsribeMap& src, const std::string& pattern, + std::vector* dest); + struct Channel { - absl::flat_hash_map subscribers; + SubsribeMap subscribers; }; absl::flat_hash_map> channels_; diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index d5697ea16..b34358175 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -23,9 +23,11 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis DCHECK(to_add); conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo); + // to be able to read input and still write the output. this->force_dispatch = true; } + // Gather all the channels we need to subsribe to / remove. for (size_t i = 0; i < args.size(); ++i) { bool res = false; string_view channel = ArgS(args, i); @@ -44,13 +46,14 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis } } - if (!to_add && conn_state.subscribe_info->channels.empty()) { + if (!to_add && conn_state.subscribe_info->IsEmpty()) { conn_state.subscribe_info.reset(); force_dispatch = false; } sort(channels.begin(), channels.end()); + // prepare the array in order to distribute the updates to the shards. vector shard_idx(shard_set->size() + 1, 0); for (const auto& k_v : channels) { shard_idx[k_v.first]++; @@ -68,6 +71,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis int32_t tid = util::ProactorBase::GetIndex(); DCHECK_GE(tid, 0); + // Update the subsribers on publisher's side. auto cb = [&](EngineShard* shard) { ChannelSlice& cs = shard->channel_slice(); unsigned start = shard_idx[shard->shard_id()]; @@ -83,6 +87,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis } }; + // Update subscription shard_set->RunBriefInParallel(move(cb), [&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; }); } @@ -90,6 +95,77 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis if (to_reply) { const char* action[2] = {"unsubscribe", "subscribe"}; + for (size_t i = 0; i < result.size(); ++i) { + (*this)->StartArray(3); + (*this)->SendBulkString(action[to_add]); + (*this)->SendBulkString(ArgS(args, i)); // channel + + // number of subsribed channels for this connection *right after* + // we subsribe. + (*this)->SendLong(result[i]); + } + } +} + +void ConnectionContext::ChangePSub(bool to_add, bool to_reply, CmdArgList args) { + vector result(to_reply ? args.size() : 0, 0); + + if (to_add || conn_state.subscribe_info) { + std::vector patterns; + patterns.reserve(args.size()); + + if (!conn_state.subscribe_info) { + DCHECK(to_add); + + conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo); + this->force_dispatch = true; + } + + // Gather all the patterns we need to subsribe to / remove. + for (size_t i = 0; i < args.size(); ++i) { + bool res = false; + string_view pattern = ArgS(args, i); + if (to_add) { + res = conn_state.subscribe_info->patterns.emplace(pattern).second; + } else { + res = conn_state.subscribe_info->patterns.erase(pattern) > 0; + } + + if (to_reply) + result[i] = conn_state.subscribe_info->patterns.size(); + + if (res) { + patterns.emplace_back(pattern); + } + } + + if (!to_add && conn_state.subscribe_info->IsEmpty()) { + conn_state.subscribe_info.reset(); + force_dispatch = false; + } + + int32_t tid = util::ProactorBase::GetIndex(); + DCHECK_GE(tid, 0); + + // Update the subsribers on publisher's side. + auto cb = [&](EngineShard* shard) { + ChannelSlice& cs = shard->channel_slice(); + for (string_view pattern : patterns) { + if (to_add) { + cs.AddGlobPattern(pattern, this, tid); + } else { + cs.RemoveGlobPattern(pattern, this); + } + } + }; + + // Update pattern subscription. Run on all shards. + shard_set->RunBriefInParallel(move(cb)); + } + + if (to_reply) { + const char* action[2] = {"punsubscribe", "psubscribe"}; + for (size_t i = 0; i < result.size(); ++i) { (*this)->StartArray(3); (*this)->SendBulkString(action[to_add]); @@ -100,18 +176,35 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis } void ConnectionContext::OnClose() { - if (conn_state.subscribe_info) { + if (!conn_state.subscribe_info) + return; + + if (!conn_state.subscribe_info->channels.empty()) { StringVec channels(conn_state.subscribe_info->channels.begin(), conn_state.subscribe_info->channels.end()); CmdArgVec arg_vec(channels.begin(), channels.end()); auto token = conn_state.subscribe_info->borrow_token; ChangeSubscription(false, false, CmdArgList{arg_vec}); - DCHECK(!conn_state.subscribe_info); // Check that all borrowers finished processing token.Wait(); } + + if (conn_state.subscribe_info) { + DCHECK(!conn_state.subscribe_info->patterns.empty()); + + StringVec patterns(conn_state.subscribe_info->patterns.begin(), + conn_state.subscribe_info->patterns.end()); + CmdArgVec arg_vec(patterns.begin(), patterns.end()); + + auto token = conn_state.subscribe_info->borrow_token; + ChangePSub(false, false, CmdArgList{arg_vec}); + + // Check that all borrowers finished processing + token.Wait(); + DCHECK(!conn_state.subscribe_info); + } } } // namespace dfly diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 8533bbd8c..f6af3876e 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -50,10 +50,16 @@ struct ConnectionState { struct SubscribeInfo { // TODO: to provide unique_strings across service. This will allow us to use string_view here. absl::flat_hash_set channels; + absl::flat_hash_set patterns; util::fibers_ext::BlockingCounter borrow_token; - SubscribeInfo() : borrow_token(0) {} + bool IsEmpty() const { + return channels.empty() && patterns.empty(); + } + + SubscribeInfo() : borrow_token(0) { + } }; std::unique_ptr subscribe_info; @@ -85,6 +91,7 @@ class ConnectionContext : public facade::ConnectionContext { } void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args); + void ChangePSub(bool to_add, bool to_reply, CmdArgList args); bool is_replicating = false; }; diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index 2c401a606..9d4d95617 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -22,13 +22,13 @@ extern "C" { namespace dfly { -using namespace absl; -using namespace boost; using namespace std; using namespace util; using ::io::Result; using testing::ElementsAre; using testing::HasSubstr; +using absl::StrCat; +namespace this_fiber = boost::this_fiber; namespace { @@ -411,6 +411,20 @@ TEST_F(DflyEngineTest, OOM) { } } +TEST_F(DflyEngineTest, PSubscribe) { + auto resp = pp_->at(1)->Await([&] { return Run({"psubscribe", "a*", "b*"}); }); + EXPECT_THAT(resp, ArrLen(3)); + resp = pp_->at(0)->Await([&] { return Run({"publish", "ab", "foo"}); }); + EXPECT_THAT(resp, IntArg(1)); + + ASSERT_EQ(1, SubsriberMessagesLen("IO1")); + + facade::Connection::PubMessage msg = GetPublishedMessage("IO1", 0); + EXPECT_EQ("foo", msg.message); + EXPECT_EQ("ab", msg.channel); + EXPECT_EQ("a*", msg.pattern); +} + // TODO: to test transactions with a single shard since then all transactions become local. // To consider having a parameter in dragonfly engine controlling number of shards // unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case. diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 3ca6bec28..4c70d497b 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -912,23 +912,22 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { } fibers_ext::BlockingCounter bc(subsriber_arr.size()); - char prefix[] = "*3\r\n$7\r\nmessage\r\n$"; - char msg_size[32] = {0}; - char channel_size[32] = {0}; - absl::SNPrintF(msg_size, sizeof(msg_size), "%u\r\n", message.size()); - absl::SNPrintF(channel_size, sizeof(channel_size), "%u\r\n", channel.size()); - - string_view msg_arr[] = {prefix, channel_size, channel, "\r\n$", msg_size, message, "\r\n"}; - auto publish_cb = [&, bc](unsigned idx, util::ProactorBase*) mutable { unsigned start = slices[idx]; for (unsigned i = start; i < subsriber_arr.size(); ++i) { - if (subsriber_arr[i].thread_id != idx) + const ChannelSlice::Subscriber& subscriber = subsriber_arr[i]; + if (subscriber.thread_id != idx) break; published.fetch_add(1, memory_order_relaxed); - subsriber_arr[i].conn_cntx->owner()->SendMsgVecAsync(msg_arr, bc); + facade::Connection* conn = subsriber_arr[i].conn_cntx->owner(); + DCHECK(conn); + facade::Connection::PubMessage pmsg; + pmsg.channel = channel; + pmsg.message = message; + pmsg.pattern = subscriber.pattern; + conn->SendMsgVecAsync(pmsg, bc); } }; @@ -959,6 +958,17 @@ void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) { cntx->ChangeSubscription(false, true, std::move(args)); } +void Service::PSubscribe(CmdArgList args, ConnectionContext* cntx) { + args.remove_prefix(1); + cntx->ChangePSub(true, true, args); +} + +void Service::PUnsubscribe(CmdArgList args, ConnectionContext* cntx) { + args.remove_prefix(1); + + cntx->ChangePSub(false, true, args); +} + // Not a real implementation. Serves as a decorator to accept some function commands // for testing. void Service::Function(CmdArgList args, ConnectionContext* cntx) { @@ -1024,6 +1034,8 @@ void Service::RegisterCommands() { << CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish) << CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe) << CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe) + << CI{"PSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PSubscribe) + << CI{"PUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PUnsubscribe) << CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function); StringFamily::Register(®istry_); diff --git a/src/server/main_service.h b/src/server/main_service.h index c3ef1fcac..cb40ea3dd 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -94,6 +94,8 @@ class Service : public facade::ServiceInterface { void Publish(CmdArgList args, ConnectionContext* cntx); void Subscribe(CmdArgList args, ConnectionContext* cntx); void Unsubscribe(CmdArgList args, ConnectionContext* cntx); + void PSubscribe(CmdArgList args, ConnectionContext* cntx); + void PUnsubscribe(CmdArgList args, ConnectionContext* cntx); void Function(CmdArgList args, ConnectionContext* cntx); struct EvalArgs { @@ -113,6 +115,7 @@ class Service : public facade::ServiceInterface { ServerFamily server_family_; CommandRegistry registry_; absl::flat_hash_map unknown_cmds_; + mutable ::boost::fibers::mutex mu_; GlobalState global_state_ = GlobalState::ACTIVE; // protected by mu_; diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index 6c5c2a6f8..c90651640 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -37,9 +37,68 @@ static vector SplitLines(const std::string& src) { return res; } +TestConnection::TestConnection(Protocol protocol) + : facade::Connection(protocol, nullptr, nullptr, nullptr) { +} + +void TestConnection::SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) { + backing_str_.emplace_back(new string(pmsg.channel)); + PubMessage dest; + dest.channel = *backing_str_.back(); + + backing_str_.emplace_back(new string(pmsg.message)); + dest.message = *backing_str_.back(); + + if (!pmsg.pattern.empty()) { + backing_str_.emplace_back(new string(pmsg.pattern)); + dest.pattern = *backing_str_.back(); + } + messages.push_back(dest); + + bc.Dec(); +} + +class BaseFamilyTest::TestConnWrapper { + public: + TestConnWrapper(Protocol proto); + ~TestConnWrapper(); + + CmdArgVec Args(ArgSlice list); + + RespVec ParseResponse(); + + // returns: type(pmessage), pattern, channel, message. + facade::Connection::PubMessage GetPubMessage(size_t index) const; + + ConnectionContext* cmd_cntx() { + return &cmd_cntx_; + } + + StringVec SplitLines() const { + return dfly::SplitLines(sink_.str()); + } + + void ClearSink() { + sink_.Clear(); + } + + TestConnection* conn() { + return dummy_conn_.get(); + } + + private: + ::io::StringSink sink_; // holds the response blob + + std::unique_ptr dummy_conn_; + + ConnectionContext cmd_cntx_; + std::vector> tmp_str_vec_; + + std::unique_ptr parser_; +}; + BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto) - : dummy_conn(new facade::Connection(proto, nullptr, nullptr, nullptr)), - cmd_cntx(&sink, dummy_conn.get()) { + : dummy_conn_(new TestConnection(proto)), cmd_cntx_(&sink_, dummy_conn_.get()) { } BaseFamilyTest::TestConnWrapper::~TestConnWrapper() { @@ -102,22 +161,22 @@ RespExpr BaseFamilyTest::Run(ArgSlice list) { } RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { - TestConnWrapper* conn = AddFindConn(Protocol::REDIS, id); + TestConnWrapper* conn_wrapper = AddFindConn(Protocol::REDIS, id); - CmdArgVec args = conn->Args(slice); + CmdArgVec args = conn_wrapper->Args(slice); - auto& context = conn->cmd_cntx; + auto* context = conn_wrapper->cmd_cntx(); - DCHECK(context.transaction == nullptr); + DCHECK(context->transaction == nullptr); - service_->DispatchCommand(CmdArgList{args}, &context); + service_->DispatchCommand(CmdArgList{args}, context); - DCHECK(context.transaction == nullptr); + DCHECK(context->transaction == nullptr); unique_lock lk(mu_); - last_cmd_dbg_info_ = context.last_command_debug; + last_cmd_dbg_info_ = context->last_command_debug; - RespVec vec = conn->ParseResponse(); + RespVec vec = conn_wrapper->ParseResponse(); if (vec.size() == 1) return vec.front(); RespVec* new_vec = new RespVec(vec); @@ -144,15 +203,15 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); - auto& context = conn->cmd_cntx; + auto* context = conn->cmd_cntx(); - DCHECK(context.transaction == nullptr); + DCHECK(context->transaction == nullptr); - service_->DispatchMC(cmd, value, &context); + service_->DispatchMC(cmd, value, context); - DCHECK(context.transaction == nullptr); + DCHECK(context->transaction == nullptr); - return SplitLines(conn->sink.str()); + return conn->SplitLines(); } auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResponse { @@ -165,11 +224,11 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp cmd.key = key; TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); - auto& context = conn->cmd_cntx; + auto* context = conn->cmd_cntx(); - service_->DispatchMC(cmd, string_view{}, &context); + service_->DispatchMC(cmd, string_view{}, context); - return SplitLines(conn->sink.str()); + return conn->SplitLines(); } auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list list) @@ -191,11 +250,11 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_listcmd_cntx; + auto* context = conn->cmd_cntx(); - service_->DispatchMC(cmd, string_view{}, &context); + service_->DispatchMC(cmd, string_view{}, context); - return SplitLines(conn->sink.str()); + return conn->SplitLines(); } int64_t BaseFamilyTest::CheckedInt(std::initializer_list list) { @@ -222,8 +281,8 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) { if (v.empty()) { res.push_back(MutableSlice{}); } else { - tmp_str_vec.emplace_back(new string{v}); - auto& s = *tmp_str_vec.back(); + tmp_str_vec_.emplace_back(new string{v}); + auto& s = *tmp_str_vec_.back(); res.emplace_back(s.data(), s.size()); } @@ -233,19 +292,24 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) { } RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() { - tmp_str_vec.emplace_back(new string{sink.str()}); - auto& s = *tmp_str_vec.back(); + tmp_str_vec_.emplace_back(new string{sink_.str()}); + auto& s = *tmp_str_vec_.back(); auto buf = RespExpr::buffer(&s); uint32_t consumed = 0; - parser.reset(new RedisParser{false}); // Client mode. + parser_.reset(new RedisParser{false}); // Client mode. RespVec res; - RedisParser::Result st = parser->Parse(buf, &consumed, &res); + RedisParser::Result st = parser_->Parse(buf, &consumed, &res); CHECK_EQ(RedisParser::OK, st); return res; } +facade::Connection::PubMessage BaseFamilyTest::TestConnWrapper::GetPubMessage(size_t index) const { + CHECK_LT(index, dummy_conn_->messages.size()); + return dummy_conn_->messages[index]; +} + bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const { ShardId sid = Shard(key, shard_set->size()); KeyLockArgs args; @@ -263,11 +327,30 @@ string BaseFamilyTest::GetId() const { return absl::StrCat("IO", id); } +size_t BaseFamilyTest::SubsriberMessagesLen(string_view conn_id) const { + auto it = connections_.find(conn_id); + if (it == connections_.end()) + return 0; + + return it->second->conn()->messages.size(); +} + +facade::Connection::PubMessage BaseFamilyTest::GetPublishedMessage(string_view conn_id, + size_t index) const { + facade::Connection::PubMessage res; + + auto it = connections_.find(conn_id); + if (it == connections_.end()) + return res; + + return it->second->GetPubMessage(index); +} + ConnectionContext::DebugInfo BaseFamilyTest::GetDebugInfo(const std::string& id) const { auto it = connections_.find(id); CHECK(it != connections_.end()); - return it->second->cmd_cntx.last_command_debug; + return it->second->cmd_cntx()->last_command_debug; } auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestConnWrapper* { @@ -278,7 +361,7 @@ auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestCon if (inserted) { it->second.reset(new TestConnWrapper(proto)); } else { - it->second->sink.Clear(); + it->second->ClearSink(); } return it->second.get(); } diff --git a/src/server/test_utils.h b/src/server/test_utils.h index 6011897aa..0bfbbb687 100644 --- a/src/server/test_utils.h +++ b/src/server/test_utils.h @@ -6,6 +6,7 @@ #include +#include "facade/dragonfly_connection.h" #include "facade/memcache_parser.h" #include "facade/redis_parser.h" #include "io/io.h" @@ -16,6 +17,18 @@ namespace dfly { using namespace facade; +class TestConnection : public facade::Connection { + public: + TestConnection(Protocol protocol); + + void SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) final; + + std::vector messages; + + private: + std::vector> backing_str_; +}; + class BaseFamilyTest : public ::testing::Test { protected: BaseFamilyTest(); @@ -27,23 +40,7 @@ class BaseFamilyTest : public ::testing::Test { void TearDown() override; protected: - struct TestConnWrapper { - ::io::StringSink sink; // holds the response blob - - std::unique_ptr dummy_conn; - - ConnectionContext cmd_cntx; - std::vector> tmp_str_vec; - - std::unique_ptr parser; - - TestConnWrapper(Protocol proto); - ~TestConnWrapper(); - - CmdArgVec Args(ArgSlice list); - - RespVec ParseResponse(); - }; + class TestConnWrapper; RespExpr Run(std::initializer_list list) { return Run(ArgSlice{list.begin(), list.size()}); @@ -75,6 +72,11 @@ class BaseFamilyTest : public ::testing::Test { void UpdateTime(uint64_t ms); std::string GetId() const; + size_t SubsriberMessagesLen(std::string_view conn_id) const; + + // Returns message parts as returned by RESP: + // pmessage, pattern, channel, message + facade::Connection::PubMessage GetPublishedMessage(std::string_view conn_id, size_t index) const; std::unique_ptr pp_; std::unique_ptr service_; diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 5d322ef45..a77030906 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -13,7 +13,6 @@ using namespace testing; using namespace std; using namespace util; -using namespace boost; namespace dfly {