mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 10:25:47 +02:00
fix: data race during Publish in PubSub
The issue happens when SendMsgVecAsync is called with PubMessage that has string_view objects referencing objects in stack. We replace string_view with either string or shared_ptr<string> Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
parent
7975848c36
commit
2ec3d48534
6 changed files with 56 additions and 53 deletions
|
@ -87,7 +87,7 @@ constexpr size_t kMaxReadSize = 32_KB;
|
|||
struct PubMsgRecord {
|
||||
Connection::PubMessage pub_msg;
|
||||
|
||||
PubMsgRecord(const Connection::PubMessage& pmsg) : pub_msg(pmsg) {
|
||||
PubMsgRecord(Connection::PubMessage pmsg) : pub_msg(move(pmsg)) {
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -149,7 +149,7 @@ class Connection::Request {
|
|||
static RequestPtr New(mi_heap_t* heap, const RespVec& args, size_t capacity);
|
||||
|
||||
// Overload to create a new pubsub message
|
||||
static RequestPtr New(const PubMessage& pub_msg);
|
||||
static RequestPtr New(PubMessage pub_msg);
|
||||
|
||||
// Overload to create a new the monitor message
|
||||
static RequestPtr New(MonitorMessage msg);
|
||||
|
@ -221,13 +221,13 @@ void Connection::Request::SetArgs(const RespVec& args) {
|
|||
}
|
||||
}
|
||||
|
||||
Connection::RequestPtr Connection::Request::New(const PubMessage& pub_msg) {
|
||||
Connection::RequestPtr Connection::Request::New(PubMessage pub_msg) {
|
||||
// This will generate a new request for pubsub message
|
||||
// Please note that unlike the above case, we don't need to "protect", the internals here
|
||||
// since we are currently using a borrow token for it - i.e. the BlockingCounter will
|
||||
// ensure that the message is not deleted until we are finish sending it at the other
|
||||
// side of the queue
|
||||
PubMsgRecord new_msg{pub_msg};
|
||||
PubMsgRecord new_msg{move(pub_msg)};
|
||||
void* ptr = mi_malloc(sizeof(Request));
|
||||
Request* req = new (ptr) Request(std::move(new_msg));
|
||||
return Connection::RequestPtr{req, Connection::RequestDeleter{}};
|
||||
|
@ -278,15 +278,15 @@ void Connection::DispatchOperations::operator()(const PubMsgRecord& msg) {
|
|||
DCHECK(!rbuilder->is_sending);
|
||||
rbuilder->is_sending = true;
|
||||
if (pub_msg.pattern.empty()) {
|
||||
DVLOG(1) << "Sending message, from channel: " << pub_msg.channel << " " << *pub_msg.message;
|
||||
DVLOG(1) << "Sending message, from channel: " << *pub_msg.channel << " " << *pub_msg.message;
|
||||
arr[0] = "message";
|
||||
arr[1] = pub_msg.channel;
|
||||
arr[1] = *pub_msg.channel;
|
||||
arr[2] = *pub_msg.message;
|
||||
rbuilder->SendStringArr(absl::Span<string_view>{arr, 3});
|
||||
} else {
|
||||
arr[0] = "pmessage";
|
||||
arr[1] = pub_msg.pattern;
|
||||
arr[2] = pub_msg.channel;
|
||||
arr[2] = *pub_msg.channel;
|
||||
arr[3] = *pub_msg.message;
|
||||
rbuilder->SendStringArr(absl::Span<string_view>{arr, 4});
|
||||
}
|
||||
|
@ -449,13 +449,14 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) {
|
|||
breaker_cb_ = breaker_cb;
|
||||
}
|
||||
|
||||
void Connection::SendMsgVecAsync(const PubMessage& pub_msg) {
|
||||
void Connection::SendMsgVecAsync(PubMessage pub_msg) {
|
||||
DCHECK(cc_);
|
||||
|
||||
if (cc_->conn_closing) {
|
||||
return;
|
||||
}
|
||||
RequestPtr req = Request::New(pub_msg); // new (ptr) Request(0, 0);
|
||||
|
||||
RequestPtr req = Request::New(move(pub_msg));
|
||||
dispatch_q_.push_back(std::move(req));
|
||||
if (dispatch_q_.size() == 1) {
|
||||
evc_.notify();
|
||||
|
|
|
@ -54,14 +54,19 @@ class Connection : public util::Connection {
|
|||
|
||||
struct PubMessage {
|
||||
// if empty - means its a regular message, otherwise it's pmessage.
|
||||
std::string_view pattern;
|
||||
std::string_view channel;
|
||||
std::shared_ptr<const std::string> message; // ensure that this message would out live passing
|
||||
// between different threads/fibers
|
||||
std::string pattern;
|
||||
std::shared_ptr<std::string> channel;
|
||||
std::shared_ptr<std::string> message; // ensure that this message would out live passing
|
||||
// between different threads/fibers
|
||||
|
||||
PubMessage() = default;
|
||||
PubMessage(const PubMessage&) = delete;
|
||||
PubMessage& operator=(const PubMessage&) = delete;
|
||||
PubMessage(PubMessage&&) = default;
|
||||
};
|
||||
|
||||
// this function is overriden at test_utils TestConnection
|
||||
virtual void SendMsgVecAsync(const PubMessage& pub_msg);
|
||||
virtual void SendMsgVecAsync(PubMessage pub_msg);
|
||||
|
||||
// Note that this is accepted by value because the message is processed asynchronously.
|
||||
void SendMonitorMsg(std::string monitor_msg);
|
||||
|
|
|
@ -398,9 +398,9 @@ TEST_F(DflyEngineTest, PSubscribe) {
|
|||
|
||||
ASSERT_EQ(1, SubscriberMessagesLen("IO1"));
|
||||
|
||||
facade::Connection::PubMessage msg = GetPublishedMessage("IO1", 0);
|
||||
const facade::Connection::PubMessage& msg = GetPublishedMessage("IO1", 0);
|
||||
EXPECT_EQ("foo", *msg.message);
|
||||
EXPECT_EQ("ab", msg.channel);
|
||||
EXPECT_EQ("ab", *msg.channel);
|
||||
EXPECT_EQ("a*", msg.pattern);
|
||||
}
|
||||
|
||||
|
|
|
@ -1366,9 +1366,6 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
|
|||
void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view channel = ArgS(args, 1);
|
||||
|
||||
// shared_ptr ensures that the message lives until it's been sent to all subscribers and handled
|
||||
// by DispatchOperations.
|
||||
std::shared_ptr<const std::string> message = std::make_shared<const std::string>(ArgS(args, 2));
|
||||
ShardId sid = Shard(channel, shard_count());
|
||||
|
||||
auto cb = [&] { return EngineShard::tlocal()->channel_slice().FetchSubscribers(channel); };
|
||||
|
@ -1390,12 +1387,17 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
}
|
||||
|
||||
// shared_ptr ensures that the message lives until it's been sent to all subscribers and handled
|
||||
// by DispatchOperations.
|
||||
shared_ptr<string> msg_ptr = make_shared<string>(ArgS(args, 2));
|
||||
shared_ptr<string> channel_ptr = make_shared<string>(channel);
|
||||
|
||||
// We run publish_cb in each subscriber's thread.
|
||||
auto publish_cb = [&](unsigned idx, util::ProactorBase*) mutable {
|
||||
unsigned start = slices[idx];
|
||||
|
||||
for (unsigned i = start; i < subscriber_arr.size(); ++i) {
|
||||
const ChannelSlice::Subscriber& subscriber = subscriber_arr[i];
|
||||
ChannelSlice::Subscriber& subscriber = subscriber_arr[i];
|
||||
if (subscriber.thread_id != idx)
|
||||
break;
|
||||
|
||||
|
@ -1404,10 +1406,10 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
|
|||
facade::Connection* conn = subscriber_arr[i].conn_cntx->owner();
|
||||
DCHECK(conn);
|
||||
facade::Connection::PubMessage pmsg;
|
||||
pmsg.channel = channel;
|
||||
pmsg.message = message;
|
||||
pmsg.pattern = subscriber.pattern;
|
||||
conn->SendMsgVecAsync(pmsg);
|
||||
pmsg.channel = channel_ptr;
|
||||
pmsg.message = msg_ptr;
|
||||
pmsg.pattern = move(subscriber.pattern);
|
||||
conn->SendMsgVecAsync(move(pmsg));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -58,23 +58,21 @@ static vector<string> SplitLines(const std::string& src) {
|
|||
return res;
|
||||
}
|
||||
|
||||
TestConnection::TestConnection(Protocol protocol)
|
||||
: facade::Connection(protocol, nullptr, nullptr, nullptr) {
|
||||
TestConnection::TestConnection(Protocol protocol, io::StringSink* sink)
|
||||
: facade::Connection(protocol, nullptr, nullptr, nullptr), sink_(sink) {
|
||||
}
|
||||
|
||||
void TestConnection::SendMsgVecAsync(const PubMessage& pmsg) {
|
||||
backing_str_.emplace_back(new string(pmsg.channel));
|
||||
PubMessage dest;
|
||||
dest.channel = *backing_str_.back();
|
||||
|
||||
backing_str_.emplace_back(new string(*pmsg.message));
|
||||
dest.message = pmsg.message;
|
||||
|
||||
if (!pmsg.pattern.empty()) {
|
||||
backing_str_.emplace_back(new string(pmsg.pattern));
|
||||
dest.pattern = *backing_str_.back();
|
||||
void TestConnection::SendMsgVecAsync(PubMessage pmsg) {
|
||||
if (pmsg.type == PubMessage::kPublish) {
|
||||
messages.push_back(move(pmsg));
|
||||
} else {
|
||||
RedisReplyBuilder builder(sink_);
|
||||
const char* action[2] = {"unsubscribe", "subscribe"};
|
||||
builder.StartArray(3);
|
||||
builder.SendBulkString(action[pmsg.type == PubMessage::kSubscribe]);
|
||||
builder.SendBulkString(*pmsg.channel);
|
||||
builder.SendLong(pmsg.channel_cnt);
|
||||
}
|
||||
messages.push_back(dest);
|
||||
}
|
||||
|
||||
class BaseFamilyTest::TestConnWrapper {
|
||||
|
@ -87,7 +85,7 @@ class BaseFamilyTest::TestConnWrapper {
|
|||
RespVec ParseResponse(bool fully_consumed);
|
||||
|
||||
// returns: type(pmessage), pattern, channel, message.
|
||||
facade::Connection::PubMessage GetPubMessage(size_t index) const;
|
||||
const facade::Connection::PubMessage& GetPubMessage(size_t index) const;
|
||||
|
||||
ConnectionContext* cmd_cntx() {
|
||||
return &cmd_cntx_;
|
||||
|
@ -117,7 +115,7 @@ class BaseFamilyTest::TestConnWrapper {
|
|||
};
|
||||
|
||||
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
|
||||
: dummy_conn_(new TestConnection(proto)), cmd_cntx_(&sink_, dummy_conn_.get()) {
|
||||
: dummy_conn_(new TestConnection(proto, &sink_)), cmd_cntx_(&sink_, dummy_conn_.get()) {
|
||||
}
|
||||
|
||||
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {
|
||||
|
@ -359,7 +357,8 @@ RespVec BaseFamilyTest::TestConnWrapper::ParseResponse(bool fully_consumed) {
|
|||
return res;
|
||||
}
|
||||
|
||||
facade::Connection::PubMessage BaseFamilyTest::TestConnWrapper::GetPubMessage(size_t index) const {
|
||||
const facade::Connection::PubMessage& BaseFamilyTest::TestConnWrapper::GetPubMessage(
|
||||
size_t index) const {
|
||||
CHECK_LT(index, dummy_conn_->messages.size());
|
||||
return dummy_conn_->messages[index];
|
||||
}
|
||||
|
@ -389,13 +388,10 @@ size_t BaseFamilyTest::SubscriberMessagesLen(string_view conn_id) const {
|
|||
return it->second->conn()->messages.size();
|
||||
}
|
||||
|
||||
facade::Connection::PubMessage BaseFamilyTest::GetPublishedMessage(string_view conn_id,
|
||||
size_t index) const {
|
||||
facade::Connection::PubMessage res;
|
||||
|
||||
const facade::Connection::PubMessage& BaseFamilyTest::GetPublishedMessage(string_view conn_id,
|
||||
size_t index) const {
|
||||
auto it = connections_.find(conn_id);
|
||||
if (it == connections_.end())
|
||||
return res;
|
||||
CHECK(it != connections_.end());
|
||||
|
||||
return it->second->GetPubMessage(index);
|
||||
}
|
||||
|
|
|
@ -19,14 +19,14 @@ using namespace facade;
|
|||
|
||||
class TestConnection : public facade::Connection {
|
||||
public:
|
||||
TestConnection(Protocol protocol);
|
||||
TestConnection(Protocol protocol, io::StringSink* sink);
|
||||
|
||||
void SendMsgVecAsync(const PubMessage& pmsg) final;
|
||||
void SendMsgVecAsync(PubMessage pmsg) final;
|
||||
|
||||
std::vector<PubMessage> messages;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<std::string>> backing_str_;
|
||||
io::StringSink* sink_;
|
||||
};
|
||||
|
||||
class BaseFamilyTest : public ::testing::Test {
|
||||
|
@ -87,9 +87,8 @@ class BaseFamilyTest : public ::testing::Test {
|
|||
std::string GetId() const;
|
||||
size_t SubscriberMessagesLen(std::string_view conn_id) const;
|
||||
|
||||
// Returns message parts as returned by RESP:
|
||||
// pmessage, pattern, channel, message
|
||||
facade::Connection::PubMessage GetPublishedMessage(std::string_view conn_id, size_t index) const;
|
||||
const facade::Connection::PubMessage& GetPublishedMessage(std::string_view conn_id,
|
||||
size_t index) const;
|
||||
|
||||
std::unique_ptr<util::ProactorPool> pp_;
|
||||
std::unique_ptr<Service> service_;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue