Implement PSUBSCRIBE/PUNSUBSCRIBE commands.

Add minimal tests.
This commit is contained in:
Roman Gershman 2022-06-02 07:54:34 +03:00
parent 8570a12d81
commit ec9754150f
13 changed files with 373 additions and 85 deletions

View file

@ -268,15 +268,15 @@ API 2.0
- [X] HSETNX - [X] HSETNX
- [X] HVALS - [X] HVALS
- [X] HSCAN - [X] HSCAN
- [ ] PubSub family - [X] PubSub family
- [X] PUBLISH - [X] PUBLISH
- [ ] PUBSUB - [ ] PUBSUB
- [ ] PUBSUB CHANNELS - [ ] PUBSUB CHANNELS
- [X] SUBSCRIBE - [X] SUBSCRIBE
- [X] UNSUBSCRIBE - [X] UNSUBSCRIBE
- [ ] PSUBSCRIBE - [X] PSUBSCRIBE
- [ ] PUNSUBSCRIBE - [X] PUNSUBSCRIBE
- [ ] Server Family - [X] Server Family
- [ ] WATCH - [ ] WATCH
- [ ] UNWATCH - [ ] UNWATCH
- [X] DISCARD - [X] DISCARD

View file

@ -69,11 +69,11 @@ constexpr size_t kMinReadSize = 256;
constexpr size_t kMaxReadSize = 32_KB; constexpr size_t kMaxReadSize = 32_KB;
struct AsyncMsg { struct AsyncMsg {
absl::Span<const std::string_view> msg_vec; Connection::PubMessage pub_msg;
fibers_ext::BlockingCounter bc; fibers_ext::BlockingCounter bc;
AsyncMsg(absl::Span<const std::string_view> vec, fibers_ext::BlockingCounter b) AsyncMsg(const Connection::PubMessage& pmsg, fibers_ext::BlockingCounter b)
: msg_vec(vec), bc(move(b)) { : pub_msg(pmsg), bc(move(b)) {
} }
}; };
@ -245,15 +245,17 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) {
breaker_cb_ = breaker_cb; breaker_cb_ = breaker_cb;
} }
void Connection::SendMsgVecAsync(absl::Span<const std::string_view> msg_vec, void Connection::SendMsgVecAsync(const PubMessage& pub_msg,
fibers_ext::BlockingCounter bc) { fibers_ext::BlockingCounter bc) {
DCHECK(cc_);
if (cc_->conn_closing) { if (cc_->conn_closing) {
bc.Dec(); bc.Dec();
return; return;
} }
void* ptr = mi_malloc(sizeof(AsyncMsg)); void* ptr = mi_malloc(sizeof(AsyncMsg));
AsyncMsg* amsg = new (ptr) AsyncMsg(msg_vec, move(bc)); AsyncMsg* amsg = new (ptr) AsyncMsg(pub_msg, move(bc));
ptr = mi_malloc(sizeof(Request)); ptr = mi_malloc(sizeof(Request));
Request* req = new (ptr) Request(0, 0); Request* req = new (ptr) Request(0, 0);
@ -571,7 +573,24 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
if (req->async_msg) { if (req->async_msg) {
++stats->async_writes_cnt; ++stats->async_writes_cnt;
builder->SendRawVec(req->async_msg->msg_vec);
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;
const PubMessage& pub_msg = req->async_msg->pub_msg;
string_view arr[4];
if (pub_msg.pattern.empty()) {
arr[0] = "message";
arr[1] = pub_msg.channel;
arr[2] = pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 3});
} else {
arr[0] = "pmessage";
arr[1] = pub_msg.pattern;
arr[2] = pub_msg.channel;
arr[3] = pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr, 4});
}
req->async_msg->bc.Dec(); req->async_msg->bc.Dec();
req->async_msg->~AsyncMsg(); req->async_msg->~AsyncMsg();

View file

@ -46,11 +46,20 @@ class Connection : public util::Connection {
using BreakerCb = std::function<void(uint32_t)>; using BreakerCb = std::function<void(uint32_t)>;
void RegisterOnBreak(BreakerCb breaker_cb); void RegisterOnBreak(BreakerCb breaker_cb);
// This interface is used to pass a raw message directly to the socket via zero-copy interface. // This interface is used to pass a published message directly to the socket without
// copying strings.
// Once the msg is sent "bc" will be decreased so that caller could release the underlying // Once the msg is sent "bc" will be decreased so that caller could release the underlying
// storage for the message. // storage for the message.
void SendMsgVecAsync(absl::Span<const std::string_view> msg_vec, // virtual - to allow the testing code to override it.
util::fibers_ext::BlockingCounter bc);
struct PubMessage {
// if empty - means its a regular message, otherwise it's pmessage.
std::string_view pattern;
std::string_view channel;
std::string_view message;
};
virtual void SendMsgVecAsync(const PubMessage& pub_msg, util::fibers_ext::BlockingCounter bc);
void SetName(std::string_view name) { void SetName(std::string_view name) {
CopyCharBuf(name, sizeof(name_), name_); CopyCharBuf(name, sizeof(name_), name_);

View file

@ -4,6 +4,10 @@
#include "server/channel_slice.h" #include "server/channel_slice.h"
extern "C" {
#include "redis/util.h"
}
namespace dfly { namespace dfly {
using namespace std; using namespace std;
@ -11,6 +15,14 @@ ChannelSlice::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid)
: conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) { : conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) {
} }
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) {
auto [it, added] = channels_.emplace(channel, nullptr);
if (added) {
it->second.reset(new Channel);
}
it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
}
void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) { void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me) {
auto it = channels_.find(channel); auto it = channels_.find(channel);
if (it != channels_.end()) { if (it != channels_.end()) {
@ -20,29 +32,52 @@ void ChannelSlice::RemoveSubscription(string_view channel, ConnectionContext* me
} }
} }
void ChannelSlice::AddSubscription(string_view channel, ConnectionContext* me, uint32_t thread_id) { void ChannelSlice::AddGlobPattern(string_view pattern, ConnectionContext* me, uint32_t thread_id) {
auto [it, added] = channels_.emplace(channel, nullptr); auto [it, added] = patterns_.emplace(pattern, nullptr);
if (added) { if (added) {
it->second.reset(new Channel); it->second.reset(new Channel);
} }
it->second->subscribers.emplace(me, SubscriberInternal{thread_id}); it->second->subscribers.emplace(me, SubscriberInternal{thread_id});
} }
void ChannelSlice::RemoveGlobPattern(string_view pattern, ConnectionContext* me) {
auto it = patterns_.find(pattern);
if (it != patterns_.end()) {
it->second->subscribers.erase(me);
if (it->second->subscribers.empty())
patterns_.erase(it);
}
}
auto ChannelSlice::FetchSubscribers(string_view channel) -> vector<Subscriber> { auto ChannelSlice::FetchSubscribers(string_view channel) -> vector<Subscriber> {
vector<Subscriber> res; vector<Subscriber> res;
auto it = channels_.find(channel); auto it = channels_.find(channel);
if (it != channels_.end()) { if (it != channels_.end()) {
res.reserve(it->second->subscribers.size()); res.reserve(it->second->subscribers.size());
for (const auto& k_v : it->second->subscribers) { CopySubsribers(it->second->subscribers, string{}, &res);
Subscriber s(k_v.first, k_v.second.thread_id); }
s.borrow_token.Inc();
res.push_back(std::move(s)); for (const auto& k_v : patterns_) {
const string& pat = k_v.first;
// 1 - match
if (stringmatchlen(pat.data(), pat.size(), channel.data(), channel.size(), 0) == 1) {
CopySubsribers(k_v.second->subscribers, pat, &res);
} }
} }
return res; return res;
} }
void ChannelSlice::CopySubsribers(const SubsribeMap& src, const std::string& pattern,
vector<Subscriber>* dest) {
for (const auto& sub : src) {
Subscriber s(sub.first, sub.second.thread_id);
s.pattern = pattern;
s.borrow_token.Inc();
dest->push_back(std::move(s));
}
}
} // namespace dfly } // namespace dfly

View file

@ -19,6 +19,9 @@ class ChannelSlice {
util::fibers_ext::BlockingCounter borrow_token; util::fibers_ext::BlockingCounter borrow_token;
uint32_t thread_id; uint32_t thread_id;
// non-empty if was registered via psubscribe
std::string pattern;
Subscriber(ConnectionContext* cntx, uint32_t tid); Subscriber(ConnectionContext* cntx, uint32_t tid);
// Subscriber() : borrow_token(0) {} // Subscriber() : borrow_token(0) {}
@ -31,18 +34,27 @@ class ChannelSlice {
std::vector<Subscriber> FetchSubscribers(std::string_view channel); std::vector<Subscriber> FetchSubscribers(std::string_view channel);
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id); void AddSubscription(std::string_view channel, ConnectionContext* me, uint32_t thread_id);
void RemoveSubscription(std::string_view channel, ConnectionContext* me);
void AddGlobPattern(std::string_view pattern, ConnectionContext* me, uint32_t thread_id);
void RemoveGlobPattern(std::string_view pattern, ConnectionContext* me);
private: private:
struct SubscriberInternal { struct SubscriberInternal {
uint32_t thread_id; // proactor thread id. uint32_t thread_id; // proactor thread id.
SubscriberInternal(uint32_t tid) : thread_id(tid) {} SubscriberInternal(uint32_t tid) : thread_id(tid) {
}
}; };
using SubsribeMap = absl::flat_hash_map<ConnectionContext*, SubscriberInternal>;
static void CopySubsribers(const SubsribeMap& src, const std::string& pattern,
std::vector<Subscriber>* dest);
struct Channel { struct Channel {
absl::flat_hash_map<ConnectionContext*, SubscriberInternal> subscribers; SubsribeMap subscribers;
}; };
absl::flat_hash_map<std::string, std::unique_ptr<Channel>> channels_; absl::flat_hash_map<std::string, std::unique_ptr<Channel>> channels_;

View file

@ -23,9 +23,11 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
DCHECK(to_add); DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo); conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
// to be able to read input and still write the output.
this->force_dispatch = true; this->force_dispatch = true;
} }
// Gather all the channels we need to subsribe to / remove.
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
bool res = false; bool res = false;
string_view channel = ArgS(args, i); string_view channel = ArgS(args, i);
@ -44,13 +46,14 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
} }
} }
if (!to_add && conn_state.subscribe_info->channels.empty()) { if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset(); conn_state.subscribe_info.reset();
force_dispatch = false; force_dispatch = false;
} }
sort(channels.begin(), channels.end()); sort(channels.begin(), channels.end());
// prepare the array in order to distribute the updates to the shards.
vector<unsigned> shard_idx(shard_set->size() + 1, 0); vector<unsigned> shard_idx(shard_set->size() + 1, 0);
for (const auto& k_v : channels) { for (const auto& k_v : channels) {
shard_idx[k_v.first]++; shard_idx[k_v.first]++;
@ -68,6 +71,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
int32_t tid = util::ProactorBase::GetIndex(); int32_t tid = util::ProactorBase::GetIndex();
DCHECK_GE(tid, 0); DCHECK_GE(tid, 0);
// Update the subsribers on publisher's side.
auto cb = [&](EngineShard* shard) { auto cb = [&](EngineShard* shard) {
ChannelSlice& cs = shard->channel_slice(); ChannelSlice& cs = shard->channel_slice();
unsigned start = shard_idx[shard->shard_id()]; unsigned start = shard_idx[shard->shard_id()];
@ -83,6 +87,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
} }
}; };
// Update subscription
shard_set->RunBriefInParallel(move(cb), shard_set->RunBriefInParallel(move(cb),
[&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; }); [&](ShardId sid) { return shard_idx[sid + 1] > shard_idx[sid]; });
} }
@ -90,6 +95,77 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
if (to_reply) { if (to_reply) {
const char* action[2] = {"unsubscribe", "subscribe"}; const char* action[2] = {"unsubscribe", "subscribe"};
for (size_t i = 0; i < result.size(); ++i) {
(*this)->StartArray(3);
(*this)->SendBulkString(action[to_add]);
(*this)->SendBulkString(ArgS(args, i)); // channel
// number of subsribed channels for this connection *right after*
// we subsribe.
(*this)->SendLong(result[i]);
}
}
}
void ConnectionContext::ChangePSub(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result(to_reply ? args.size() : 0, 0);
if (to_add || conn_state.subscribe_info) {
std::vector<string_view> patterns;
patterns.reserve(args.size());
if (!conn_state.subscribe_info) {
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
this->force_dispatch = true;
}
// Gather all the patterns we need to subsribe to / remove.
for (size_t i = 0; i < args.size(); ++i) {
bool res = false;
string_view pattern = ArgS(args, i);
if (to_add) {
res = conn_state.subscribe_info->patterns.emplace(pattern).second;
} else {
res = conn_state.subscribe_info->patterns.erase(pattern) > 0;
}
if (to_reply)
result[i] = conn_state.subscribe_info->patterns.size();
if (res) {
patterns.emplace_back(pattern);
}
}
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
force_dispatch = false;
}
int32_t tid = util::ProactorBase::GetIndex();
DCHECK_GE(tid, 0);
// Update the subsribers on publisher's side.
auto cb = [&](EngineShard* shard) {
ChannelSlice& cs = shard->channel_slice();
for (string_view pattern : patterns) {
if (to_add) {
cs.AddGlobPattern(pattern, this, tid);
} else {
cs.RemoveGlobPattern(pattern, this);
}
}
};
// Update pattern subscription. Run on all shards.
shard_set->RunBriefInParallel(move(cb));
}
if (to_reply) {
const char* action[2] = {"punsubscribe", "psubscribe"};
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
(*this)->StartArray(3); (*this)->StartArray(3);
(*this)->SendBulkString(action[to_add]); (*this)->SendBulkString(action[to_add]);
@ -100,18 +176,35 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
} }
void ConnectionContext::OnClose() { void ConnectionContext::OnClose() {
if (conn_state.subscribe_info) { if (!conn_state.subscribe_info)
return;
if (!conn_state.subscribe_info->channels.empty()) {
StringVec channels(conn_state.subscribe_info->channels.begin(), StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end()); conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end()); CmdArgVec arg_vec(channels.begin(), channels.end());
auto token = conn_state.subscribe_info->borrow_token; auto token = conn_state.subscribe_info->borrow_token;
ChangeSubscription(false, false, CmdArgList{arg_vec}); ChangeSubscription(false, false, CmdArgList{arg_vec});
DCHECK(!conn_state.subscribe_info);
// Check that all borrowers finished processing // Check that all borrowers finished processing
token.Wait(); token.Wait();
} }
if (conn_state.subscribe_info) {
DCHECK(!conn_state.subscribe_info->patterns.empty());
StringVec patterns(conn_state.subscribe_info->patterns.begin(),
conn_state.subscribe_info->patterns.end());
CmdArgVec arg_vec(patterns.begin(), patterns.end());
auto token = conn_state.subscribe_info->borrow_token;
ChangePSub(false, false, CmdArgList{arg_vec});
// Check that all borrowers finished processing
token.Wait();
DCHECK(!conn_state.subscribe_info);
}
} }
} // namespace dfly } // namespace dfly

View file

@ -50,10 +50,16 @@ struct ConnectionState {
struct SubscribeInfo { struct SubscribeInfo {
// TODO: to provide unique_strings across service. This will allow us to use string_view here. // TODO: to provide unique_strings across service. This will allow us to use string_view here.
absl::flat_hash_set<std::string> channels; absl::flat_hash_set<std::string> channels;
absl::flat_hash_set<std::string> patterns;
util::fibers_ext::BlockingCounter borrow_token; util::fibers_ext::BlockingCounter borrow_token;
SubscribeInfo() : borrow_token(0) {} bool IsEmpty() const {
return channels.empty() && patterns.empty();
}
SubscribeInfo() : borrow_token(0) {
}
}; };
std::unique_ptr<SubscribeInfo> subscribe_info; std::unique_ptr<SubscribeInfo> subscribe_info;
@ -85,6 +91,7 @@ class ConnectionContext : public facade::ConnectionContext {
} }
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args); void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
void ChangePSub(bool to_add, bool to_reply, CmdArgList args);
bool is_replicating = false; bool is_replicating = false;
}; };

View file

@ -22,13 +22,13 @@ extern "C" {
namespace dfly { namespace dfly {
using namespace absl;
using namespace boost;
using namespace std; using namespace std;
using namespace util; using namespace util;
using ::io::Result; using ::io::Result;
using testing::ElementsAre; using testing::ElementsAre;
using testing::HasSubstr; using testing::HasSubstr;
using absl::StrCat;
namespace this_fiber = boost::this_fiber;
namespace { namespace {
@ -411,6 +411,20 @@ TEST_F(DflyEngineTest, OOM) {
} }
} }
TEST_F(DflyEngineTest, PSubscribe) {
auto resp = pp_->at(1)->Await([&] { return Run({"psubscribe", "a*", "b*"}); });
EXPECT_THAT(resp, ArrLen(3));
resp = pp_->at(0)->Await([&] { return Run({"publish", "ab", "foo"}); });
EXPECT_THAT(resp, IntArg(1));
ASSERT_EQ(1, SubsriberMessagesLen("IO1"));
facade::Connection::PubMessage msg = GetPublishedMessage("IO1", 0);
EXPECT_EQ("foo", msg.message);
EXPECT_EQ("ab", msg.channel);
EXPECT_EQ("a*", msg.pattern);
}
// TODO: to test transactions with a single shard since then all transactions become local. // TODO: to test transactions with a single shard since then all transactions become local.
// To consider having a parameter in dragonfly engine controlling number of shards // To consider having a parameter in dragonfly engine controlling number of shards
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case. // unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.

View file

@ -912,23 +912,22 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
} }
fibers_ext::BlockingCounter bc(subsriber_arr.size()); fibers_ext::BlockingCounter bc(subsriber_arr.size());
char prefix[] = "*3\r\n$7\r\nmessage\r\n$";
char msg_size[32] = {0};
char channel_size[32] = {0};
absl::SNPrintF(msg_size, sizeof(msg_size), "%u\r\n", message.size());
absl::SNPrintF(channel_size, sizeof(channel_size), "%u\r\n", channel.size());
string_view msg_arr[] = {prefix, channel_size, channel, "\r\n$", msg_size, message, "\r\n"};
auto publish_cb = [&, bc](unsigned idx, util::ProactorBase*) mutable { auto publish_cb = [&, bc](unsigned idx, util::ProactorBase*) mutable {
unsigned start = slices[idx]; unsigned start = slices[idx];
for (unsigned i = start; i < subsriber_arr.size(); ++i) { for (unsigned i = start; i < subsriber_arr.size(); ++i) {
if (subsriber_arr[i].thread_id != idx) const ChannelSlice::Subscriber& subscriber = subsriber_arr[i];
if (subscriber.thread_id != idx)
break; break;
published.fetch_add(1, memory_order_relaxed); published.fetch_add(1, memory_order_relaxed);
subsriber_arr[i].conn_cntx->owner()->SendMsgVecAsync(msg_arr, bc); facade::Connection* conn = subsriber_arr[i].conn_cntx->owner();
DCHECK(conn);
facade::Connection::PubMessage pmsg;
pmsg.channel = channel;
pmsg.message = message;
pmsg.pattern = subscriber.pattern;
conn->SendMsgVecAsync(pmsg, bc);
} }
}; };
@ -959,6 +958,17 @@ void Service::Unsubscribe(CmdArgList args, ConnectionContext* cntx) {
cntx->ChangeSubscription(false, true, std::move(args)); cntx->ChangeSubscription(false, true, std::move(args));
} }
void Service::PSubscribe(CmdArgList args, ConnectionContext* cntx) {
args.remove_prefix(1);
cntx->ChangePSub(true, true, args);
}
void Service::PUnsubscribe(CmdArgList args, ConnectionContext* cntx) {
args.remove_prefix(1);
cntx->ChangePSub(false, true, args);
}
// Not a real implementation. Serves as a decorator to accept some function commands // Not a real implementation. Serves as a decorator to accept some function commands
// for testing. // for testing.
void Service::Function(CmdArgList args, ConnectionContext* cntx) { void Service::Function(CmdArgList args, ConnectionContext* cntx) {
@ -1024,6 +1034,8 @@ void Service::RegisterCommands() {
<< CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish) << CI{"PUBLISH", CO::LOADING | CO::FAST, 3, 0, 0, 0}.MFUNC(Publish)
<< CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe) << CI{"SUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Subscribe)
<< CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe) << CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(Unsubscribe)
<< CI{"PSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PSubscribe)
<< CI{"PUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PUnsubscribe)
<< CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function); << CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function);
StringFamily::Register(&registry_); StringFamily::Register(&registry_);

View file

@ -94,6 +94,8 @@ class Service : public facade::ServiceInterface {
void Publish(CmdArgList args, ConnectionContext* cntx); void Publish(CmdArgList args, ConnectionContext* cntx);
void Subscribe(CmdArgList args, ConnectionContext* cntx); void Subscribe(CmdArgList args, ConnectionContext* cntx);
void Unsubscribe(CmdArgList args, ConnectionContext* cntx); void Unsubscribe(CmdArgList args, ConnectionContext* cntx);
void PSubscribe(CmdArgList args, ConnectionContext* cntx);
void PUnsubscribe(CmdArgList args, ConnectionContext* cntx);
void Function(CmdArgList args, ConnectionContext* cntx); void Function(CmdArgList args, ConnectionContext* cntx);
struct EvalArgs { struct EvalArgs {
@ -113,6 +115,7 @@ class Service : public facade::ServiceInterface {
ServerFamily server_family_; ServerFamily server_family_;
CommandRegistry registry_; CommandRegistry registry_;
absl::flat_hash_map<std::string, unsigned> unknown_cmds_; absl::flat_hash_map<std::string, unsigned> unknown_cmds_;
mutable ::boost::fibers::mutex mu_; mutable ::boost::fibers::mutex mu_;
GlobalState global_state_ = GlobalState::ACTIVE; // protected by mu_; GlobalState global_state_ = GlobalState::ACTIVE; // protected by mu_;

View file

@ -37,9 +37,68 @@ static vector<string> SplitLines(const std::string& src) {
return res; return res;
} }
TestConnection::TestConnection(Protocol protocol)
: facade::Connection(protocol, nullptr, nullptr, nullptr) {
}
void TestConnection::SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) {
backing_str_.emplace_back(new string(pmsg.channel));
PubMessage dest;
dest.channel = *backing_str_.back();
backing_str_.emplace_back(new string(pmsg.message));
dest.message = *backing_str_.back();
if (!pmsg.pattern.empty()) {
backing_str_.emplace_back(new string(pmsg.pattern));
dest.pattern = *backing_str_.back();
}
messages.push_back(dest);
bc.Dec();
}
class BaseFamilyTest::TestConnWrapper {
public:
TestConnWrapper(Protocol proto);
~TestConnWrapper();
CmdArgVec Args(ArgSlice list);
RespVec ParseResponse();
// returns: type(pmessage), pattern, channel, message.
facade::Connection::PubMessage GetPubMessage(size_t index) const;
ConnectionContext* cmd_cntx() {
return &cmd_cntx_;
}
StringVec SplitLines() const {
return dfly::SplitLines(sink_.str());
}
void ClearSink() {
sink_.Clear();
}
TestConnection* conn() {
return dummy_conn_.get();
}
private:
::io::StringSink sink_; // holds the response blob
std::unique_ptr<TestConnection> dummy_conn_;
ConnectionContext cmd_cntx_;
std::vector<std::unique_ptr<std::string>> tmp_str_vec_;
std::unique_ptr<RedisParser> parser_;
};
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto) BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
: dummy_conn(new facade::Connection(proto, nullptr, nullptr, nullptr)), : dummy_conn_(new TestConnection(proto)), cmd_cntx_(&sink_, dummy_conn_.get()) {
cmd_cntx(&sink, dummy_conn.get()) {
} }
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() { BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {
@ -102,22 +161,22 @@ RespExpr BaseFamilyTest::Run(ArgSlice list) {
} }
RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
TestConnWrapper* conn = AddFindConn(Protocol::REDIS, id); TestConnWrapper* conn_wrapper = AddFindConn(Protocol::REDIS, id);
CmdArgVec args = conn->Args(slice); CmdArgVec args = conn_wrapper->Args(slice);
auto& context = conn->cmd_cntx; auto* context = conn_wrapper->cmd_cntx();
DCHECK(context.transaction == nullptr); DCHECK(context->transaction == nullptr);
service_->DispatchCommand(CmdArgList{args}, &context); service_->DispatchCommand(CmdArgList{args}, context);
DCHECK(context.transaction == nullptr); DCHECK(context->transaction == nullptr);
unique_lock lk(mu_); unique_lock lk(mu_);
last_cmd_dbg_info_ = context.last_command_debug; last_cmd_dbg_info_ = context->last_command_debug;
RespVec vec = conn->ParseResponse(); RespVec vec = conn_wrapper->ParseResponse();
if (vec.size() == 1) if (vec.size() == 1)
return vec.front(); return vec.front();
RespVec* new_vec = new RespVec(vec); RespVec* new_vec = new RespVec(vec);
@ -144,15 +203,15 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx; auto* context = conn->cmd_cntx();
DCHECK(context.transaction == nullptr); DCHECK(context->transaction == nullptr);
service_->DispatchMC(cmd, value, &context); service_->DispatchMC(cmd, value, context);
DCHECK(context.transaction == nullptr); DCHECK(context->transaction == nullptr);
return SplitLines(conn->sink.str()); return conn->SplitLines();
} }
auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResponse { auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResponse {
@ -165,11 +224,11 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp
cmd.key = key; cmd.key = key;
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx; auto* context = conn->cmd_cntx();
service_->DispatchMC(cmd, string_view{}, &context); service_->DispatchMC(cmd, string_view{}, context);
return SplitLines(conn->sink.str()); return conn->SplitLines();
} }
auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::string_view> list) auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::string_view> list)
@ -191,11 +250,11 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::stri
TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId());
auto& context = conn->cmd_cntx; auto* context = conn->cmd_cntx();
service_->DispatchMC(cmd, string_view{}, &context); service_->DispatchMC(cmd, string_view{}, context);
return SplitLines(conn->sink.str()); return conn->SplitLines();
} }
int64_t BaseFamilyTest::CheckedInt(std::initializer_list<std::string_view> list) { int64_t BaseFamilyTest::CheckedInt(std::initializer_list<std::string_view> list) {
@ -222,8 +281,8 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
if (v.empty()) { if (v.empty()) {
res.push_back(MutableSlice{}); res.push_back(MutableSlice{});
} else { } else {
tmp_str_vec.emplace_back(new string{v}); tmp_str_vec_.emplace_back(new string{v});
auto& s = *tmp_str_vec.back(); auto& s = *tmp_str_vec_.back();
res.emplace_back(s.data(), s.size()); res.emplace_back(s.data(), s.size());
} }
@ -233,19 +292,24 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
} }
RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() { RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() {
tmp_str_vec.emplace_back(new string{sink.str()}); tmp_str_vec_.emplace_back(new string{sink_.str()});
auto& s = *tmp_str_vec.back(); auto& s = *tmp_str_vec_.back();
auto buf = RespExpr::buffer(&s); auto buf = RespExpr::buffer(&s);
uint32_t consumed = 0; uint32_t consumed = 0;
parser.reset(new RedisParser{false}); // Client mode. parser_.reset(new RedisParser{false}); // Client mode.
RespVec res; RespVec res;
RedisParser::Result st = parser->Parse(buf, &consumed, &res); RedisParser::Result st = parser_->Parse(buf, &consumed, &res);
CHECK_EQ(RedisParser::OK, st); CHECK_EQ(RedisParser::OK, st);
return res; return res;
} }
facade::Connection::PubMessage BaseFamilyTest::TestConnWrapper::GetPubMessage(size_t index) const {
CHECK_LT(index, dummy_conn_->messages.size());
return dummy_conn_->messages[index];
}
bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const { bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const {
ShardId sid = Shard(key, shard_set->size()); ShardId sid = Shard(key, shard_set->size());
KeyLockArgs args; KeyLockArgs args;
@ -263,11 +327,30 @@ string BaseFamilyTest::GetId() const {
return absl::StrCat("IO", id); return absl::StrCat("IO", id);
} }
size_t BaseFamilyTest::SubsriberMessagesLen(string_view conn_id) const {
auto it = connections_.find(conn_id);
if (it == connections_.end())
return 0;
return it->second->conn()->messages.size();
}
facade::Connection::PubMessage BaseFamilyTest::GetPublishedMessage(string_view conn_id,
size_t index) const {
facade::Connection::PubMessage res;
auto it = connections_.find(conn_id);
if (it == connections_.end())
return res;
return it->second->GetPubMessage(index);
}
ConnectionContext::DebugInfo BaseFamilyTest::GetDebugInfo(const std::string& id) const { ConnectionContext::DebugInfo BaseFamilyTest::GetDebugInfo(const std::string& id) const {
auto it = connections_.find(id); auto it = connections_.find(id);
CHECK(it != connections_.end()); CHECK(it != connections_.end());
return it->second->cmd_cntx.last_command_debug; return it->second->cmd_cntx()->last_command_debug;
} }
auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestConnWrapper* { auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestConnWrapper* {
@ -278,7 +361,7 @@ auto BaseFamilyTest::AddFindConn(Protocol proto, std::string_view id) -> TestCon
if (inserted) { if (inserted) {
it->second.reset(new TestConnWrapper(proto)); it->second.reset(new TestConnWrapper(proto));
} else { } else {
it->second->sink.Clear(); it->second->ClearSink();
} }
return it->second.get(); return it->second.get();
} }

View file

@ -6,6 +6,7 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include "facade/dragonfly_connection.h"
#include "facade/memcache_parser.h" #include "facade/memcache_parser.h"
#include "facade/redis_parser.h" #include "facade/redis_parser.h"
#include "io/io.h" #include "io/io.h"
@ -16,6 +17,18 @@
namespace dfly { namespace dfly {
using namespace facade; using namespace facade;
class TestConnection : public facade::Connection {
public:
TestConnection(Protocol protocol);
void SendMsgVecAsync(const PubMessage& pmsg, util::fibers_ext::BlockingCounter bc) final;
std::vector<PubMessage> messages;
private:
std::vector<std::unique_ptr<std::string>> backing_str_;
};
class BaseFamilyTest : public ::testing::Test { class BaseFamilyTest : public ::testing::Test {
protected: protected:
BaseFamilyTest(); BaseFamilyTest();
@ -27,23 +40,7 @@ class BaseFamilyTest : public ::testing::Test {
void TearDown() override; void TearDown() override;
protected: protected:
struct TestConnWrapper { class TestConnWrapper;
::io::StringSink sink; // holds the response blob
std::unique_ptr<facade::Connection> dummy_conn;
ConnectionContext cmd_cntx;
std::vector<std::unique_ptr<std::string>> tmp_str_vec;
std::unique_ptr<RedisParser> parser;
TestConnWrapper(Protocol proto);
~TestConnWrapper();
CmdArgVec Args(ArgSlice list);
RespVec ParseResponse();
};
RespExpr Run(std::initializer_list<const std::string_view> list) { RespExpr Run(std::initializer_list<const std::string_view> list) {
return Run(ArgSlice{list.begin(), list.size()}); return Run(ArgSlice{list.begin(), list.size()});
@ -75,6 +72,11 @@ class BaseFamilyTest : public ::testing::Test {
void UpdateTime(uint64_t ms); void UpdateTime(uint64_t ms);
std::string GetId() const; std::string GetId() const;
size_t SubsriberMessagesLen(std::string_view conn_id) const;
// Returns message parts as returned by RESP:
// pmessage, pattern, channel, message
facade::Connection::PubMessage GetPublishedMessage(std::string_view conn_id, size_t index) const;
std::unique_ptr<util::ProactorPool> pp_; std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_; std::unique_ptr<Service> service_;

View file

@ -13,7 +13,6 @@
using namespace testing; using namespace testing;
using namespace std; using namespace std;
using namespace util; using namespace util;
using namespace boost;
namespace dfly { namespace dfly {