fix: data race during Publish in PubSub

The issue happens when SendMsgVecAsync is called with PubMessage that has
string_view objects referencing objects in stack. We replace string_view
with either string or shared_ptr<string>

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2023-03-11 21:24:33 +02:00 committed by Roman Gershman
parent 7975848c36
commit 2ec3d48534
6 changed files with 56 additions and 53 deletions

View file

@ -87,7 +87,7 @@ constexpr size_t kMaxReadSize = 32_KB;
struct PubMsgRecord {
Connection::PubMessage pub_msg;
PubMsgRecord(const Connection::PubMessage& pmsg) : pub_msg(pmsg) {
PubMsgRecord(Connection::PubMessage pmsg) : pub_msg(move(pmsg)) {
}
};
@ -149,7 +149,7 @@ class Connection::Request {
static RequestPtr New(mi_heap_t* heap, const RespVec& args, size_t capacity);
// Overload to create a new pubsub message
static RequestPtr New(const PubMessage& pub_msg);
static RequestPtr New(PubMessage pub_msg);
// Overload to create a new the monitor message
static RequestPtr New(MonitorMessage msg);
@ -221,13 +221,13 @@ void Connection::Request::SetArgs(const RespVec& args) {
}
}
Connection::RequestPtr Connection::Request::New(const PubMessage& pub_msg) {
Connection::RequestPtr Connection::Request::New(PubMessage pub_msg) {
// This will generate a new request for pubsub message
// Please note that unlike the above case, we don't need to "protect", the internals here
// since we are currently using a borrow token for it - i.e. the BlockingCounter will
// ensure that the message is not deleted until we are finish sending it at the other
// side of the queue
PubMsgRecord new_msg{pub_msg};
PubMsgRecord new_msg{move(pub_msg)};
void* ptr = mi_malloc(sizeof(Request));
Request* req = new (ptr) Request(std::move(new_msg));
return Connection::RequestPtr{req, Connection::RequestDeleter{}};
@ -278,15 +278,15 @@ void Connection::DispatchOperations::operator()(const PubMsgRecord& msg) {
DCHECK(!rbuilder->is_sending);
rbuilder->is_sending = true;
if (pub_msg.pattern.empty()) {
DVLOG(1) << "Sending message, from channel: " << pub_msg.channel << " " << *pub_msg.message;
DVLOG(1) << "Sending message, from channel: " << *pub_msg.channel << " " << *pub_msg.message;
arr[0] = "message";
arr[1] = pub_msg.channel;
arr[1] = *pub_msg.channel;
arr[2] = *pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 3});
} else {
arr[0] = "pmessage";
arr[1] = pub_msg.pattern;
arr[2] = pub_msg.channel;
arr[2] = *pub_msg.channel;
arr[3] = *pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 4});
}
@ -449,13 +449,14 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) {
breaker_cb_ = breaker_cb;
}
void Connection::SendMsgVecAsync(const PubMessage& pub_msg) {
void Connection::SendMsgVecAsync(PubMessage pub_msg) {
DCHECK(cc_);
if (cc_->conn_closing) {
return;
}
RequestPtr req = Request::New(pub_msg); // new (ptr) Request(0, 0);
RequestPtr req = Request::New(move(pub_msg));
dispatch_q_.push_back(std::move(req));
if (dispatch_q_.size() == 1) {
evc_.notify();

View file

@ -54,14 +54,19 @@ class Connection : public util::Connection {
struct PubMessage {
// if empty - means its a regular message, otherwise it's pmessage.
std::string_view pattern;
std::string_view channel;
std::shared_ptr<const std::string> message; // ensure that this message would out live passing
// between different threads/fibers
std::string pattern;
std::shared_ptr<std::string> channel;
std::shared_ptr<std::string> message; // ensure that this message would out live passing
// between different threads/fibers
PubMessage() = default;
PubMessage(const PubMessage&) = delete;
PubMessage& operator=(const PubMessage&) = delete;
PubMessage(PubMessage&&) = default;
};
// this function is overriden at test_utils TestConnection
virtual void SendMsgVecAsync(const PubMessage& pub_msg);
virtual void SendMsgVecAsync(PubMessage pub_msg);
// Note that this is accepted by value because the message is processed asynchronously.
void SendMonitorMsg(std::string monitor_msg);

View file

@ -398,9 +398,9 @@ TEST_F(DflyEngineTest, PSubscribe) {
ASSERT_EQ(1, SubscriberMessagesLen("IO1"));
facade::Connection::PubMessage msg = GetPublishedMessage("IO1", 0);
const facade::Connection::PubMessage& msg = GetPublishedMessage("IO1", 0);
EXPECT_EQ("foo", *msg.message);
EXPECT_EQ("ab", msg.channel);
EXPECT_EQ("ab", *msg.channel);
EXPECT_EQ("a*", msg.pattern);
}

View file

@ -1366,9 +1366,6 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
string_view channel = ArgS(args, 1);
// shared_ptr ensures that the message lives until it's been sent to all subscribers and handled
// by DispatchOperations.
std::shared_ptr<const std::string> message = std::make_shared<const std::string>(ArgS(args, 2));
ShardId sid = Shard(channel, shard_count());
auto cb = [&] { return EngineShard::tlocal()->channel_slice().FetchSubscribers(channel); };
@ -1390,12 +1387,17 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
}
}
// shared_ptr ensures that the message lives until it's been sent to all subscribers and handled
// by DispatchOperations.
shared_ptr<string> msg_ptr = make_shared<string>(ArgS(args, 2));
shared_ptr<string> channel_ptr = make_shared<string>(channel);
// We run publish_cb in each subscriber's thread.
auto publish_cb = [&](unsigned idx, util::ProactorBase*) mutable {
unsigned start = slices[idx];
for (unsigned i = start; i < subscriber_arr.size(); ++i) {
const ChannelSlice::Subscriber& subscriber = subscriber_arr[i];
ChannelSlice::Subscriber& subscriber = subscriber_arr[i];
if (subscriber.thread_id != idx)
break;
@ -1404,10 +1406,10 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
facade::Connection* conn = subscriber_arr[i].conn_cntx->owner();
DCHECK(conn);
facade::Connection::PubMessage pmsg;
pmsg.channel = channel;
pmsg.message = message;
pmsg.pattern = subscriber.pattern;
conn->SendMsgVecAsync(pmsg);
pmsg.channel = channel_ptr;
pmsg.message = msg_ptr;
pmsg.pattern = move(subscriber.pattern);
conn->SendMsgVecAsync(move(pmsg));
}
};

View file

@ -58,23 +58,21 @@ static vector<string> SplitLines(const std::string& src) {
return res;
}
TestConnection::TestConnection(Protocol protocol)
: facade::Connection(protocol, nullptr, nullptr, nullptr) {
TestConnection::TestConnection(Protocol protocol, io::StringSink* sink)
: facade::Connection(protocol, nullptr, nullptr, nullptr), sink_(sink) {
}
void TestConnection::SendMsgVecAsync(const PubMessage& pmsg) {
backing_str_.emplace_back(new string(pmsg.channel));
PubMessage dest;
dest.channel = *backing_str_.back();
backing_str_.emplace_back(new string(*pmsg.message));
dest.message = pmsg.message;
if (!pmsg.pattern.empty()) {
backing_str_.emplace_back(new string(pmsg.pattern));
dest.pattern = *backing_str_.back();
void TestConnection::SendMsgVecAsync(PubMessage pmsg) {
if (pmsg.type == PubMessage::kPublish) {
messages.push_back(move(pmsg));
} else {
RedisReplyBuilder builder(sink_);
const char* action[2] = {"unsubscribe", "subscribe"};
builder.StartArray(3);
builder.SendBulkString(action[pmsg.type == PubMessage::kSubscribe]);
builder.SendBulkString(*pmsg.channel);
builder.SendLong(pmsg.channel_cnt);
}
messages.push_back(dest);
}
class BaseFamilyTest::TestConnWrapper {
@ -87,7 +85,7 @@ class BaseFamilyTest::TestConnWrapper {
RespVec ParseResponse(bool fully_consumed);
// returns: type(pmessage), pattern, channel, message.
facade::Connection::PubMessage GetPubMessage(size_t index) const;
const facade::Connection::PubMessage& GetPubMessage(size_t index) const;
ConnectionContext* cmd_cntx() {
return &cmd_cntx_;
@ -117,7 +115,7 @@ class BaseFamilyTest::TestConnWrapper {
};
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
: dummy_conn_(new TestConnection(proto)), cmd_cntx_(&sink_, dummy_conn_.get()) {
: dummy_conn_(new TestConnection(proto, &sink_)), cmd_cntx_(&sink_, dummy_conn_.get()) {
}
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {
@ -359,7 +357,8 @@ RespVec BaseFamilyTest::TestConnWrapper::ParseResponse(bool fully_consumed) {
return res;
}
facade::Connection::PubMessage BaseFamilyTest::TestConnWrapper::GetPubMessage(size_t index) const {
const facade::Connection::PubMessage& BaseFamilyTest::TestConnWrapper::GetPubMessage(
size_t index) const {
CHECK_LT(index, dummy_conn_->messages.size());
return dummy_conn_->messages[index];
}
@ -389,13 +388,10 @@ size_t BaseFamilyTest::SubscriberMessagesLen(string_view conn_id) const {
return it->second->conn()->messages.size();
}
facade::Connection::PubMessage BaseFamilyTest::GetPublishedMessage(string_view conn_id,
size_t index) const {
facade::Connection::PubMessage res;
const facade::Connection::PubMessage& BaseFamilyTest::GetPublishedMessage(string_view conn_id,
size_t index) const {
auto it = connections_.find(conn_id);
if (it == connections_.end())
return res;
CHECK(it != connections_.end());
return it->second->GetPubMessage(index);
}

View file

@ -19,14 +19,14 @@ using namespace facade;
class TestConnection : public facade::Connection {
public:
TestConnection(Protocol protocol);
TestConnection(Protocol protocol, io::StringSink* sink);
void SendMsgVecAsync(const PubMessage& pmsg) final;
void SendMsgVecAsync(PubMessage pmsg) final;
std::vector<PubMessage> messages;
private:
std::vector<std::unique_ptr<std::string>> backing_str_;
io::StringSink* sink_;
};
class BaseFamilyTest : public ::testing::Test {
@ -87,9 +87,8 @@ class BaseFamilyTest : public ::testing::Test {
std::string GetId() const;
size_t SubscriberMessagesLen(std::string_view conn_id) const;
// Returns message parts as returned by RESP:
// pmessage, pattern, channel, message
facade::Connection::PubMessage GetPublishedMessage(std::string_view conn_id, size_t index) const;
const facade::Connection::PubMessage& GetPublishedMessage(std::string_view conn_id,
size_t index) const;
std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_;