diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index da6f1e4e4..9f238dd7e 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -78,17 +78,11 @@ bool MatchHttp11Line(string_view line) { constexpr size_t kMinReadSize = 256; constexpr size_t kMaxReadSize = 32_KB; -#ifdef ABSL_HAVE_ADDRESS_SANITIZER -constexpr size_t kReqStorageSize = 88; -#else -constexpr size_t kReqStorageSize = 120; -#endif - thread_local uint32_t free_req_release_weight = 0; } // namespace -thread_local vector Connection::pipeline_req_pool_; +thread_local vector Connection::pipeline_req_pool_; struct Connection::Shutdown { absl::flat_hash_map map; @@ -104,77 +98,21 @@ struct Connection::Shutdown { } }; -// Used as custom deleter for Request object -struct Connection::RequestDeleter { - void operator()(Request* req) const; -}; - -using PubMessage = Connection::PubMessage; -using MonitorMessage = std::string; - -// Please note: The call to the Dtor is mandatory for this!! -// This class contain types that don't have trivial destructed objects -class Connection::Request { - public: - struct PipelineMessage { - // mi_stl_allocator uses mi heap internally. - // The capacity is chosen so that we allocate a fully utilized (256 bytes) block. - using StorageType = absl::InlinedVector>; - - PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) { - } - - void Reset(size_t nargs, size_t capacity); - - absl::InlinedVector args; - StorageType storage; - }; - - using MessagePayload = std::variant; - - // Overload to create the a new pipeline message - static RequestPtr New(mi_heap_t* heap, const RespVec& args, size_t capacity); - - // Overload to create a new pubsub message - static RequestPtr New(PubMessage pub_msg); - - // Overload to create a new the monitor message - static RequestPtr New(MonitorMessage msg); - - void Emplace(const RespVec& args, size_t capacity); - - size_t StorageCapacity() const; - - bool IsPipelineMsg() const; - - private: - static constexpr size_t kSizeOfPipelineMsg = sizeof(PipelineMessage); - - Request(size_t nargs, size_t capacity) : payload(PipelineMessage{nargs, capacity}) { - } - - Request(PubMessage msg) : payload(move(msg)) { - } - - Request(MonitorMessage msg) : payload(move(msg)) { - } - - Request(const Request&) = delete; - - // Store arguments for pipeline message. - void SetArgs(const RespVec& args); - - public: - MessagePayload payload; -}; - -Connection::PubMessage::PubMessage(string pattern, shared_ptr channel, - shared_ptr message) - : type{kPublish}, pattern{move(pattern)}, channel{move(channel)}, message{move(message)} { +Connection::PubMessage::PubMessage(string pattern, shared_ptr buf, size_t channel_len, + size_t message_len) + : data{MessageData{pattern, move(buf), channel_len, message_len}} { } -Connection::PubMessage::PubMessage(bool add, shared_ptr channel, uint32_t channel_cnt) - : type{add ? kSubscribe : kUnsubscribe}, channel{move(channel)}, channel_cnt{channel_cnt} { +Connection::PubMessage::PubMessage(bool add, string_view channel, uint32_t channel_cnt) + : data{SubscribeData{add, string{channel}, channel_cnt}} { +} + +string_view Connection::PubMessage::MessageData::Channel() const { + return {buf.get(), channel_len}; +} + +string_view Connection::PubMessage::MessageData::Message() const { + return {buf.get() + channel_len, message_len}; } struct Connection::DispatchOperations { @@ -183,90 +121,67 @@ struct Connection::DispatchOperations { } void operator()(const PubMessage& msg); - void operator()(Request::PipelineMessage& msg); + void operator()(Connection::PipelineMessage& msg); void operator()(const MonitorMessage& msg); + template void operator()(unique_ptr& ptr) { + operator()(*ptr.get()); + } + ConnectionStats* stats = nullptr; SinkReplyBuilder* builder = nullptr; Connection* self = nullptr; }; -Connection::RequestPtr Connection::Request::New(MonitorMessage msg) { - void* ptr = mi_malloc(sizeof(Request)); - Request* req = new (ptr) Request(move(msg)); - return Connection::RequestPtr{req, Connection::RequestDeleter{}}; -} - -Connection::RequestPtr Connection::Request::New(mi_heap_t* heap, const RespVec& args, - size_t capacity) { - constexpr auto kReqSz = sizeof(Request); - void* ptr = mi_heap_malloc_small(heap, kReqSz); - - // We must construct in place here, since there is a slice that uses memory locations - Request* req = new (ptr) Request(args.size(), capacity); - req->SetArgs(args); - - return Connection::RequestPtr{req, Connection::RequestDeleter{}}; -} - -Connection::RequestPtr Connection::Request::New(PubMessage pub_msg) { - // This will generate a new request for pubsub message - // Please note that unlike the above case, we don't need to "protect", the internals here - // since we are currently using a borrow token for it - i.e. the BlockingCounter will - // ensure that the message is not deleted until we are finish sending it at the other - // side of the queue - void* ptr = mi_malloc(sizeof(Request)); - Request* req = new (ptr) Request(move(pub_msg)); - return Connection::RequestPtr{req, Connection::RequestDeleter{}}; -} - -void Connection::Request::SetArgs(const RespVec& args) { - // At this point we know that we have PipelineMessage in Request so next op is safe. - PipelineMessage& pipeline_msg = std::get(payload); - auto* next = pipeline_msg.storage.data(); +void Connection::PipelineMessage::SetArgs(const RespVec& args) { + auto* next = storage.data(); for (size_t i = 0; i < args.size(); ++i) { auto buf = args[i].GetBuf(); size_t s = buf.size(); memcpy(next, buf.data(), s); - pipeline_msg.args[i] = MutableSlice(next, s); + this->args[i] = MutableSlice(next, s); next += s; } } -void Connection::RequestDeleter::operator()(Request* req) const { - req->~Request(); - mi_free(req); +void Connection::MessageDeleter::operator()(PipelineMessage* msg) const { + msg->~PipelineMessage(); + mi_free(msg); } -void Connection::Request::Emplace(const RespVec& args, size_t capacity) { - PipelineMessage* msg = get_if(&payload); - if (msg) { - msg->Reset(args.size(), capacity); - } else { - payload = PipelineMessage{args.size(), capacity}; - } - SetArgs(args); +void Connection::MessageDeleter::operator()(PubMessage* msg) const { + msg->~PubMessage(); + mi_free(msg); } -void Connection::Request::PipelineMessage::Reset(size_t nargs, size_t capacity) { +void Connection::PipelineMessage::Reset(size_t nargs, size_t capacity) { storage.resize(capacity); args.resize(nargs); } -template struct Overloaded : Ts... { using Ts::operator()...; }; -template Overloaded(Ts...) -> Overloaded; - -size_t Connection::Request::StorageCapacity() const { - return std::visit(Overloaded{[](const PubMessage& msg) -> size_t { return 0; }, - [](const PipelineMessage& arg) -> size_t { - return arg.storage.capacity() + arg.args.capacity(); - }, - [](const MonitorMessage& arg) -> size_t { return arg.capacity(); }}, - payload); +size_t Connection::PipelineMessage::StorageCapacity() const { + return storage.capacity() + args.capacity(); } -bool Connection::Request::IsPipelineMsg() const { - return std::get_if(&payload) != nullptr; +template struct Overloaded : Ts... { + using Ts::operator()...; + + template size_t operator()(const unique_ptr& ptr) { + return operator()(*ptr.get()); + } +}; + +template Overloaded(Ts...) -> Overloaded; + +size_t Connection::MessageHandle::StorageCapacity() const { + auto pub_size = [](const PubMessage& msg) -> size_t { return 0; }; + auto msg_size = [](const PipelineMessage& arg) -> size_t { return arg.StorageCapacity(); }; + auto monitor_size = [](const MonitorMessage& arg) -> size_t { return 0; }; + return visit(Overloaded{pub_size, msg_size, monitor_size}, this->handle); +} + +bool Connection::MessageHandle::IsPipelineMsg() const { + return get_if(&this->handle) != nullptr; } void Connection::DispatchOperations::operator()(const MonitorMessage& msg) { @@ -277,33 +192,31 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) { void Connection::DispatchOperations::operator()(const PubMessage& pub_msg) { RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; ++stats->async_writes_cnt; - string_view arr[4]; - if (pub_msg.type == PubMessage::kPublish) { - if (pub_msg.pattern.empty()) { - DVLOG(1) << "Sending message, from channel: " << *pub_msg.channel << " " << *pub_msg.message; - arr[0] = "message"; - arr[1] = *pub_msg.channel; - arr[2] = *pub_msg.message; - rbuilder->SendStringArr(absl::Span{arr, 3}, - RedisReplyBuilder::CollectionType::PUSH); + auto send_msg = [rbuilder](const PubMessage::MessageData& data) { + unsigned i = 0; + string_view arr[4]; + if (data.pattern.empty()) { + arr[i++] = "message"; } else { - arr[0] = "pmessage"; - arr[1] = pub_msg.pattern; - arr[2] = *pub_msg.channel; - arr[3] = *pub_msg.message; - rbuilder->SendStringArr(absl::Span{arr, 4}, - RedisReplyBuilder::CollectionType::PUSH); + arr[i++] = "pmessage"; + arr[i++] = data.pattern; } - } else { + arr[i++] = data.Channel(); + arr[i++] = data.Message(); + rbuilder->SendStringArr(absl::Span{arr, i}, + RedisReplyBuilder::CollectionType::PUSH); + }; + auto send_sub = [rbuilder](const PubMessage::SubscribeData& data) { const char* action[2] = {"unsubscribe", "subscribe"}; rbuilder->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH); - rbuilder->SendBulkString(action[pub_msg.type == PubMessage::kSubscribe]); - rbuilder->SendBulkString(*pub_msg.channel); - rbuilder->SendLong(pub_msg.channel_cnt); - } + rbuilder->SendBulkString(action[data.add]); + rbuilder->SendBulkString(data.channel); + rbuilder->SendLong(data.channel_cnt); + }; + visit(Overloaded{send_msg, send_sub}, pub_msg.data); } -void Connection::DispatchOperations::operator()(Request::PipelineMessage& msg) { +void Connection::DispatchOperations::operator()(Connection::PipelineMessage& msg) { ++stats->pipelined_cmd_cnt; self->pipeline_msg_cnt_--; @@ -324,7 +237,7 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, protocol_ = protocol; - constexpr size_t kReqSz = sizeof(Connection::Request); + constexpr size_t kReqSz = sizeof(Connection::PipelineMessage); static_assert(kReqSz <= 256 && kReqSz >= 232); switch (protocol) { @@ -461,20 +374,6 @@ void Connection::RegisterBreakHook(BreakerCb breaker_cb) { breaker_cb_ = breaker_cb; } -void Connection::SendPubMessageAsync(PubMessage pub_msg) { - DCHECK(cc_); - - if (cc_->conn_closing) { - return; - } - - RequestPtr req = Request::New(move(pub_msg)); - dispatch_q_.push_back(std::move(req)); - if (dispatch_q_.size() == 1) { - evc_.notify(); - } -} - std::string Connection::LocalBindAddress() const { LinuxSocketBase* lsb = static_cast(socket_.get()); auto le = lsb->LocalEndpoint(); @@ -668,14 +567,10 @@ auto Connection::ParseRedis() -> ParserStatus { last_interaction_ = time(nullptr); } else { // Dispatch via queue to speedup input reading. - RequestPtr req = FromArgs(std::move(tmp_parse_args_), tlh); + SendAsync(MessageHandle{FromArgs(move(tmp_parse_args_), tlh)}); ++pipeline_msg_cnt_; - dispatch_q_.push_back(std::move(req)); - if (dispatch_q_.size() == 1) { - evc_.notify(); - } else if (dispatch_q_.size() > 10) { + if (dispatch_q_.size() > 10) ThisFiber::Yield(); - } } } io_buf_.ConsumeInput(consumed); @@ -843,13 +738,15 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) { if (cc_->conn_closing) break; - RequestPtr req{std::move(dispatch_q_.front())}; + MessageHandle msg = move(dispatch_q_.front()); dispatch_q_.pop_front(); - std::visit(dispatch_op, req->payload); + std::visit(dispatch_op, msg.handle); - if (req->IsPipelineMsg() && stats_->pipeline_cache_capacity < request_cache_limit) { - stats_->pipeline_cache_capacity += req->StorageCapacity(); - pipeline_req_pool_.push_back(std::move(req)); + if (auto* pipe = get_if(&msg.handle); pipe) { + if (stats_->pipeline_cache_capacity < request_cache_limit) { + stats_->pipeline_cache_capacity += (*pipe)->StorageCapacity(); + pipeline_req_pool_.push_back(move(*pipe)); + } } } @@ -859,7 +756,7 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) { dispatch_q_.clear(); } -auto Connection::FromArgs(RespVec args, mi_heap_t* heap) -> RequestPtr { +Connection::PipelineMessagePtr Connection::FromArgs(RespVec args, mi_heap_t* heap) { DCHECK(!args.empty()); size_t backed_sz = 0; for (const auto& arg : args) { @@ -868,18 +765,21 @@ auto Connection::FromArgs(RespVec args, mi_heap_t* heap) -> RequestPtr { } DCHECK(backed_sz); - constexpr auto kReqSz = sizeof(Request); + constexpr auto kReqSz = sizeof(PipelineMessage); static_assert(kReqSz < MI_SMALL_SIZE_MAX); - static_assert(alignof(Request) == 8); + static_assert(alignof(PipelineMessage) == 8); - RequestPtr req; - - if (req = GetFromPipelinePool(); req) { - req->Emplace(move(args), backed_sz); + PipelineMessagePtr ptr; + if (ptr = GetFromPipelinePool(); ptr) { + ptr->Reset(args.size(), backed_sz); } else { - req = Request::New(heap, args, backed_sz); + void* heap_ptr = mi_heap_malloc_small(heap, sizeof(PipelineMessage)); + // We must construct in place here, since there is a slice that uses memory locations + ptr.reset(new (heap_ptr) PipelineMessage(args.size(), backed_sz)); } - return req; + + ptr->SetArgs(args); + return ptr; } void Connection::ShrinkPipelinePool() { @@ -899,15 +799,15 @@ void Connection::ShrinkPipelinePool() { } } -Connection::RequestPtr Connection::GetFromPipelinePool() { +Connection::PipelineMessagePtr Connection::GetFromPipelinePool() { if (pipeline_req_pool_.empty()) - return {}; + return nullptr; free_req_release_weight = 0; // Reset the release weight. - RequestPtr req = move(pipeline_req_pool_.back()); - stats_->pipeline_cache_capacity -= req->StorageCapacity(); + auto ptr = move(pipeline_req_pool_.back()); + stats_->pipeline_cache_capacity -= ptr->StorageCapacity(); pipeline_req_pool_.pop_back(); - return req; + return ptr; } void Connection::ShutdownSelf() { @@ -927,15 +827,24 @@ void RespToArgList(const RespVec& src, CmdArgVec* dest) { } } -void Connection::SendMonitorMessageAsync(std::string monitor_msg) { +void Connection::SendPubMessageAsync(PubMessage msg) { + void* ptr = mi_malloc(sizeof(PubMessage)); + SendAsync({PubMessagePtr{new (ptr) PubMessage{move(msg)}, MessageDeleter{}}}); +} + +void Connection::SendMonitorMessageAsync(string msg) { + SendAsync({MonitorMessage{move(msg)}}); +} + +void Connection::SendAsync(MessageHandle msg) { DCHECK(cc_); - if (!cc_->conn_closing) { - RequestPtr req = Request::New(std::move(monitor_msg)); - dispatch_q_.push_back(std::move(req)); - if (dispatch_q_.size() == 1) { - evc_.notify(); - } + if (cc_->conn_closing) + return; + + dispatch_q_.push_back(move(msg)); + if (dispatch_q_.size() == 1) { + evc_.notify(); } } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index d7e783d60..4427ec99c 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -31,6 +32,12 @@ typedef struct mi_heap_s mi_heap_t; #define SO_INCOMING_NAPI_ID 56 #endif +#ifdef ABSL_HAVE_ADDRESS_SANITIZER +constexpr size_t kReqStorageSize = 88; +#else +constexpr size_t kReqStorageSize = 120; +#endif + namespace facade { class ConnectionContext; @@ -55,21 +62,65 @@ class Connection : public util::Connection { // PubSub message, either incoming message for active subscription or reply for new subscription. struct PubMessage { - enum Type { kSubscribe, kUnsubscribe, kPublish } type; + // Represents incoming message. + struct MessageData { + std::string pattern{}; // non-empty for pattern subscriber + std::shared_ptr buf; // stores channel name and message + size_t channel_len, message_len; // lengths in buf - std::string pattern{}; // non-empty for pattern subscriber - std::shared_ptr channel{}; - std::shared_ptr message{}; + std::string_view Channel() const; + std::string_view Message() const; + }; - uint32_t channel_cnt = 0; + // Represents reply for subscribe/unsubscribe. + struct SubscribeData { + bool add; + std::string channel; + uint32_t channel_cnt; + }; - PubMessage(bool add, std::shared_ptr channel, uint32_t channel_cnt); - PubMessage(std::string pattern, std::shared_ptr channel, - std::shared_ptr message); + std::variant data; - PubMessage(const PubMessage&) = delete; - PubMessage& operator=(const PubMessage&) = delete; - PubMessage(PubMessage&&) = default; + PubMessage(bool add, std::string_view channel, uint32_t channel_cnt); + PubMessage(std::string pattern, std::shared_ptr buf, size_t channel_len, + size_t message_len); + }; + + struct MonitorMessage : public std::string {}; + + struct PipelineMessage { + PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) { + } + + void Reset(size_t nargs, size_t capacity); + + void SetArgs(const RespVec& args); + + size_t StorageCapacity() const; + + // mi_stl_allocator uses mi heap internally. + // The capacity is chosen so that we allocate a fully utilized (256 bytes) block. + using StorageType = absl::InlinedVector>; + + absl::InlinedVector args; + StorageType storage; + }; + + struct MessageDeleter { + void operator()(PipelineMessage* msg) const; + void operator()(PubMessage* msg) const; + }; + + // Requests are allocated on the mimalloc heap and thus require a custom deleter. + using PipelineMessagePtr = std::unique_ptr; + using PubMessagePtr = std::unique_ptr; + + struct MessageHandle { + size_t StorageCapacity() const; + + bool IsPipelineMsg() const; + + std::variant handle; }; enum Phase { READ_SOCKET, PROCESS }; @@ -77,10 +128,10 @@ class Connection : public util::Connection { public: // Add PubMessage to dispatch queue. // Virtual because behaviour is overwritten in test_utils. - virtual void SendPubMessageAsync(PubMessage pub_msg); + virtual void SendPubMessageAsync(PubMessage); // Add monitor message to dispatch queue. - void SendMonitorMessageAsync(std::string monitor_msg); + void SendMonitorMessageAsync(std::string); // Register hook that is executed on connection shutdown. ShutdownHandle RegisterShutdownHook(ShutdownCb cb); @@ -121,15 +172,10 @@ class Connection : public util::Connection { private: enum ParserStatus { OK, NEED_MORE, ERROR }; - class Request; struct DispatchOperations; struct DispatchCleanup; - struct RequestDeleter; struct Shutdown; - // Requests are allocated on the mimalloc heap and thus require a custom deleter. - using RequestPtr = std::unique_ptr; - private: // Check protocol and handle connection. void HandleRequests() final; @@ -146,8 +192,10 @@ class Connection : public util::Connection { // Handles events from dispatch queue. void DispatchFiber(util::FiberSocketBase* peer); + void SendAsync(MessageHandle msg); + // Create new pipeline request, re-use from pool when possible. - RequestPtr FromArgs(RespVec args, mi_heap_t* heap); + PipelineMessagePtr FromArgs(RespVec args, mi_heap_t* heap); ParserStatus ParseRedis(); ParserStatus ParseMemcache(); @@ -158,11 +206,11 @@ class Connection : public util::Connection { void ShrinkPipelinePool(); // Returns non-null request ptr if pool has vacant entries. - RequestPtr GetFromPipelinePool(); + PipelineMessagePtr GetFromPipelinePool(); private: - std::deque dispatch_q_; // dispatch queue - dfly::EventCount evc_; // dispatch queue waker + std::deque dispatch_q_; // dispatch queue + dfly::EventCount evc_; // dispatch queue waker base::IoBuf io_buf_; // used in io loop and parsers std::unique_ptr redis_parser_; @@ -198,7 +246,7 @@ class Connection : public util::Connection { // Pooled pipieline messages per-thread. // Aggregated while handling pipelines, // graudally released while handling regular commands. - static thread_local std::vector pipeline_req_pool_; + static thread_local std::vector pipeline_req_pool_; }; void RespToArgList(const RespVec& src, CmdArgVec* dest); diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 90b1fe766..6135ec317 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -127,7 +127,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis if (to_reply) { for (size_t i = 0; i < result.size(); ++i) { - owner()->SendPubMessageAsync({to_add, make_shared(ArgS(args, i)), result[i]}); + owner()->SendPubMessageAsync({to_add, ArgS(args, i), result[i]}); } } } diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index ed8c3acac..c6316e0c0 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -403,9 +403,9 @@ TEST_F(DflyEngineTest, PSubscribe) { ASSERT_EQ(1, SubscriberMessagesLen("IO1")); - const facade::Connection::PubMessage& msg = GetPublishedMessage("IO1", 0); - EXPECT_EQ("foo", *msg.message); - EXPECT_EQ("ab", *msg.channel); + const auto& msg = GetPublishedMessage("IO1", 0); + EXPECT_EQ("foo", msg.Message()); + EXPECT_EQ("ab", msg.Channel()); EXPECT_EQ("a*", msg.pattern); } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 6e2bbfc9e..bec534ba6 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1446,6 +1446,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { void Service::Publish(CmdArgList args, ConnectionContext* cntx) { string_view channel = ArgS(args, 0); + string_view msg = ArgS(args, 1); auto* cs = ServerState::tlocal()->channel_store(); vector subscribers = cs->FetchSubscribers(channel); @@ -1453,17 +1454,18 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { if (!subscribers.empty()) { auto subscribers_ptr = make_shared(move(subscribers)); - auto msg_ptr = make_shared(ArgS(args, 1)); - auto channel_ptr = make_shared(channel); + auto buf = shared_ptr{new char[channel.size() + msg.size()]}; + memcpy(buf.get(), channel.data(), channel.size()); + memcpy(buf.get() + channel.size(), msg.data(), msg.size()); - auto cb = [subscribers_ptr, msg_ptr, channel_ptr](unsigned idx, util::ProactorBase*) { + auto cb = [subscribers_ptr, buf, channel, msg](unsigned idx, util::ProactorBase*) { auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx, ChannelStore::Subscriber::ByThread); while (it != subscribers_ptr->end() && it->thread_id == idx) { facade::Connection* conn = it->conn_cntx->owner(); DCHECK(conn); - conn->SendPubMessageAsync({move(it->pattern), move(channel_ptr), move(msg_ptr)}); + conn->SendPubMessageAsync({move(it->pattern), move(buf), channel.size(), msg.size()}); it->borrow_token.Dec(); it++; } diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index 4805f2cb5..9af7d18b9 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -62,15 +62,15 @@ TestConnection::TestConnection(Protocol protocol, io::StringSink* sink) } void TestConnection::SendPubMessageAsync(PubMessage pmsg) { - if (pmsg.type == PubMessage::kPublish) { - messages.push_back(move(pmsg)); - } else { + if (auto* ptr = std::get_if(&pmsg.data); ptr != nullptr) { + messages.push_back(move(*ptr)); + } else if (auto* ptr = std::get_if(&pmsg.data); ptr != nullptr) { RedisReplyBuilder builder(sink_); const char* action[2] = {"unsubscribe", "subscribe"}; builder.StartArray(3); - builder.SendBulkString(action[pmsg.type == PubMessage::kSubscribe]); - builder.SendBulkString(*pmsg.channel); - builder.SendLong(pmsg.channel_cnt); + builder.SendBulkString(action[ptr->add]); + builder.SendBulkString(ptr->channel); + builder.SendLong(ptr->channel_cnt); } } @@ -84,7 +84,7 @@ class BaseFamilyTest::TestConnWrapper { RespVec ParseResponse(bool fully_consumed); // returns: type(pmessage), pattern, channel, message. - const facade::Connection::PubMessage& GetPubMessage(size_t index) const; + const facade::Connection::PubMessage::MessageData& GetPubMessage(size_t index) const; ConnectionContext* cmd_cntx() { return &cmd_cntx_; @@ -360,7 +360,7 @@ RespVec BaseFamilyTest::TestConnWrapper::ParseResponse(bool fully_consumed) { return res; } -const facade::Connection::PubMessage& BaseFamilyTest::TestConnWrapper::GetPubMessage( +const facade::Connection::PubMessage::MessageData& BaseFamilyTest::TestConnWrapper::GetPubMessage( size_t index) const { CHECK_LT(index, dummy_conn_->messages.size()); return dummy_conn_->messages[index]; @@ -391,8 +391,8 @@ size_t BaseFamilyTest::SubscriberMessagesLen(string_view conn_id) const { return it->second->conn()->messages.size(); } -const facade::Connection::PubMessage& BaseFamilyTest::GetPublishedMessage(string_view conn_id, - size_t index) const { +const facade::Connection::PubMessage::MessageData& BaseFamilyTest::GetPublishedMessage( + string_view conn_id, size_t index) const { auto it = connections_.find(conn_id); CHECK(it != connections_.end()); diff --git a/src/server/test_utils.h b/src/server/test_utils.h index e36a8ab71..c7f5f059b 100644 --- a/src/server/test_utils.h +++ b/src/server/test_utils.h @@ -23,7 +23,7 @@ class TestConnection : public facade::Connection { void SendPubMessageAsync(PubMessage pmsg) final; - std::vector messages; + std::vector messages; private: io::StringSink* sink_; @@ -87,8 +87,8 @@ class BaseFamilyTest : public ::testing::Test { std::string GetId() const; size_t SubscriberMessagesLen(std::string_view conn_id) const; - const facade::Connection::PubMessage& GetPublishedMessage(std::string_view conn_id, - size_t index) const; + const facade::Connection::PubMessage::MessageData& GetPublishedMessage(std::string_view conn_id, + size_t index) const; std::unique_ptr pp_; std::unique_ptr service_;