Factor out client connections module into a separate library called facade

This commit is contained in:
Roman Gershman 2022-03-03 01:59:29 +02:00
parent 28a2db1044
commit 3f0fcbf99f
64 changed files with 946 additions and 666 deletions

View file

@ -2,25 +2,22 @@ add_executable(dragonfly dfly_main.cc)
cxx_link(dragonfly base dragonfly_lib)
add_library(dragonfly_lib command_registry.cc common.cc config_flags.cc
conn_context.cc db_slice.cc debugcmd.cc dragonfly_listener.cc
dragonfly_connection.cc engine_shard_set.cc generic_family.cc hset_family.cc
list_family.cc main_service.cc memcache_parser.cc rdb_load.cc rdb_save.cc replica.cc
snapshot.cc redis_parser.cc reply_builder.cc script_mgr.cc server_family.cc
db_slice.cc debugcmd.cc
engine_shard_set.cc generic_family.cc hset_family.cc
list_family.cc main_service.cc rdb_load.cc rdb_save.cc replica.cc
snapshot.cc script_mgr.cc server_family.cc
set_family.cc
string_family.cc transaction.cc zset_family.cc)
cxx_link(dragonfly_lib dfly_core redis_lib uring_fiber_lib
fibers_ext strings_lib http_server_lib tls_lib)
cxx_link(dragonfly_lib dfly_core dfly_facade redis_lib strings_lib)
add_library(dfly_test_lib test_utils.cc)
cxx_link(dfly_test_lib dragonfly_lib gtest_main_ext)
cxx_link(dfly_test_lib dragonfly_lib facade_test gtest_main_ext)
cxx_test(dragonfly_test dfly_test_lib LABELS DFLY)
cxx_test(generic_family_test dfly_test_lib LABELS DFLY)
cxx_test(hset_family_test dfly_test_lib LABELS DFLY)
cxx_test(list_family_test dfly_test_lib LABELS DFLY)
cxx_test(memcache_parser_test dfly_test_lib LABELS DFLY)
cxx_test(redis_parser_test dfly_test_lib LABELS DFLY)
cxx_test(set_family_test dfly_test_lib LABELS DFLY)
cxx_test(string_family_test dfly_test_lib LABELS DFLY)
cxx_test(rdb_test dfly_test_lib DATA testdata/empty.rdb testdata/small.rdb LABELS DFLY)

View file

@ -38,41 +38,6 @@ Interpreter& ServerState::GetInterpreter() {
return interpreter_.value();
}
#define ADD(x) (x) += o.x
ConnectionStats& ConnectionStats::operator+=(const ConnectionStats& o) {
// To break this code deliberately if we add/remove a field to this struct.
static_assert(sizeof(ConnectionStats) == 64);
ADD(num_conns);
ADD(num_replicas);
ADD(read_buf_capacity);
ADD(io_read_cnt);
ADD(io_read_bytes);
ADD(io_write_cnt);
ADD(io_write_bytes);
ADD(pipelined_cmd_cnt);
ADD(command_cnt);
return *this;
}
#undef ADD
string WrongNumArgsError(std::string_view cmd) {
return absl::StrCat("wrong number of arguments for '", cmd, "' command");
}
const char kSyntaxErr[] = "syntax error";
const char kWrongTypeErr[] = "-WRONGTYPE Operation against a key holding the wrong kind of value";
const char kKeyNotFoundErr[] = "no such key";
const char kInvalidIntErr[] = "value is not an integer or out of range";
const char kUintErr[] = "value is out of range, must be positive";
const char kDbIndOutOfRangeErr[] = "DB index is out of range";
const char kInvalidDbIndErr[] = "invalid DB index";
const char kScriptNotFound[] = "-NOSCRIPT No matching script. Please use EVAL.";
const char kAuthRejected[] = "-WRONGPASS invalid username-password pair or user is disabled.";
const char* GlobalState::Name(S s) {
switch (s) {
case GlobalState::IDLE:
@ -88,20 +53,3 @@ const char* GlobalState::Name(S s) {
}
} // namespace dfly
namespace std {
ostream& operator<<(ostream& os, dfly::CmdArgList ras) {
os << "[";
if (!ras.empty()) {
for (size_t i = 0; i < ras.size() - 1; ++i) {
os << dfly::ArgS(ras, i) << ",";
}
os << dfly::ArgS(ras, ras.size() - 1);
}
os << "]";
return os;
}
} // namespace std

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
@ -9,25 +9,24 @@
#include <string_view>
#include <vector>
#include "facade/facade_types.h"
namespace dfly {
enum class ListDir : uint8_t { LEFT, RIGHT };
enum class Protocol : uint8_t {
MEMCACHE = 1,
REDIS = 2
};
using DbIndex = uint16_t;
using ShardId = uint16_t;
using TxId = uint64_t;
using TxClock = uint64_t;
using facade::MutableSlice;
using facade::CmdArgList;
using facade::CmdArgVec;
using facade::ArgS;
using ArgSlice = absl::Span<const std::string_view>;
using MutableSlice = absl::Span<char>;
using CmdArgList = absl::Span<MutableSlice>;
using CmdArgVec = std::vector<MutableSlice>;
constexpr DbIndex kInvalidDbId = DbIndex(-1);
constexpr ShardId kInvalidSid = ShardId(-1);
@ -50,42 +49,11 @@ struct KeyIndex {
unsigned step; // 1 for commands like mget. 2 for commands like mset.
};
struct ConnectionStats {
uint32_t num_conns = 0;
uint32_t num_replicas = 0;
size_t read_buf_capacity = 0;
size_t io_read_cnt = 0;
size_t io_read_bytes = 0;
size_t io_write_cnt = 0;
size_t io_write_bytes = 0;
size_t command_cnt = 0;
size_t pipelined_cmd_cnt = 0;
ConnectionStats& operator+=(const ConnectionStats& o);
};
struct OpArgs {
EngineShard* shard;
DbIndex db_ind;
};
constexpr inline unsigned long long operator""_MB(unsigned long long x) {
return 1024L * 1024L * x;
}
constexpr inline unsigned long long operator""_KB(unsigned long long x) {
return 1024L * x;
}
inline std::string_view ArgS(CmdArgList args, size_t i) {
auto arg = args[i];
return std::string_view(arg.data(), arg.size());
}
inline MutableSlice ToMSS(absl::Span<uint8_t> span) {
return MutableSlice{reinterpret_cast<char*>(span.data()), span.size()};
}
inline void ToUpper(const MutableSlice* val) {
for (auto& c : *val) {
c = absl::ascii_toupper(c);
@ -99,8 +67,3 @@ inline void ToLower(const MutableSlice* val) {
}
} // namespace dfly
namespace std {
ostream& operator<<(ostream& os, dfly::CmdArgList args);
} // namespace std

View file

@ -1,33 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/conn_context.h"
#include "base/logging.h"
#include "server/dragonfly_connection.h"
namespace dfly {
ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : owner_(owner) {
switch (owner->protocol()) {
case Protocol::REDIS:
rbuilder_.reset(new RedisReplyBuilder(stream));
break;
case Protocol::MEMCACHE:
rbuilder_.reset(new MCReplyBuilder(stream));
break;
}
}
Protocol ConnectionContext::protocol() const {
return owner_->protocol();
}
RedisReplyBuilder* ConnectionContext::operator->() {
CHECK(Protocol::REDIS == protocol());
return static_cast<RedisReplyBuilder*>(rbuilder_.get());
}
} // namespace dfly

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
@ -6,12 +6,11 @@
#include <absl/container/flat_hash_set.h>
#include "facade/conn_context.h"
#include "server/common_types.h"
#include "server/reply_builder.h"
namespace dfly {
class Connection;
class EngineShardSet;
struct StoredCmd {
@ -30,19 +29,6 @@ struct ConnectionState {
ExecState exec_state = EXEC_INACTIVE;
std::vector<StoredCmd> exec_body;
enum Mask : uint32_t {
ASYNC_DISPATCH = 1, // whether a command is handled via async dispatch.
CONN_CLOSING = 2, // could be because of unrecoverable error or planned action.
// Whether this connection belongs to replica, i.e. a dragonfly slave is connected to this
// host (master) via this connection to sync from it.
REPL_CONNECTION = 4,
REQ_AUTH = 8,
AUTHENTICATED = 0x10,
};
uint32_t mask = 0; // A bitmask of Mask values.
enum MCGetMask {
FETCH_CAS_VER = 1,
};
@ -52,14 +38,6 @@ struct ConnectionState {
// For get op - we use it as a mask of MCGetMask values.
uint32_t memcache_flag = 0;
bool IsClosing() const {
return mask & CONN_CLOSING;
}
bool IsRunViaDispatch() const {
return mask & ASYNC_DISPATCH;
}
// Lua-script related data.
struct Script {
bool is_write = true;
@ -69,10 +47,11 @@ struct ConnectionState {
std::optional<Script> script_info;
};
class ConnectionContext {
class ConnectionContext : public facade::ConnectionContext {
public:
ConnectionContext(::io::Sink* stream, Connection* owner);
ConnectionContext(::io::Sink* stream, facade::Connection* owner)
: facade::ConnectionContext(stream, owner) {
}
struct DebugInfo {
uint32_t shards_count = 0;
TxClock clock = 0;
@ -85,36 +64,11 @@ class ConnectionContext {
Transaction* transaction = nullptr;
const CommandId* cid = nullptr;
EngineShardSet* shard_set = nullptr;
Connection* owner() {
return owner_;
}
Protocol protocol() const;
ConnectionState conn_state;
DbIndex db_index() const {
return conn_state.db_index;
}
ConnectionState conn_state;
// A convenient proxy for redis interface.
RedisReplyBuilder* operator->();
ReplyBuilderInterface* reply_builder() {
return rbuilder_.get();
}
// Allows receiving the output data from the commands called from scripts.
ReplyBuilderInterface* Inject(ReplyBuilderInterface* new_i) {
ReplyBuilderInterface* res = rbuilder_.release();
rbuilder_.reset(new_i);
return res;
}
private:
Connection* owner_;
std::unique_ptr<ReplyBuilderInterface> rbuilder_;
};
} // namespace dfly

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
@ -21,6 +21,7 @@ namespace dfly {
using namespace boost;
using namespace std;
using namespace util;
using facade::OpStatus;
#define ADD(x) (x) += o.x

View file

@ -8,7 +8,7 @@
#include <absl/container/flat_hash_set.h>
#include "core/intent_lock.h"
#include "core/op_status.h"
#include "facade/op_status.h"
#include "server/common_types.h"
#include "server/table.h"
@ -18,6 +18,8 @@ class ProactorBase;
namespace dfly {
using facade::OpResult;
struct DbStats {
// number of active keys.
size_t key_count = 0;

View file

@ -25,6 +25,7 @@ using namespace std;
using namespace util;
namespace this_fiber = ::boost::this_fiber;
using boost::fibers::fiber;
using facade::kUintErr;
namespace fs = std::filesystem;
struct PopulateBatch {

View file

@ -1,9 +1,9 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "base/init.h"
#include "server/dragonfly_listener.h"
#include "facade/dragonfly_listener.h"
#include "server/main_service.h"
#include "util/accept_server.h"
#include "util/uring/uring_pool.h"
@ -14,13 +14,13 @@ DECLARE_uint32(memcache_port);
using namespace util;
using namespace std;
using namespace facade;
namespace dfly {
void RunEngine(ProactorPool* pool, AcceptServer* acceptor, HttpListener<>* http) {
void RunEngine(ProactorPool* pool, AcceptServer* acceptor) {
Service service(pool);
service.RegisterHttp(http);
service.Init(acceptor);
acceptor->AddListener(FLAGS_port, new Listener{Protocol::REDIS, &service});
if (FLAGS_memcache_port > 0) {
@ -47,11 +47,8 @@ int main(int argc, char* argv[]) {
pp.Run();
AcceptServer acceptor(&pp);
unique_ptr<HttpListener<>> http_listener(new HttpListener<>);
http_listener->enable_metrics();
dfly::RunEngine(&pp, &acceptor, http_listener.get());
dfly::RunEngine(&pp, &acceptor);
pp.Stop();

View file

@ -1,539 +0,0 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/dragonfly_connection.h"
#include <absl/container/flat_hash_map.h>
#include <absl/strings/match.h>
#include <mimalloc.h>
#include <boost/fiber/operations.hpp>
#include "base/logging.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/main_service.h"
#include "server/memcache_parser.h"
#include "server/redis_parser.h"
#include "server/server_state.h"
#include "server/transaction.h"
#include "util/fiber_sched_algo.h"
#include "util/tls/tls_socket.h"
#include "util/uring/uring_socket.h"
DEFINE_bool(tcp_nodelay, true, "Configures dragonfly connections with socket option TCP_NODELAY");
using namespace util;
using namespace std;
using nonstd::make_unexpected;
namespace this_fiber = boost::this_fiber;
namespace fibers = boost::fibers;
namespace dfly {
namespace {
void SendProtocolError(RedisParser::Result pres, FiberSocketBase* peer) {
string res("-ERR Protocol error: ");
if (pres == RedisParser::BAD_BULKLEN) {
res.append("invalid bulk length\r\n");
} else {
CHECK_EQ(RedisParser::BAD_ARRAYLEN, pres);
res.append("invalid multibulk length\r\n");
}
auto size_res = peer->Send(::io::Buffer(res));
if (!size_res) {
LOG(WARNING) << "Error " << size_res.error();
}
}
void RespToArgList(const RespVec& src, CmdArgVec* dest) {
dest->resize(src.size());
for (size_t i = 0; i < src.size(); ++i) {
(*dest)[i] = ToMSS(src[i].GetBuf());
}
}
// TODO: to implement correct matcher according to HTTP spec
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html
// One place to find a good implementation would be https://github.com/h2o/picohttpparser
bool MatchHttp11Line(string_view line) {
return absl::StartsWith(line, "GET ") && absl::EndsWith(line, "HTTP/1.1");
}
constexpr size_t kMinReadSize = 256;
constexpr size_t kMaxReadSize = 32_KB;
} // namespace
struct Connection::Shutdown {
absl::flat_hash_map<ShutdownHandle, ShutdownCb> map;
ShutdownHandle next_handle = 1;
ShutdownHandle Add(ShutdownCb cb) {
map[next_handle] = move(cb);
return next_handle++;
}
void Remove(ShutdownHandle sh) {
map.erase(sh);
}
};
struct Connection::Request {
absl::FixedArray<MutableSlice> args;
// I do not use mi_heap_t explicitly but mi_stl_allocator at the end does the same job
// of using the thread's heap.
absl::FixedArray<char, 256, mi_stl_allocator<char>> storage;
Request(size_t nargs, size_t capacity) : args(nargs), storage(capacity) {
}
Request(const Request&) = delete;
};
Connection::Connection(Protocol protocol, Service* service, SSL_CTX* ctx)
: io_buf_{kMinReadSize}, service_(service), ctx_(ctx) {
protocol_ = protocol;
switch (protocol) {
case Protocol::REDIS:
redis_parser_.reset(new RedisParser);
break;
case Protocol::MEMCACHE:
memcache_parser_.reset(new MemcacheParser);
break;
}
}
Connection::~Connection() {
}
void Connection::OnShutdown() {
VLOG(1) << "Connection::OnShutdown";
if (shutdown_) {
for (const auto& k_v : shutdown_->map) {
k_v.second();
}
}
}
auto Connection::RegisterShutdownHook(ShutdownCb cb) -> ShutdownHandle {
if (!shutdown_) {
shutdown_ = make_unique<Shutdown>();
}
return shutdown_->Add(std::move(cb));
}
void Connection::UnregisterShutdownHook(ShutdownHandle id) {
if (shutdown_) {
shutdown_->Remove(id);
if (shutdown_->map.empty())
shutdown_.reset();
}
}
void Connection::HandleRequests() {
this_fiber::properties<FiberProps>().set_name("DflyConnection");
LinuxSocketBase* lsb = static_cast<LinuxSocketBase*>(socket_.get());
if (FLAGS_tcp_nodelay) {
int val = 1;
CHECK_EQ(0, setsockopt(lsb->native_handle(), SOL_TCP, TCP_NODELAY, &val, sizeof(val)));
}
auto remote_ep = lsb->RemoteEndpoint();
unique_ptr<tls::TlsSocket> tls_sock;
if (ctx_) {
tls_sock.reset(new tls::TlsSocket(socket_.get()));
tls_sock->InitSSL(ctx_);
FiberSocketBase::AcceptResult aresult = tls_sock->Accept();
if (!aresult) {
LOG(WARNING) << "Error handshaking " << aresult.error().message();
return;
}
VLOG(1) << "TLS handshake succeeded";
}
FiberSocketBase* peer = tls_sock ? (FiberSocketBase*)tls_sock.get() : socket_.get();
io::Result<bool> http_res = CheckForHttpProto(peer);
if (http_res) {
if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
HttpConnection http_conn{service_->http_listener()};
http_conn.SetSocket(peer);
auto ec = http_conn.ParseFromBuffer(io_buf_.InputBuffer());
io_buf_.ConsumeInput(io_buf_.InputLen());
if (!ec) {
http_conn.HandleRequests();
}
http_conn.ReleaseSocket();
} else {
cc_.reset(new ConnectionContext(peer, this));
cc_->shard_set = &service_->shard_set();
if (service_->IsPassProtected())
cc_->conn_state.mask |= ConnectionState::REQ_AUTH;
// TODO: to move this interface to LinuxSocketBase so we won't need to cast.
uring::UringSocket* us = static_cast<uring::UringSocket*>(socket_.get());
bool poll_armed = true;
uint32_t poll_id = us->PollEvent(POLLERR | POLLHUP, [&](uint32_t mask) {
VLOG(1) << "Got event " << mask;
cc_->conn_state.mask |= ConnectionState::CONN_CLOSING;
if (cc_->transaction) {
cc_->transaction->BreakOnClose();
}
evc_.notify(); // Notify dispatch fiber.
poll_armed = false;
});
ConnectionFlow(peer);
if (poll_armed) {
us->CancelPoll(poll_id);
}
}
}
VLOG(1) << "Closed connection for peer " << remote_ep;
}
io::Result<bool> Connection::CheckForHttpProto(util::FiberSocketBase* peer) {
size_t last_len = 0;
do {
auto buf = io_buf_.AppendBuffer();
::io::Result<size_t> recv_sz = peer->Recv(buf);
if (!recv_sz) {
return make_unexpected(recv_sz.error());
}
io_buf_.CommitWrite(*recv_sz);
string_view ib = ToSV(io_buf_.InputBuffer().subspan(last_len));
size_t pos = ib.find('\n');
if (pos != string_view::npos) {
ib = ToSV(io_buf_.InputBuffer().first(last_len + pos));
if (ib.size() < 10 || ib.back() != '\r')
return false;
ib.remove_suffix(1);
return MatchHttp11Line(ib);
}
last_len = io_buf_.InputLen();
} while (last_len < 1024);
return false;
}
void Connection::ConnectionFlow(FiberSocketBase* peer) {
auto dispatch_fb = fibers::fiber(fibers::launch::dispatch, [&] { DispatchFiber(peer); });
ConnectionStats* stats = ServerState::tl_connection_stats();
stats->num_conns++;
stats->read_buf_capacity += io_buf_.Capacity();
ParserStatus parse_status = OK;
// At the start we read from the socket to determine the HTTP/Memstore protocol.
// Therefore we may already have some data in the buffer.
if (io_buf_.InputLen() > 0) {
if (redis_parser_) {
parse_status = ParseRedis();
} else {
DCHECK(memcache_parser_);
parse_status = ParseMemcache();
}
}
error_code ec;
// Main loop.
if (parse_status != ERROR) {
auto res = IoLoop(peer);
if (holds_alternative<error_code>(res)) {
ec = get<error_code>(res);
} else {
parse_status = get<ParserStatus>(res);
}
}
cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; // Signal dispatch to close.
evc_.notify();
dispatch_fb.join();
stats->read_buf_capacity -= io_buf_.Capacity();
// Update num_replicas if this was a replica connection.
if (cc_->conn_state.mask & ConnectionState::REPL_CONNECTION) {
--stats->num_replicas;
}
// We wait for dispatch_fb to finish writing the previous replies before replying to the last
// offending request.
if (parse_status == ERROR) {
VLOG(1) << "Error stats " << parse_status;
if (redis_parser_) {
SendProtocolError(RedisParser::Result(parser_error_), peer);
} else {
string_view sv{"CLIENT_ERROR bad command line format\r\n"};
auto size_res = peer->Send(::io::Buffer(sv));
if (!size_res) {
LOG(WARNING) << "Error " << size_res.error();
ec = size_res.error();
}
}
}
if (ec && !FiberSocketBase::IsConnClosed(ec)) {
LOG(WARNING) << "Socket error " << ec;
}
--stats->num_conns;
}
auto Connection::ParseRedis() -> ParserStatus {
RespVec args;
CmdArgVec arg_vec;
uint32_t consumed = 0;
RedisParser::Result result = RedisParser::OK;
ReplyBuilderInterface* builder = cc_->reply_builder();
mi_heap_t* tlh = mi_heap_get_backing();
do {
result = redis_parser_->Parse(io_buf_.InputBuffer(), &consumed, &args);
if (result == RedisParser::OK && !args.empty()) {
RespExpr& first = args.front();
if (first.type == RespExpr::STRING) {
DVLOG(2) << "Got Args with first token " << ToSV(first.GetBuf());
}
// An optimization to skip dispatch_q_ if no pipelining is identified.
// We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the
// dispatch fiber pulls the last record but is still processing the command and then this
// fiber enters the condition below and executes out of order.
bool is_sync_dispatch = !cc_->conn_state.IsRunViaDispatch();
if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf_.InputLen()) {
RespToArgList(args, &arg_vec);
service_->DispatchCommand(CmdArgList{arg_vec.data(), arg_vec.size()}, cc_.get());
} else {
// Dispatch via queue to speedup input reading
// We could use
Request* req = FromArgs(std::move(args), tlh);
dispatch_q_.emplace_back(req);
if (dispatch_q_.size() == 1) {
evc_.notify();
} else if (dispatch_q_.size() > 10) {
this_fiber::yield();
}
}
}
io_buf_.ConsumeInput(consumed);
} while (RedisParser::OK == result && !builder->GetError());
parser_error_ = result;
if (result == RedisParser::OK)
return OK;
if (result == RedisParser::INPUT_PENDING)
return NEED_MORE;
return ERROR;
}
auto Connection::ParseMemcache() -> ParserStatus {
MemcacheParser::Result result = MemcacheParser::OK;
uint32_t consumed = 0;
MemcacheParser::Command cmd;
string_view value;
MCReplyBuilder* builder = static_cast<MCReplyBuilder*>(cc_->reply_builder());
do {
string_view str = ToSV(io_buf_.InputBuffer());
result = memcache_parser_->Parse(str, &consumed, &cmd);
if (result != MemcacheParser::OK) {
io_buf_.ConsumeInput(consumed);
break;
}
size_t total_len = consumed;
if (MemcacheParser::IsStoreCmd(cmd.type)) {
total_len += cmd.bytes_len + 2;
if (io_buf_.InputLen() >= total_len) {
value = str.substr(consumed, cmd.bytes_len);
// TODO: dispatch.
} else {
return NEED_MORE;
}
}
// An optimization to skip dispatch_q_ if no pipelining is identified.
// We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the
// dispatch fiber pulls the last record but is still processing the command and then this
// fiber enters the condition below and executes out of order.
bool is_sync_dispatch = (cc_->conn_state.mask & ConnectionState::ASYNC_DISPATCH) == 0;
if (dispatch_q_.empty() && is_sync_dispatch) {
service_->DispatchMC(cmd, value, cc_.get());
}
io_buf_.ConsumeInput(total_len);
} while (!builder->GetError());
parser_error_ = result;
if (result == MemcacheParser::INPUT_PENDING) {
return NEED_MORE;
}
if (result == MemcacheParser::PARSE_ERROR) {
builder->SendError(""); // ERROR.
} else if (result == MemcacheParser::BAD_DELTA) {
builder->SendClientError("invalid numeric delta argument");
} else if (result != MemcacheParser::OK) {
builder->SendClientError("bad command line format");
}
return OK;
}
auto Connection::IoLoop(util::FiberSocketBase* peer) -> variant<error_code, ParserStatus> {
SinkReplyBuilder* builder = static_cast<SinkReplyBuilder*>(cc_->reply_builder());
ConnectionStats* stats = ServerState::tl_connection_stats();
error_code ec;
ParserStatus parse_status = OK;
auto fetch_builder_stats = [&] {
stats->io_write_cnt += builder->io_write_cnt();
stats->io_write_bytes += builder->io_write_bytes();
builder->reset_io_stats();
};
do {
fetch_builder_stats();
io::MutableBytes append_buf = io_buf_.AppendBuffer();
::io::Result<size_t> recv_sz = peer->Recv(append_buf);
if (!recv_sz) {
ec = recv_sz.error();
parse_status = OK;
break;
}
io_buf_.CommitWrite(*recv_sz);
stats->io_read_bytes += *recv_sz;
++stats->io_read_cnt;
if (redis_parser_)
parse_status = ParseRedis();
else {
DCHECK(memcache_parser_);
parse_status = ParseMemcache();
}
if (parse_status == NEED_MORE) {
parse_status = OK;
size_t capacity = io_buf_.Capacity();
if (capacity < kMaxReadSize) {
size_t parser_hint = 0;
if (redis_parser_)
parser_hint = redis_parser_->parselen_hint(); // Could be done for MC as well.
if (parser_hint > capacity) {
io_buf_.Reserve(std::min(kMaxReadSize, parser_hint));
} else if (append_buf.size() == *recv_sz && append_buf.size() > capacity / 2) {
// Last io used most of the io_buf to the end.
io_buf_.Reserve(capacity * 2); // Valid growth range.
}
if (capacity < io_buf_.Capacity()) {
VLOG(1) << "Growing io_buf to " << io_buf_.Capacity();
stats->read_buf_capacity += (io_buf_.Capacity() - capacity);
}
}
} else if (parse_status != OK) {
break;
}
ec = builder->GetError();
} while (peer->IsOpen() && !ec);
fetch_builder_stats();
if (ec)
return ec;
return parse_status;
}
// DispatchFiber handles commands coming from the InputLoop.
// Thus, InputLoop can quickly read data from the input buffer, parse it and push
// into the dispatch queue and DispatchFiber will run those commands asynchronously with InputLoop.
// Note: in some cases, InputLoop may decide to dispatch directly and bypass the DispatchFiber.
void Connection::DispatchFiber(util::FiberSocketBase* peer) {
this_fiber::properties<FiberProps>().set_name("DispatchFiber");
ConnectionStats* stats = ServerState::tl_connection_stats();
SinkReplyBuilder* builder = static_cast<SinkReplyBuilder*>(cc_->reply_builder());
while (!builder->GetError()) {
evc_.await([this] { return cc_->conn_state.IsClosing() || !dispatch_q_.empty(); });
if (cc_->conn_state.IsClosing())
break; // TODO: We have a memory leak with pending requests in the queue.
Request* req = dispatch_q_.front();
dispatch_q_.pop_front();
++stats->pipelined_cmd_cnt;
builder->SetBatchMode(!dispatch_q_.empty());
cc_->conn_state.mask |= ConnectionState::ASYNC_DISPATCH;
service_->DispatchCommand(CmdArgList{req->args.data(), req->args.size()}, cc_.get());
cc_->conn_state.mask &= ~ConnectionState::ASYNC_DISPATCH;
req->~Request();
mi_free(req);
}
cc_->conn_state.mask |= ConnectionState::CONN_CLOSING;
}
auto Connection::FromArgs(RespVec args, mi_heap_t* heap) -> Request* {
DCHECK(!args.empty());
size_t backed_sz = 0;
for (const auto& arg : args) {
CHECK_EQ(RespExpr::STRING, arg.type);
backed_sz += arg.GetBuf().size();
}
DCHECK(backed_sz);
constexpr auto kReqSz = sizeof(Request);
static_assert(kReqSz < MI_SMALL_SIZE_MAX);
static_assert(alignof(Request) == 8);
void* ptr = mi_heap_malloc_small(heap, kReqSz);
Request* req = new (ptr) Request{args.size(), backed_sz};
auto* next = req->storage.data();
for (size_t i = 0; i < args.size(); ++i) {
auto buf = args[i].GetBuf();
size_t s = buf.size();
memcpy(next, buf.data(), s);
req->args[i] = MutableSlice(next, s);
next += s;
}
return req;
}
} // namespace dfly

View file

@ -1,82 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/container/fixed_array.h>
#include <deque>
#include <variant>
#include "base/io_buf.h"
#include "core/resp_expr.h"
#include "server/common_types.h"
#include "util/connection.h"
#include "util/fibers/event_count.h"
typedef struct ssl_ctx_st SSL_CTX;
typedef struct mi_heap_s mi_heap_t;
namespace dfly {
class ConnectionContext;
class RedisParser;
class Service;
class MemcacheParser;
class Connection : public util::Connection {
public:
Connection(Protocol protocol, Service* service, SSL_CTX* ctx);
~Connection();
using error_code = std::error_code;
using ShutdownCb = std::function<void()>;
using ShutdownHandle = unsigned;
ShutdownHandle RegisterShutdownHook(ShutdownCb cb);
void UnregisterShutdownHook(ShutdownHandle id);
Protocol protocol() const {
return protocol_;
}
protected:
void OnShutdown() override;
private:
enum ParserStatus { OK, NEED_MORE, ERROR };
void HandleRequests() final;
//
io::Result<bool> CheckForHttpProto(util::FiberSocketBase* peer);
void ConnectionFlow(util::FiberSocketBase* peer);
std::variant<std::error_code, ParserStatus> IoLoop(util::FiberSocketBase* peer);
void DispatchFiber(util::FiberSocketBase* peer);
ParserStatus ParseRedis();
ParserStatus ParseMemcache();
base::IoBuf io_buf_;
std::unique_ptr<RedisParser> redis_parser_;
std::unique_ptr<MemcacheParser> memcache_parser_;
Service* service_;
SSL_CTX* ctx_;
std::unique_ptr<ConnectionContext> cc_;
struct Request;
static Request* FromArgs(RespVec args, mi_heap_t* heap);
std::deque<Request*> dispatch_q_; // coordinated via evc_.
util::fibers_ext::EventCount evc_;
unsigned parser_error_ = 0;
Protocol protocol_;
struct Shutdown;
std::unique_ptr<Shutdown> shutdown_;
};
} // namespace dfly

View file

@ -1,142 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/dragonfly_listener.h"
#include <openssl/ssl.h>
#include "base/logging.h"
#include "server/config_flags.h"
#include "server/dragonfly_connection.h"
#include "util/proactor_pool.h"
DEFINE_uint32(conn_threads, 0, "Number of threads used for handing server connections");
DEFINE_bool(tls, false, "");
DEFINE_bool(conn_use_incoming_cpu, false,
"If true uses incoming cpu of a socket in order to distribute"
" incoming connections");
CONFIG_string(tls_client_cert_file, "", "", TrueValidator);
CONFIG_string(tls_client_key_file, "", "", TrueValidator);
enum TlsClientAuth {
CL_AUTH_NO = 0,
CL_AUTH_YES = 1,
CL_AUTH_OPTIONAL = 2,
};
dfly::ConfigEnum tls_auth_clients_enum[] = {
{"no", CL_AUTH_NO},
{"yes", CL_AUTH_YES},
{"optional", CL_AUTH_OPTIONAL},
};
static int tls_auth_clients_opt = CL_AUTH_YES;
CONFIG_enum(tls_auth_clients, "yes", "", tls_auth_clients_enum, tls_auth_clients_opt);
namespace dfly {
using namespace util;
using namespace std;
// To connect: openssl s_client -cipher "ADH:@SECLEVEL=0" -state -crlf -connect 127.0.0.1:6380
static SSL_CTX* CreateSslCntx() {
SSL_CTX* ctx = SSL_CTX_new(TLS_server_method());
if (FLAGS_tls_client_key_file.empty()) {
// To connect - use openssl s_client -cipher with either:
// "AECDH:@SECLEVEL=0" or "ADH:@SECLEVEL=0" setting.
CHECK_EQ(1, SSL_CTX_set_cipher_list(ctx, "aNULL"));
// To allow anonymous ciphers.
SSL_CTX_set_security_level(ctx, 0);
// you can still connect with redis-cli with :
// redis-cli --tls --insecure --tls-ciphers "ADH:@SECLEVEL=0"
LOG(WARNING)
<< "tls-client-key-file not set, no keys are loaded and anonymous ciphers are enabled. "
<< "Do not use in production!";
} else { // tls_client_key_file is set.
CHECK_EQ(1,
SSL_CTX_use_PrivateKey_file(ctx, FLAGS_tls_client_key_file.c_str(), SSL_FILETYPE_PEM));
if (!FLAGS_tls_client_cert_file.empty()) {
// TO connect with redis-cli you need both tls-client-key-file and tls-client-cert-file
// loaded. Use `redis-cli --tls -p 6380 --insecure PING` to test
CHECK_EQ(1, SSL_CTX_use_certificate_chain_file(ctx, FLAGS_tls_client_cert_file.c_str()));
}
CHECK_EQ(1, SSL_CTX_set_cipher_list(ctx, "DEFAULT"));
}
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
unsigned mask = SSL_VERIFY_NONE;
// if (tls_auth_clients_opt)
// mask |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
SSL_CTX_set_verify(ctx, mask, NULL);
CHECK_EQ(1, SSL_CTX_set_dh_auto(ctx, 1));
return ctx;
}
Listener::Listener(Protocol protocol, Service* e) : engine_(e), protocol_(protocol) {
if (FLAGS_tls) {
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL);
ctx_ = CreateSslCntx();
}
}
Listener::~Listener() {
SSL_CTX_free(ctx_);
}
util::Connection* Listener::NewConnection(ProactorBase* proactor) {
return new Connection{protocol_, engine_, ctx_};
}
void Listener::PreShutdown() {
}
void Listener::PostShutdown() {
}
// We can limit number of threads handling dragonfly connections.
ProactorBase* Listener::PickConnectionProactor(LinuxSocketBase* sock) {
util::ProactorPool* pp = pool();
uint32_t total = FLAGS_conn_threads;
uint32_t id = kuint32max;
if (total == 0 || total > pp->size()) {
total = pp->size();
}
if (FLAGS_conn_use_incoming_cpu) {
int fd = sock->native_handle();
int cpu, napi_id;
socklen_t len = sizeof(cpu);
CHECK_EQ(0, getsockopt(fd, SOL_SOCKET, SO_INCOMING_CPU, &cpu, &len));
CHECK_EQ(0, getsockopt(fd, SOL_SOCKET, SO_INCOMING_NAPI_ID, &napi_id, &len));
VLOG(1) << "CPU/NAPI for connection " << fd << " is " << cpu << "/" << napi_id;
vector<unsigned> ids = pool()->MapCpuToThreads(cpu);
if (!ids.empty()) {
id = ids.front();
}
}
if (id == kuint32max) {
id = next_id_.fetch_add(1, std::memory_order_relaxed);
}
return pp->at(id % total);
}
} // namespace dfly

View file

@ -1,36 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include "util/listener_interface.h"
#include "server/common_types.h"
typedef struct ssl_ctx_st SSL_CTX;
namespace dfly {
class Service;
class Listener : public util::ListenerInterface {
public:
Listener(Protocol protocol, Service*);
~Listener();
private:
util::Connection* NewConnection(util::ProactorBase* proactor) final;
util::ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final;
void PreShutdown();
void PostShutdown();
Service* engine_;
std::atomic_uint32_t next_id_{0};
Protocol protocol_;
SSL_CTX* ctx_ = nullptr;
};
} // namespace dfly

View file

@ -2,6 +2,11 @@
// See LICENSE for licensing terms.
//
extern "C" {
#include "redis/sds.h"
#include "redis/zmalloc.h"
}
#include <absl/strings/ascii.h>
#include <absl/strings/str_join.h>
#include <absl/strings/strip.h>
@ -9,9 +14,9 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
#include "server/conn_context.h"
#include "server/main_service.h"
#include "server/redis_parser.h"
#include "server/test_utils.h"
#include "util/uring/uring_pool.h"
@ -43,8 +48,41 @@ class DflyEngineTest : public BaseFamilyTest {
DflyEngineTest() : BaseFamilyTest() {
num_threads_ = kPoolThreadCount;
}
static void SetUpTestSuite() {
init_zmalloc_threadlocal();
}
};
// TODO: to implement equivalent parsing in redis parser.
TEST_F(DflyEngineTest, Sds) {
int argc;
sds* argv = sdssplitargs("\r\n", &argc);
EXPECT_EQ(0, argc);
sdsfreesplitres(argv, argc);
argv = sdssplitargs("\026 \020 \200 \277 \r\n", &argc);
EXPECT_EQ(4, argc);
EXPECT_STREQ("\026", argv[0]);
sdsfreesplitres(argv, argc);
argv = sdssplitargs(R"(abc "oops\n" )"
"\r\n",
&argc);
EXPECT_EQ(2, argc);
EXPECT_STREQ("oops\n", argv[1]);
sdsfreesplitres(argv, argc);
argv = sdssplitargs(R"( "abc\xf0" )"
"\t'oops\n' \r\n",
&argc);
ASSERT_EQ(2, argc);
EXPECT_STREQ("abc\xf0", argv[0]);
EXPECT_STREQ("oops\n", argv[1]);
sdsfreesplitres(argv, argc);
}
TEST_F(DflyEngineTest, Multi) {
RespVec resp = Run({"multi"});
ASSERT_THAT(resp, RespEq("OK"));

View file

@ -7,19 +7,15 @@
#include <atomic>
#include <string>
#include "facade/error.h"
namespace dfly {
std::string WrongNumArgsError(std::string_view cmd);
extern const char kSyntaxErr[];
extern const char kWrongTypeErr[];
extern const char kKeyNotFoundErr[];
extern const char kInvalidIntErr[];
extern const char kUintErr[];
extern const char kDbIndOutOfRangeErr[];
extern const char kInvalidDbIndErr[];
extern const char kScriptNotFound[];
extern const char kAuthRejected[];
using facade::kWrongTypeErr;
using facade::kInvalidIntErr;
using facade::kSyntaxErr;
using facade::kInvalidDbIndErr;
using facade::kDbIndOutOfRangeErr;
#ifndef RETURN_ON_ERR

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
@ -20,6 +20,7 @@ DEFINE_uint32(dbnum, 16, "Number of databases");
namespace dfly {
using namespace std;
using facade::Protocol;
namespace {
@ -177,6 +178,7 @@ void GenericFamily::Del(CmdArgList args, ConnectionContext* cntx) {
uint32_t del_cnt = result.load(memory_order_relaxed);
if (is_mc) {
using facade::MCReplyBuilder;
MCReplyBuilder* mc_builder = static_cast<MCReplyBuilder*>(cntx->reply_builder());
if (del_cnt == 0) {

View file

@ -4,7 +4,7 @@
#pragma once
#include "core/op_status.h"
#include "facade/op_status.h"
#include "server/common_types.h"
namespace util {
@ -13,6 +13,9 @@ class ProactorPool;
namespace dfly {
using facade::OpResult;
using facade::OpStatus;
class ConnectionContext;
class CommandRegistry;
class EngineShard;

View file

@ -6,6 +6,7 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/engine_shard_set.h"

View file

@ -4,13 +4,14 @@
#pragma once
#include "core/op_status.h"
#include "facade/op_status.h"
#include "server/common_types.h"
namespace dfly {
class ConnectionContext;
class CommandRegistry;
using facade::OpResult;
class HSetFamily {
public:

View file

@ -7,11 +7,13 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "server/test_utils.h"
#include "facade/facade_test.h"
using namespace testing;
using namespace std;
using namespace util;
using namespace boost;
using namespace facade;
namespace dfly {

View file

@ -4,11 +4,13 @@
#pragma once
#include "core/op_status.h"
#include "facade/op_status.h"
#include "server/common_types.h"
namespace dfly {
using facade::OpResult;
class ConnectionContext;
class CommandRegistry;
class EngineShard;

View file

@ -8,6 +8,7 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/engine_shard_set.h"

View file

@ -16,6 +16,8 @@ extern "C" {
#include <filesystem>
#include "base/logging.h"
#include "facade/dragonfly_connection.h"
#include "facade/error.h"
#include "server/conn_context.h"
#include "server/error.h"
#include "server/generic_family.h"
@ -43,6 +45,8 @@ using base::VarzValue;
using ::boost::intrusive_ptr;
namespace fibers = ::boost::fibers;
namespace this_fiber = ::boost::this_fiber;
using facade::MCReplyBuilder;
using facade::RedisReplyBuilder;
namespace {
@ -259,7 +263,7 @@ bool EvalValidator(CmdArgList args, ConnectionContext* cntx) {
int32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
(*cntx)->SendError(kInvalidIntErr);
(*cntx)->SendError(facade::kInvalidIntErr);
return false;
}
@ -329,7 +333,7 @@ void Service::Shutdown() {
shard_set_.RunBlockingInParallel([&](EngineShard*) { EngineShard::DestroyThreadLocal(); });
}
void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) {
CHECK(!args.empty());
DCHECK_NE(0u, shard_set_.size()) << "Init was not called";
@ -344,9 +348,10 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
etl.RecordCmd();
absl::Cleanup multi_error = [cntx] {
if (cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE) {
cntx->conn_state.exec_state = ConnectionState::EXEC_ERROR;
ConnectionContext* dfly_cntx = static_cast<ConnectionContext*>(cntx);
absl::Cleanup multi_error = [dfly_cntx] {
if (dfly_cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE) {
dfly_cntx->conn_state.exec_state = ConnectionState::EXEC_ERROR;
}
};
@ -362,22 +367,22 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
string_view cmd_name{cid->name()};
if ((cntx->conn_state.mask & (ConnectionState::REQ_AUTH | ConnectionState::AUTHENTICATED)) ==
ConnectionState::REQ_AUTH) {
if (cntx->req_auth && !cntx->authenticated) {
if (cmd_name != "AUTH") {
return (*cntx)->SendError("-NOAUTH Authentication required.");
}
}
bool under_script = cntx->conn_state.script_info.has_value();
bool under_script = dfly_cntx->conn_state.script_info.has_value();
if (under_script && (cid->opt_mask() & CO::NOSCRIPT)) {
return (*cntx)->SendError("This Redis command is not allowed from script");
}
bool is_write_cmd =
(cid->opt_mask() & CO::WRITE) || (under_script && cntx->conn_state.script_info->is_write);
bool under_multi = cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd;
bool is_write_cmd = (cid->opt_mask() & CO::WRITE) ||
(under_script && dfly_cntx->conn_state.script_info->is_write);
bool under_multi =
dfly_cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd;
if (!etl.is_master && is_write_cmd) {
(*cntx)->SendError("-READONLY You can't write against a read only replica.");
@ -386,15 +391,15 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
if ((cid->arity() > 0 && args.size() != size_t(cid->arity())) ||
(cid->arity() < 0 && args.size() < size_t(-cid->arity()))) {
return (*cntx)->SendError(WrongNumArgsError(cmd_str));
return (*cntx)->SendError(facade::WrongNumArgsError(cmd_str));
}
if (cid->key_arg_step() == 2 && (args.size() % 2) == 0) {
return (*cntx)->SendError(WrongNumArgsError(cmd_str));
return (*cntx)->SendError(facade::WrongNumArgsError(cmd_str));
}
// Validate more complicated cases with custom validators.
if (!cid->Validate(args, cntx)) {
if (!cid->Validate(args, dfly_cntx)) {
return;
}
@ -412,14 +417,14 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
std::move(multi_error).Cancel();
if (cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd) {
if (dfly_cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd) {
// TODO: protect against aggregating huge transactions.
StoredCmd stored_cmd{cid};
stored_cmd.cmd.reserve(args.size());
for (size_t i = 0; i < args.size(); ++i) {
stored_cmd.cmd.emplace_back(ArgS(args, i));
}
cntx->conn_state.exec_body.push_back(std::move(stored_cmd));
dfly_cntx->conn_state.exec_body.push_back(std::move(stored_cmd));
return (*cntx)->SendSimpleString("QUEUED");
}
@ -430,48 +435,48 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
intrusive_ptr<Transaction> dist_trans;
if (under_script) {
DCHECK(cntx->transaction);
DCHECK(dfly_cntx->transaction);
KeyIndex key_index = DetermineKeys(cid, args);
for (unsigned i = key_index.start; i < key_index.end; ++i) {
string_view key = ArgS(args, i);
if (!cntx->conn_state.script_info->keys.contains(key)) {
if (!dfly_cntx->conn_state.script_info->keys.contains(key)) {
return (*cntx)->SendError("script tried accessing undeclared key");
}
}
cntx->transaction->SetExecCmd(cid);
cntx->transaction->InitByArgs(cntx->conn_state.db_index, args);
dfly_cntx->transaction->SetExecCmd(cid);
dfly_cntx->transaction->InitByArgs(dfly_cntx->conn_state.db_index, args);
} else {
DCHECK(cntx->transaction == nullptr);
DCHECK(dfly_cntx->transaction == nullptr);
if (IsTransactional(cid)) {
dist_trans.reset(new Transaction{cid, &shard_set_});
cntx->transaction = dist_trans.get();
dfly_cntx->transaction = dist_trans.get();
dist_trans->InitByArgs(cntx->conn_state.db_index, args);
cntx->last_command_debug.shards_count = cntx->transaction->unique_shard_cnt();
dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args);
dfly_cntx->last_command_debug.shards_count = dfly_cntx->transaction->unique_shard_cnt();
} else {
cntx->transaction = nullptr;
dfly_cntx->transaction = nullptr;
}
}
cntx->cid = cid;
dfly_cntx->cid = cid;
cmd_req.Inc({cmd_name});
cid->Invoke(args, cntx);
cid->Invoke(args, dfly_cntx);
end_usec = ProactorBase::GetMonotonicTimeNs();
request_latency_usec.IncBy(cmd_str, (end_usec - start_usec) / 1000);
if (dist_trans) {
cntx->last_command_debug.clock = dist_trans->txid();
cntx->last_command_debug.is_ooo = dist_trans->IsOOO();
dfly_cntx->last_command_debug.clock = dist_trans->txid();
dfly_cntx->last_command_debug.is_ooo = dist_trans->IsOOO();
}
if (!under_script) {
cntx->transaction = nullptr;
dfly_cntx->transaction = nullptr;
}
}
void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
ConnectionContext* cntx) {
facade::ConnectionContext* cntx) {
absl::InlinedVector<MutableSlice, 8> args;
char cmd_name[16];
char ttl[16];
@ -533,6 +538,8 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
args.emplace_back(key, cmd.key.size());
}
ConnectionContext* dfly_cntx = static_cast<ConnectionContext*>(cntx);
if (MemcacheParser::IsStoreCmd(cmd.type)) {
char* v = const_cast<char*>(value.data());
args.emplace_back(v, value.size());
@ -546,7 +553,7 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
args.emplace_back(ttl_op, 2);
args.emplace_back(ttl, next - ttl);
}
cntx->conn_state.memcache_flag = cmd.flags;
dfly_cntx->conn_state.memcache_flag = cmd.flags;
} else if (cmd.type < MemcacheParser::QUIT) { // read commands
for (auto s : cmd.keys_ext) {
char* key = const_cast<char*>(s.data());
@ -561,7 +568,28 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
DispatchCommand(CmdArgList{args}, cntx);
// Reset back.
cntx->conn_state.memcache_flag = 0;
dfly_cntx->conn_state.memcache_flag = 0;
}
facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer,
facade::Connection* owner) {
ConnectionContext* res = new ConnectionContext{peer, owner};
res->shard_set = &shard_set();
res->req_auth = IsPassProtected();
// a bit of a hack. I set up breaker callback here for the owner.
// Should work though it's confusing to have it here.
owner->RegisterOnBreak([res](uint32_t) {
if (res->transaction) {
res->transaction->BreakOnClose();
}
});
return res;
}
facade::ConnectionStats* Service::GetThreadLocalConnectionStats() {
return ServerState::tl_connection_stats();
}
bool Service::IsLocked(DbIndex db_index, std::string_view key) const {
@ -590,14 +618,10 @@ bool Service::IsPassProtected() const {
return !FLAGS_requirepass.empty();
}
void Service::RegisterHttp(HttpListenerBase* listener) {
CHECK_NOTNULL(listener);
http_listener_ = listener;
}
void Service::Quit(CmdArgList args, ConnectionContext* cntx) {
if (cntx->protocol() == Protocol::REDIS)
if (cntx->protocol() == facade::Protocol::REDIS)
(*cntx)->SendOk();
using facade::SinkReplyBuilder;
SinkReplyBuilder* builder = static_cast<SinkReplyBuilder*>(cntx->reply_builder());
builder->CloseConnection();
@ -615,7 +639,7 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) {
DCHECK(cntx->transaction);
InterpreterReplier replier(reply);
ReplyBuilderInterface* orig = cntx->Inject(&replier);
facade::ReplyBuilderInterface* orig = cntx->Inject(&replier);
DispatchCommand(std::move(args), cntx);
@ -670,7 +694,7 @@ void Service::EvalSha(CmdArgList args, ConnectionContext* cntx) {
if (!exists) {
const char* body = (sha.size() == 40) ? server_family_.script_mgr()->Find(sha) : nullptr;
if (!body) {
return (*cntx)->SendError(kScriptNotFound);
return (*cntx)->SendError(facade::kScriptNotFound);
}
string res;
@ -692,7 +716,7 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
// Sanitizing the input to avoid code injection.
if (eval_args.sha.size() != 40 || !IsSHA(eval_args.sha)) {
return (*cntx)->SendError(kScriptNotFound);
return (*cntx)->SendError(facade::kScriptNotFound);
}
bool exists = interpreter->Exists(eval_args.sha);
@ -700,7 +724,7 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
if (!exists) {
const char* body = server_family_.script_mgr()->Find(eval_args.sha);
if (!body) {
return (*cntx)->SendError(kScriptNotFound);
return (*cntx)->SendError(facade::kScriptNotFound);
}
string res;

View file

@ -5,11 +5,10 @@
#pragma once
#include "base/varz_value.h"
#include "facade/service_interface.h"
#include "server/command_registry.h"
#include "server/engine_shard_set.h"
#include "server/memcache_parser.h"
#include "server/server_family.h"
#include "util/http/http_handler.h"
namespace util {
class AcceptServer;
@ -19,8 +18,9 @@ namespace dfly {
class Interpreter;
class ObjectExplorer; // for Interpreter
using facade::MemcacheParser;
class Service {
class Service : public facade::ServiceInterface {
public:
using error_code = std::error_code;
@ -34,15 +34,18 @@ class Service {
explicit Service(util::ProactorPool* pp);
~Service();
void RegisterHttp(util::HttpListenerBase* listener);
void Init(util::AcceptServer* acceptor, const InitOpts& opts = InitOpts{});
void Shutdown();
void DispatchCommand(CmdArgList args, ConnectionContext* cntx);
void DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) final;
void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
ConnectionContext* cntx);
facade::ConnectionContext* cntx) final;
facade::ConnectionContext* CreateContext(util::FiberSocketBase* peer,
facade::Connection* owner) final;
facade::ConnectionStats* GetThreadLocalConnectionStats() final;
uint32_t shard_count() const {
return shard_set_.size();
@ -60,10 +63,6 @@ class Service {
return pp_;
}
util::HttpListenerBase* http_listener() {
return http_listener_;
}
bool IsPassProtected() const;
private:
@ -91,8 +90,6 @@ class Service {
EngineShardSet shard_set_;
ServerFamily server_family_;
CommandRegistry registry_;
util::HttpListenerBase* http_listener_ = nullptr;
};
} // namespace dfly

View file

@ -1,176 +0,0 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/memcache_parser.h"
#include <absl/container/flat_hash_map.h>
#include <absl/strings/ascii.h>
#include <absl/strings/numbers.h>
#include "base/stl_util.h"
namespace dfly {
using namespace std;
using MP = MemcacheParser;
namespace {
MP::CmdType From(string_view token) {
static absl::flat_hash_map<string_view, MP::CmdType> cmd_map{
{"set", MP::SET}, {"add", MP::ADD}, {"replace", MP::REPLACE},
{"append", MP::APPEND}, {"prepend", MP::PREPEND}, {"cas", MP::CAS},
{"get", MP::GET}, {"gets", MP::GETS}, {"gat", MP::GAT},
{"gats", MP::GATS}, {"stats", MP::STATS}, {"incr", MP::INCR},
{"decr", MP::DECR}, {"delete", MP::DELETE}, {"flush_all", MP::FLUSHALL},
{"quit", MP::QUIT}, {"version", MP::VERSION},
};
auto it = cmd_map.find(token);
if (it == cmd_map.end())
return MP::INVALID;
return it->second;
}
MP::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) {
unsigned opt_pos = 3;
if (res->type == MP::CAS) {
if (num_tokens <= opt_pos)
return MP::PARSE_ERROR;
++opt_pos;
}
uint32_t flags;
if (!absl::SimpleAtoi(tokens[0], &flags) || !absl::SimpleAtoi(tokens[1], &res->expire_ts) ||
!absl::SimpleAtoi(tokens[2], &res->bytes_len))
return MP::BAD_INT;
if (res->type == MP::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) {
return MP::BAD_INT;
}
res->flags = flags;
if (num_tokens == opt_pos + 1) {
if (tokens[opt_pos] == "noreply") {
res->no_reply = true;
} else {
return MP::PARSE_ERROR;
}
} else if (num_tokens > opt_pos + 1) {
return MP::PARSE_ERROR;
}
return MP::OK;
}
MP::Result ParseValueless(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) {
unsigned key_pos = 0;
if (res->type == MP::GAT || res->type == MP::GATS) {
if (!absl::SimpleAtoi(tokens[0], &res->expire_ts)) {
return MP::BAD_INT;
}
++key_pos;
}
res->key = tokens[key_pos++];
if (key_pos < num_tokens && base::_in(res->type, {MP::STATS, MP::FLUSHALL}))
return MP::PARSE_ERROR; // we do not support additional arguments for now.
if (res->type == MP::INCR || res->type == MP::DECR) {
if (key_pos == num_tokens)
return MP::PARSE_ERROR;
if (!absl::SimpleAtoi(tokens[key_pos], &res->delta))
return MP::BAD_DELTA;
++key_pos;
}
while (key_pos < num_tokens) {
res->keys_ext.push_back(tokens[key_pos++]);
}
if (res->type >= MP::DELETE) { // write commands
if (!res->keys_ext.empty() && res->keys_ext.back() == "noreply") {
res->no_reply = true;
res->keys_ext.pop_back();
}
}
return MP::OK;
}
} // namespace
auto MP::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result {
auto pos = str.find('\n');
*consumed = 0;
if (pos == string_view::npos) {
// TODO: it's over simplified since we may process GET/GAT command that is not limited to
// 300 characters.
return str.size() > 300 ? PARSE_ERROR : INPUT_PENDING;
}
if (pos == 0) {
return PARSE_ERROR;
}
*consumed = pos + 1;
// cas <key> <flags> <exptime> <bytes> <cas unique> [noreply]\r\n
// get <key>*\r\n
string_view tokens[8];
unsigned num_tokens = 0;
uint32_t cur = 0;
while (cur < pos && str[cur] == ' ')
++cur;
uint32_t s = cur;
for (; cur <= pos; ++cur) {
if (absl::ascii_isspace(str[cur])) {
if (cur != s) {
tokens[num_tokens++] = str.substr(s, cur - s);
if (num_tokens == ABSL_ARRAYSIZE(tokens)) {
++cur;
s = cur;
break;
}
}
s = cur + 1;
}
}
if (num_tokens == 0)
return PARSE_ERROR;
while (cur < pos - 1) {
if (str[cur] != ' ')
return PARSE_ERROR;
++cur;
}
cmd->type = From(tokens[0]);
if (cmd->type == INVALID) {
return UNKNOWN_CMD;
}
if (cmd->type <= CAS) { // Store command
if (num_tokens < 5 || tokens[1].size() > 250) {
return MP::PARSE_ERROR;
}
// memcpy(single_key_, tokens[0].data(), tokens[0].size()); // we copy the key
cmd->key = string_view{tokens[1].data(), tokens[1].size()};
return ParseStore(tokens + 2, num_tokens - 2, cmd);
}
if (num_tokens == 1) {
if (base::_in(cmd->type, {MP::STATS, MP::FLUSHALL, MP::QUIT, MP::VERSION}))
return MP::OK;
return MP::PARSE_ERROR;
}
return ParseValueless(tokens + 1, num_tokens - 1, cmd);
};
} // namespace dfly

View file

@ -1,78 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <string_view>
#include <vector>
namespace dfly {
// Memcache parser does not parse value blobs, only the commands.
// The expectation is that the caller will parse the command and
// then will follow up with reading the blob data directly from source.
class MemcacheParser {
public:
enum CmdType {
INVALID = 0,
SET = 1,
ADD = 2,
REPLACE = 3,
APPEND = 4,
PREPEND = 5,
CAS = 6,
// Retrieval
GET = 10,
GETS = 11,
GAT = 12,
GATS = 13,
STATS = 14,
QUIT = 20,
VERSION = 21,
// The rest of write commands.
DELETE = 31,
INCR = 32,
DECR = 33,
FLUSHALL = 34,
};
// According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol
struct Command {
CmdType type = INVALID;
std::string_view key;
std::vector<std::string_view> keys_ext;
union {
uint64_t cas_unique = 0; // for CAS COMMAND
uint64_t delta; // for DECR/INCR commands.
};
uint32_t expire_ts = 0; // relative time in seconds.
uint32_t bytes_len = 0;
uint32_t flags = 0;
bool no_reply = false;
};
enum Result {
OK,
INPUT_PENDING,
UNKNOWN_CMD,
BAD_INT,
PARSE_ERROR,
BAD_DELTA,
};
static bool IsStoreCmd(CmdType type) {
return type >= SET && type <= CAS;
}
Result Parse(std::string_view str, uint32_t* consumed, Command* res);
private:
};
} // namespace dfly

View file

@ -1,83 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/memcache_parser.h"
#include <gmock/gmock.h>
#include "absl/strings/str_cat.h"
#include "base/gtest.h"
#include "base/logging.h"
#include "server/test_utils.h"
using namespace testing;
using namespace std;
namespace dfly {
class MCParserTest : public testing::Test {
protected:
RedisParser::Result Parse(std::string_view str);
MemcacheParser parser_;
MemcacheParser::Command cmd_;
uint32_t consumed_;
unique_ptr<uint8_t[]> stash_;
};
TEST_F(MCParserTest, Basic) {
MemcacheParser::Result st = parser_.Parse("set a 1 20 3\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ("a", cmd_.key);
EXPECT_EQ(1, cmd_.flags);
EXPECT_EQ(20, cmd_.expire_ts);
EXPECT_EQ(3, cmd_.bytes_len);
EXPECT_EQ(MemcacheParser::SET, cmd_.type);
st = parser_.Parse("quit\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::QUIT, cmd_.type);
}
TEST_F(MCParserTest, Incr) {
MemcacheParser::Result st = parser_.Parse("incr a\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::PARSE_ERROR, st);
st = parser_.Parse("incr a 1\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::INCR, cmd_.type);
EXPECT_EQ("a", cmd_.key);
EXPECT_EQ(1, cmd_.delta);
EXPECT_FALSE(cmd_.no_reply);
st = parser_.Parse("incr a -1\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::BAD_DELTA, st);
st = parser_.Parse("decr b 10 noreply\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::DECR, cmd_.type);
EXPECT_EQ(10, cmd_.delta);
}
TEST_F(MCParserTest, Stats) {
MemcacheParser::Result st = parser_.Parse("stats foo\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(consumed_, 11);
EXPECT_EQ(cmd_.type, MemcacheParser::STATS);
EXPECT_EQ("foo", cmd_.key);
cmd_ = MemcacheParser::Command{};
st = parser_.Parse("stats \r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(consumed_, 9);
EXPECT_EQ(cmd_.type, MemcacheParser::STATS);
EXPECT_EQ("", cmd_.key);
cmd_ = MemcacheParser::Command{};
st = parser_.Parse("stats fpp bar\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::PARSE_ERROR, st);
}
} // namespace dfly

View file

@ -28,6 +28,7 @@ using base::IoBuf;
using nonstd::make_unexpected;
using namespace util;
using rdb::errc;
using facade::operator""_KB;
class error_category : public std::error_category {
public:

View file

@ -24,6 +24,7 @@ using namespace std;
using base::IoBuf;
using io::Bytes;
using nonstd::make_unexpected;
using facade::operator""_KB;
namespace {

View file

@ -1,438 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/redis_parser.h"
#include <absl/strings/numbers.h>
#include "base/logging.h"
namespace dfly {
using namespace std;
namespace {
constexpr int kMaxArrayLen = 1024;
constexpr int64_t kMaxBulkLen = 64 * (1ul << 20); // 64MB.
} // namespace
auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> Result {
*consumed = 0;
res->clear();
if (str.size() < 2) {
return INPUT_PENDING;
}
if (state_ == CMD_COMPLETE_S) {
state_ = INIT_S;
}
if (state_ == INIT_S) {
InitStart(str[0], res);
}
if (!cached_expr_)
cached_expr_ = res;
while (state_ != CMD_COMPLETE_S) {
last_consumed_ = 0;
switch (state_) {
case ARRAY_LEN_S:
last_result_ = ConsumeArrayLen(str);
break;
case PARSE_ARG_S:
if (str.size() < 4) {
last_result_ = INPUT_PENDING;
} else {
last_result_ = ParseArg(str);
}
break;
case INLINE_S:
DCHECK(parse_stack_.empty());
last_result_ = ParseInline(str);
break;
case BULK_STR_S:
last_result_ = ConsumeBulk(str);
break;
case FINISH_ARG_S:
HandleFinishArg();
break;
default:
LOG(FATAL) << "Unexpected state " << int(state_);
}
*consumed += last_consumed_;
if (last_result_ != OK) {
break;
}
str.remove_prefix(last_consumed_);
}
if (last_result_ == INPUT_PENDING) {
StashState(res);
} else if (last_result_ == OK) {
DCHECK(cached_expr_);
if (res != cached_expr_) {
DCHECK(!stash_.empty());
*res = *cached_expr_;
}
}
return last_result_;
}
void RedisParser::InitStart(uint8_t prefix_b, RespExpr::Vec* res) {
buf_stash_.clear();
stash_.clear();
cached_expr_ = res;
parse_stack_.clear();
last_stashed_level_ = 0;
last_stashed_index_ = 0;
switch (prefix_b) {
case '$':
case ':':
case '+':
case '-':
state_ = PARSE_ARG_S;
parse_stack_.emplace_back(1, cached_expr_); // expression of length 1.
break;
case '*':
state_ = ARRAY_LEN_S;
break;
default:
state_ = INLINE_S;
break;
}
}
void RedisParser::StashState(RespExpr::Vec* res) {
if (cached_expr_->empty() && stash_.empty()) {
cached_expr_ = nullptr;
return;
}
if (cached_expr_ == res) {
stash_.emplace_back(new RespExpr::Vec(*res));
cached_expr_ = stash_.back().get();
}
DCHECK_LT(last_stashed_level_, stash_.size());
while (true) {
auto& cur = *stash_[last_stashed_level_];
for (; last_stashed_index_ < cur.size(); ++last_stashed_index_) {
auto& e = cur[last_stashed_index_];
if (RespExpr::STRING == e.type) {
Buffer& ebuf = get<Buffer>(e.u);
if (ebuf.empty() && last_stashed_index_ + 1 == cur.size())
break;
if (!ebuf.empty() && !e.has_support) {
BlobPtr ptr(new uint8_t[ebuf.size()]);
memcpy(ptr.get(), ebuf.data(), ebuf.size());
ebuf = Buffer{ptr.get(), ebuf.size()};
buf_stash_.push_back(std::move(ptr));
e.has_support = true;
}
}
}
if (last_stashed_level_ + 1 == stash_.size())
break;
++last_stashed_level_;
last_stashed_index_ = 0;
}
}
auto RedisParser::ParseInline(Buffer str) -> Result {
DCHECK(!str.empty());
uint8_t* ptr = str.begin();
uint8_t* end = str.end();
uint8_t* token_start = ptr;
if (is_broken_token_) {
while (ptr != end && *ptr > 32)
++ptr;
size_t len = ptr - token_start;
ExtendLastString(Buffer(token_start, len));
if (ptr != end) {
is_broken_token_ = false;
}
}
auto is_finish = [&] { return ptr == end || *ptr == '\n'; };
while (true) {
while (!is_finish() && *ptr <= 32) {
++ptr;
}
// We do not test for \r in order to accept 'nc' input.
if (is_finish())
break;
DCHECK(!is_broken_token_);
token_start = ptr;
while (ptr != end && *ptr > 32)
++ptr;
cached_expr_->emplace_back(RespExpr::STRING);
cached_expr_->back().u = Buffer{token_start, size_t(ptr - token_start)};
}
last_consumed_ = ptr - str.data();
if (ptr == end) { // we have not finished parsing.
if (ptr[-1] > 32) {
// we stopped in the middle of the token.
is_broken_token_ = true;
}
return INPUT_PENDING;
} else {
++last_consumed_; // consume the delimiter as well.
}
state_ = CMD_COMPLETE_S;
return OK;
}
auto RedisParser::ParseNum(Buffer str, int64_t* res) -> Result {
if (str.size() < 4) {
return INPUT_PENDING;
}
char* s = reinterpret_cast<char*>(str.data() + 1);
char* pos = reinterpret_cast<char*>(memchr(s, '\n', str.size() - 1));
if (!pos) {
return str.size() < 32 ? INPUT_PENDING : BAD_INT;
}
if (pos[-1] != '\r') {
return BAD_INT;
}
bool success = absl::SimpleAtoi(std::string_view{s, size_t(pos - s - 1)}, res);
if (!success) {
return BAD_INT;
}
last_consumed_ = (pos - s) + 2;
return OK;
}
auto RedisParser::ConsumeArrayLen(Buffer str) -> Result {
int64_t len;
Result res = ParseNum(str, &len);
switch (res) {
case INPUT_PENDING:
return INPUT_PENDING;
case BAD_INT:
return BAD_ARRAYLEN;
case OK:
if (len < -1 || len > kMaxArrayLen)
return BAD_ARRAYLEN;
break;
default:
LOG(ERROR) << "Unexpected result " << res;
}
if (parse_stack_.size() > 0 && server_mode_)
return BAD_STRING;
if (parse_stack_.size() == 0 && !cached_expr_->empty())
return BAD_STRING;
if (len <= 0) {
cached_expr_->emplace_back(len == -1 ? RespExpr::NIL_ARRAY : RespExpr::ARRAY);
if (len < 0)
cached_expr_->back().u.emplace<RespVec*>(nullptr); // nil
else {
static RespVec empty_vec;
cached_expr_->back().u = &empty_vec;
}
state_ = (parse_stack_.empty()) ? CMD_COMPLETE_S : FINISH_ARG_S;
return OK;
}
parse_stack_.emplace_back(len, cached_expr_);
if (!cached_expr_->empty()) {
DCHECK(!server_mode_);
cached_expr_->emplace_back(RespExpr::ARRAY);
stash_.emplace_back(new RespExpr::Vec());
RespExpr::Vec* arr = stash_.back().get();
arr->reserve(len);
cached_expr_->back().u = arr;
cached_expr_ = arr;
}
state_ = PARSE_ARG_S;
return OK;
}
auto RedisParser::ParseArg(Buffer str) -> Result {
char c = str[0];
if (c == '$') {
int64_t len;
Result res = ParseNum(str, &len);
switch (res) {
case INPUT_PENDING:
return INPUT_PENDING;
case BAD_INT:
return BAD_ARRAYLEN;
case OK:
if (len < -1 || len > kMaxBulkLen)
return BAD_ARRAYLEN;
break;
default:
LOG(ERROR) << "Unexpected result " << res;
}
if (len < 0) {
state_ = FINISH_ARG_S;
cached_expr_->emplace_back(RespExpr::NIL);
} else {
cached_expr_->emplace_back(RespExpr::STRING);
bulk_len_ = len;
state_ = BULK_STR_S;
}
cached_expr_->back().u = Buffer{};
return OK;
}
if (server_mode_) {
return BAD_BULKLEN;
}
if (c == '*') {
return ConsumeArrayLen(str);
}
char* s = reinterpret_cast<char*>(str.data() + 1);
char* eol = reinterpret_cast<char*>(memchr(s, '\n', str.size() - 1));
if (c == '+' || c == '-') { // Simple string or error.
DCHECK(!server_mode_);
if (!eol) {
return str.size() < 256 ? INPUT_PENDING : BAD_STRING;
}
if (eol[-1] != '\r')
return BAD_STRING;
cached_expr_->emplace_back(c == '+' ? RespExpr::STRING : RespExpr::ERROR);
cached_expr_->back().u = Buffer{reinterpret_cast<uint8_t*>(s), size_t((eol - 1) - s)};
} else if (c == ':') {
DCHECK(!server_mode_);
if (!eol) {
return str.size() < 32 ? INPUT_PENDING : BAD_INT;
}
int64_t ival;
std::string_view tok{s, size_t((eol - s) - 1)};
if (eol[-1] != '\r' || !absl::SimpleAtoi(tok, &ival))
return BAD_INT;
cached_expr_->emplace_back(RespExpr::INT64);
cached_expr_->back().u = ival;
} else {
return BAD_STRING;
}
last_consumed_ = (eol - s) + 2;
state_ = FINISH_ARG_S;
return OK;
}
auto RedisParser::ConsumeBulk(Buffer str) -> Result {
auto& bulk_str = get<Buffer>(cached_expr_->back().u);
if (str.size() >= bulk_len_ + 2) {
if (str[bulk_len_] != '\r' || str[bulk_len_ + 1] != '\n') {
return BAD_STRING;
}
if (bulk_len_) {
if (is_broken_token_) {
memcpy(bulk_str.end(), str.data(), bulk_len_);
bulk_str = Buffer{bulk_str.data(), bulk_str.size() + bulk_len_};
} else {
bulk_str = str.subspan(0, bulk_len_);
}
}
is_broken_token_ = false;
state_ = FINISH_ARG_S;
last_consumed_ = bulk_len_ + 2;
bulk_len_ = 0;
return OK;
}
if (str.size() >= 32) {
DCHECK(bulk_len_);
size_t len = std::min<size_t>(str.size(), bulk_len_);
if (is_broken_token_) {
memcpy(bulk_str.end(), str.data(), len);
bulk_str = Buffer{bulk_str.data(), bulk_str.size() + len};
DVLOG(1) << "Extending bulk stash to size " << bulk_str.size();
} else {
DVLOG(1) << "New bulk stash size " << bulk_len_;
std::unique_ptr<uint8_t[]> nb(new uint8_t[bulk_len_]);
memcpy(nb.get(), str.data(), len);
bulk_str = Buffer{nb.get(), len};
buf_stash_.emplace_back(move(nb));
is_broken_token_ = true;
cached_expr_->back().has_support = true;
}
last_consumed_ = len;
bulk_len_ -= len;
}
return INPUT_PENDING;
}
void RedisParser::HandleFinishArg() {
DCHECK(!parse_stack_.empty());
DCHECK_GT(parse_stack_.back().first, 0u);
while (true) {
--parse_stack_.back().first;
state_ = PARSE_ARG_S;
if (parse_stack_.back().first != 0)
break;
parse_stack_.pop_back(); // pop 0.
if (parse_stack_.empty()) {
state_ = CMD_COMPLETE_S;
break;
}
cached_expr_ = parse_stack_.back().second;
}
}
void RedisParser::ExtendLastString(Buffer str) {
DCHECK(!cached_expr_->empty() && cached_expr_->back().type == RespExpr::STRING);
DCHECK(!buf_stash_.empty());
Buffer& last_str = get<Buffer>(cached_expr_->back().u);
DCHECK(last_str.data() == buf_stash_.back().get());
std::unique_ptr<uint8_t[]> nb(new uint8_t[last_str.size() + str.size()]);
memcpy(nb.get(), last_str.data(), last_str.size());
memcpy(nb.get() + last_str.size(), str.data(), str.size());
last_str = RespExpr::Buffer{nb.get(), last_str.size() + str.size()};
buf_stash_.back() = std::move(nb);
}
} // namespace dfly

View file

@ -1,100 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/container/inlined_vector.h>
#include "core/resp_expr.h"
namespace dfly {
/**
* @brief Zero-copy (best-effort) parser.
*
*/
class RedisParser {
public:
enum Result {
OK,
INPUT_PENDING,
BAD_ARRAYLEN,
BAD_BULKLEN,
BAD_STRING,
BAD_INT
};
using Buffer = RespExpr::Buffer;
explicit RedisParser(bool server_mode = true) : server_mode_(server_mode) {
}
/**
* @brief Parses str into res. "consumed" stores number of bytes consumed from str.
*
* A caller should not invalidate str if the parser returns RESP_OK as long as he continues
* accessing res. However, if parser returns MORE_INPUT a caller may discard consumed
* part of str because parser caches the intermediate state internally according to 'consumed'
* result.
*
* Note: A parser does not always guarantee progress, i.e. if a small buffer was passed it may
* returns MORE_INPUT with consumed == 0.
*
*/
Result Parse(Buffer str, uint32_t* consumed, RespVec* res);
void SetClientMode() {
server_mode_ = false;
}
size_t parselen_hint() const {
return bulk_len_;
}
size_t stash_size() const { return stash_.size(); }
const std::vector<std::unique_ptr<RespVec>>& stash() const { return stash_;}
private:
void InitStart(uint8_t prefix_b, RespVec* res);
void StashState(RespVec* res);
// Skips the first character (*).
Result ConsumeArrayLen(Buffer str);
Result ParseArg(Buffer str);
Result ConsumeBulk(Buffer str);
Result ParseInline(Buffer str);
// Updates last_consumed_
Result ParseNum(Buffer str, int64_t* res);
void HandleFinishArg();
void ExtendLastString(Buffer str);
enum State : uint8_t {
INIT_S = 0,
INLINE_S,
ARRAY_LEN_S,
PARSE_ARG_S, // Parse [$:+-]string\r\n
BULK_STR_S,
FINISH_ARG_S,
CMD_COMPLETE_S,
};
State state_ = INIT_S;
Result last_result_ = OK;
uint32_t last_consumed_ = 0;
uint32_t bulk_len_ = 0;
uint32_t last_stashed_level_ = 0, last_stashed_index_ = 0;
// expected expression length, pointer to expression vector.
absl::InlinedVector<std::pair<uint32_t, RespVec*>, 4> parse_stack_;
std::vector<std::unique_ptr<RespVec>> stash_;
using BlobPtr = std::unique_ptr<uint8_t[]>;
std::vector<BlobPtr> buf_stash_;
RespVec* cached_expr_ = nullptr;
bool is_broken_token_ = false;
bool server_mode_ = true;
};
} // namespace dfly

View file

@ -1,220 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/redis_parser.h"
extern "C" {
#include "redis/sds.h"
#include "redis/zmalloc.h"
}
#include <absl/strings/str_cat.h>
#include <gmock/gmock.h>
#include "absl/strings/str_cat.h"
#include "base/gtest.h"
#include "base/logging.h"
#include "server/test_utils.h"
using namespace testing;
using namespace std;
namespace dfly {
MATCHER_P(ArrArg, expected, absl::StrCat(negation ? "is not" : "is", " equal to:\n", expected)) {
if (arg.type != RespExpr::ARRAY) {
*result_listener << "\nWrong type: " << arg.type;
return false;
}
size_t exp_sz = expected;
size_t actual = get<RespVec*>(arg.u)->size();
if (exp_sz != actual) {
*result_listener << "\nActual size: " << actual;
return false;
}
return true;
}
class RedisParserTest : public testing::Test {
protected:
static void SetUpTestSuite() {
init_zmalloc_threadlocal();
}
RedisParser::Result Parse(std::string_view str);
RedisParser parser_;
RespExpr::Vec args_;
uint32_t consumed_;
unique_ptr<uint8_t[]> stash_;
};
RedisParser::Result RedisParserTest::Parse(std::string_view str) {
stash_.reset(new uint8_t[str.size()]);
auto* ptr = stash_.get();
memcpy(ptr, str.data(), str.size());
return parser_.Parse(RedisParser::Buffer{ptr, str.size()}, &consumed_, &args_);
}
TEST_F(RedisParserTest, Inline) {
RespExpr e{RespExpr::STRING};
ASSERT_EQ(RespExpr::STRING, e.type);
const char kCmd1[] = "KEY VAL\r\n";
ASSERT_EQ(RedisParser::OK, Parse(kCmd1));
EXPECT_EQ(strlen(kCmd1), consumed_);
EXPECT_THAT(args_, ElementsAre(StrArg("KEY"), StrArg("VAL")));
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("KEY"));
EXPECT_EQ(3, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" FOO "));
EXPECT_EQ(5, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" BAR"));
EXPECT_EQ(4, consumed_);
ASSERT_EQ(RedisParser::OK, Parse(" \r\n "));
EXPECT_EQ(3, consumed_);
EXPECT_THAT(args_, ElementsAre(StrArg("KEY"), StrArg("FOO"), StrArg("BAR")));
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 1 2"));
EXPECT_EQ(4, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 45"));
EXPECT_EQ(3, consumed_);
ASSERT_EQ(RedisParser::OK, Parse("\r\n"));
EXPECT_EQ(2, consumed_);
EXPECT_THAT(args_, ElementsAre(StrArg("1"), StrArg("2"), StrArg("45")));
// Empty queries return RESP_OK.
EXPECT_EQ(RedisParser::OK, Parse("\r\n"));
EXPECT_EQ(2, consumed_);
}
TEST_F(RedisParserTest, Sds) {
int argc;
sds* argv = sdssplitargs("\r\n",&argc);
EXPECT_EQ(0, argc);
sdsfreesplitres(argv,argc);
argv = sdssplitargs("\026 \020 \200 \277 \r\n",&argc);
EXPECT_EQ(4, argc);
EXPECT_STREQ("\026", argv[0]);
sdsfreesplitres(argv,argc);
argv = sdssplitargs(R"(abc "oops\n" )""\r\n",&argc);
EXPECT_EQ(2, argc);
EXPECT_STREQ("oops\n", argv[1]);
sdsfreesplitres(argv,argc);
argv = sdssplitargs(R"( "abc\xf0" )" "\t'oops\n' \r\n",&argc);
ASSERT_EQ(2, argc);
EXPECT_STREQ("abc\xf0", argv[0]);
EXPECT_STREQ("oops\n", argv[1]);
sdsfreesplitres(argv,argc);
}
TEST_F(RedisParserTest, InlineEscaping) {
LOG(ERROR) << "TBD: to be compliant with sdssplitargs"; // TODO:
}
TEST_F(RedisParserTest, Multi1) {
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n"));
EXPECT_EQ(4, consumed_);
EXPECT_EQ(0, parser_.parselen_hint());
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("$4\r\n"));
EXPECT_EQ(4, consumed_);
EXPECT_EQ(4, parser_.parselen_hint());
ASSERT_EQ(RedisParser::OK, Parse("PING\r\n"));
EXPECT_EQ(6, consumed_);
EXPECT_EQ(0, parser_.parselen_hint());
EXPECT_THAT(args_, ElementsAre(StrArg("PING")));
}
TEST_F(RedisParserTest, Multi2) {
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n$"));
EXPECT_EQ(4, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("$4\r\nMSET"));
EXPECT_EQ(4, consumed_);
ASSERT_EQ(RedisParser::OK, Parse("MSET\r\n*2\r\n"));
EXPECT_EQ(6, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*2\r\n$3\r\nKEY\r\n$3\r\nVAL"));
EXPECT_EQ(17, consumed_);
ASSERT_EQ(RedisParser::OK, Parse("VAL\r\n"));
EXPECT_EQ(5, consumed_);
EXPECT_THAT(args_, ElementsAre("KEY", "VAL"));
}
TEST_F(RedisParserTest, Multi3) {
const char kFirst[] = "*3\r\n$3\r\nSET\r\n$16\r\nkey:";
const char kSecond[] = "key:000002273458\r\n$3\r\nVXK";
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kFirst));
ASSERT_EQ(strlen(kFirst) - 4, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kSecond));
ASSERT_EQ(strlen(kSecond) - 3, consumed_);
ASSERT_EQ(RedisParser::OK, Parse("VXK\r\n*3\r\n$3\r\nSET"));
EXPECT_THAT(args_, ElementsAre("SET", "key:000002273458", "VXK"));
}
TEST_F(RedisParserTest, ClientMode) {
parser_.SetClientMode();
ASSERT_EQ(RedisParser::OK, Parse(":-1\r\n"));
EXPECT_THAT(args_, ElementsAre(IntArg(-1)));
ASSERT_EQ(RedisParser::OK, Parse("+OK\r\n"));
EXPECT_THAT(args_, RespEq("OK"));
ASSERT_EQ(RedisParser::OK, Parse("-ERR foo bar\r\n"));
EXPECT_THAT(args_, ElementsAre(ErrArg("ERR foo")));
}
TEST_F(RedisParserTest, Hierarchy) {
parser_.SetClientMode();
const char* kThirdArg = "*2\r\n$3\r\n100\r\n$3\r\n200\r\n";
string resp = absl::StrCat("*3\r\n$3\r\n900\r\n$3\r\n800\r\n", kThirdArg);
ASSERT_EQ(RedisParser::OK, Parse(resp));
EXPECT_THAT(args_, ElementsAre(StrArg("900"), StrArg("800"), ArrArg(2)));
EXPECT_THAT(*get<RespVec*>(args_[2].u), ElementsAre(StrArg("100"), StrArg("200")));
}
TEST_F(RedisParserTest, InvalidMult1) {
ASSERT_EQ(RedisParser::BAD_BULKLEN, Parse("*2\r\n$3\r\nFOO\r\nBAR\r\n"));
}
TEST_F(RedisParserTest, Empty) {
ASSERT_EQ(RedisParser::OK, Parse("*2\r\n$0\r\n\r\n$0\r\n\r\n"));
}
TEST_F(RedisParserTest, LargeBulk) {
std::string_view prefix("*1\r\n$1024\r\n");
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(prefix));
ASSERT_EQ(prefix.size(), consumed_);
ASSERT_GE(parser_.parselen_hint(), 1024);
string half(512, 'a');
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half));
ASSERT_EQ(512, consumed_);
ASSERT_GE(parser_.parselen_hint(), 512);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half));
ASSERT_EQ(512, consumed_);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\r"));
ASSERT_EQ(0, consumed_);
ASSERT_EQ(RedisParser::OK, Parse("\r\n"));
ASSERT_EQ(2, consumed_);
string part1 = absl::StrCat(prefix, half);
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(part1));
ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half));
ASSERT_EQ(RedisParser::OK, Parse("\r\n"));
}
} // namespace dfly

View file

@ -14,10 +14,10 @@ extern "C" {
#include <boost/asio/ip/tcp.hpp>
#include "base/logging.h"
#include "facade/redis_parser.h"
#include "server/error.h"
#include "server/main_service.h"
#include "server/rdb_load.h"
#include "server/redis_parser.h"
#include "util/proactor_base.h"
namespace dfly {
@ -25,6 +25,7 @@ namespace dfly {
using namespace std;
using namespace util;
using namespace boost::asio;
using namespace facade;
namespace this_fiber = ::boost::this_fiber;
namespace {
@ -74,7 +75,6 @@ error_code Recv(FiberSocketBase* input, base::IoBuf* dest) {
return error_code{};
}
// TODO: to remove usages of this macro and make code crash-less.
#define CHECK_EC(x) \
do { \

View file

@ -7,13 +7,12 @@
#include <variant>
#include "base/io_buf.h"
#include "core/resp_expr.h"
#include "facade/redis_parser.h"
#include "server/conn_context.h"
#include "util/fiber_socket_base.h"
namespace dfly {
class RedisParser;
class Service;
class Replica {
@ -53,7 +52,7 @@ class Replica {
// SYNCING means that the initial ack succeeded. It may be optional if we can still load from
// the journal offset.
enum State {
R_ENABLED = 1, // Replication mode is enabled. Serves for signaling shutdown.
R_ENABLED = 1, // Replication mode is enabled. Serves for signaling shutdown.
R_TCP_CONNECTED = 2,
R_SYNCING = 4,
R_SYNC_OK = 8,
@ -76,7 +75,7 @@ class Replica {
// Where the sock_ is handled.
util::ProactorBase* sock_thread_ = nullptr;
std::unique_ptr<RedisParser> parser_;
std::unique_ptr<facade::RedisParser> parser_;
// repl_offs - till what offset we've already read from the master.
// ack_offs_ last acknowledged offset.

View file

@ -1,258 +0,0 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/reply_builder.h"
#include <absl/strings/numbers.h>
#include <absl/strings/str_cat.h>
#include "base/logging.h"
#include "server/error.h"
using namespace std;
using absl::StrAppend;
namespace dfly {
namespace {
inline iovec constexpr IoVec(std::string_view s) {
iovec r{const_cast<char*>(s.data()), s.size()};
return r;
}
constexpr char kCRLF[] = "\r\n";
constexpr char kErrPref[] = "-ERR ";
constexpr char kSimplePref[] = "+";
} // namespace
SinkReplyBuilder::SinkReplyBuilder(::io::Sink* sink) : sink_(sink) {
}
void SinkReplyBuilder::CloseConnection() {
if (!ec_)
ec_ = std::make_error_code(std::errc::connection_aborted);
}
void SinkReplyBuilder::Send(const iovec* v, uint32_t len) {
DCHECK(sink_);
if (should_batch_) {
// TODO: to introduce flushing when too much data is batched.
for (unsigned i = 0; i < len; ++i) {
std::string_view src((char*)v[i].iov_base, v[i].iov_len);
DVLOG(2) << "Appending to stream " << sink_ << " " << src;
batch_.append(src.data(), src.size());
}
return;
}
error_code ec;
++io_write_cnt_;
for (unsigned i = 0; i < len; ++i) {
io_write_bytes_ += v[i].iov_len;
}
if (batch_.empty()) {
ec = sink_->Write(v, len);
} else {
DVLOG(1) << "Sending batch to stream " << sink_ << "\n" << batch_;
io_write_bytes_ += batch_.size();
iovec tmp[len + 1];
tmp[0].iov_base = batch_.data();
tmp[0].iov_len = batch_.size();
copy(v, v + len, tmp + 1);
ec = sink_->Write(tmp, len + 1);
batch_.clear();
}
if (ec) {
ec_ = ec;
}
}
void SinkReplyBuilder::SendDirect(std::string_view raw) {
iovec v = {IoVec(raw)};
Send(&v, 1);
}
MCReplyBuilder::MCReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
}
void MCReplyBuilder::SendStored() {
SendDirect("STORED\r\n");
}
void MCReplyBuilder::SendLong(long val) {
SendDirect(absl::StrCat(val, kCRLF));
}
void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
string header;
for (unsigned i = 0; i < count; ++i) {
if (resp[i]) {
const auto& src = *resp[i];
absl::StrAppend(&header, "VALUE ", src.key, " ", src.mc_flag, " ",
src.value.size());
if (src.mc_ver) {
absl::StrAppend(&header, " ", src.mc_ver);
}
absl::StrAppend(&header, "\r\n");
iovec v[] = {IoVec(header), IoVec(src.value), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
header.clear();
}
}
SendDirect("END\r\n");
}
void MCReplyBuilder::SendError(string_view str) {
SendDirect("ERROR\r\n");
}
void MCReplyBuilder::SendClientError(string_view str) {
iovec v[] = {IoVec("CLIENT_ERROR "), IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
}
void MCReplyBuilder::SendSetSkipped() {
SendDirect("NOT_STORED\r\n");
}
void MCReplyBuilder::SendNotFound() {
SendDirect("NOT_FOUND\r\n");
}
RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
}
void RedisReplyBuilder::SendError(string_view str) {
if (str[0] == '-') {
iovec v[] = {IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
} else {
iovec v[] = {IoVec(kErrPref), IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
}
}
void RedisReplyBuilder::SendStored() {
SendSimpleString("OK");
}
void RedisReplyBuilder::SendSetSkipped() {
SendNull();
}
void RedisReplyBuilder::SendNull() {
constexpr char kNullStr[] = "$-1\r\n";
iovec v[] = {IoVec(kNullStr)};
Send(v, ABSL_ARRAYSIZE(v));
}
void RedisReplyBuilder::SendSimpleString(std::string_view str) {
iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
}
void RedisReplyBuilder::SendBulkString(std::string_view str) {
char tmp[absl::numbers_internal::kFastToBufferSize + 3];
tmp[0] = '$'; // Format length
char* next = absl::numbers_internal::FastIntToBuffer(uint32_t(str.size()), tmp + 1);
*next++ = '\r';
*next++ = '\n';
std::string_view lenpref{tmp, size_t(next - tmp)};
// 3 parts: length, string and CRLF.
iovec v[3] = {IoVec(lenpref), IoVec(str), IoVec(kCRLF)};
return Send(v, ABSL_ARRAYSIZE(v));
}
void RedisReplyBuilder::SendError(OpStatus status) {
switch (status) {
case OpStatus::OK:
SendOk();
break;
case OpStatus::KEY_NOTFOUND:
SendError(kKeyNotFoundErr);
break;
case OpStatus::WRONG_TYPE:
SendError(kWrongTypeErr);
break;
default:
LOG(ERROR) << "Unsupported status " << status;
SendError("Internal error");
break;
}
}
void RedisReplyBuilder::SendLong(long num) {
string str = absl::StrCat(":", num, kCRLF);
SendDirect(str);
}
void RedisReplyBuilder::SendDouble(double val) {
SendBulkString(absl::StrCat(val));
}
void RedisReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
string res = absl::StrCat("*", count, kCRLF);
for (size_t i = 0; i < count; ++i) {
if (resp[i]) {
StrAppend(&res, "$", resp[i]->value.size(), kCRLF);
res.append(resp[i]->value).append(kCRLF);
} else {
res.append("$-1\r\n");
}
}
SendDirect(res);
}
void RedisReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t count) {
string res = absl::StrCat("*", count, kCRLF);
for (size_t i = 0; i < count; ++i) {
StrAppend(&res, "+", arr[i], kCRLF);
}
SendDirect(res);
}
void RedisReplyBuilder::SendNullArray() {
SendDirect("*-1\r\n");
}
void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
string res = absl::StrCat("*", arr.size(), kCRLF);
for (size_t i = 0; i < arr.size(); ++i) {
StrAppend(&res, "$", arr[i].size(), kCRLF);
res.append(arr[i]).append(kCRLF);
}
SendDirect(res);
}
void RedisReplyBuilder::StartArray(unsigned len) {
SendDirect(absl::StrCat("*", len, kCRLF));
}
void ReqSerializer::SendCommand(std::string_view str) {
VLOG(1) << "SendCommand: " << str;
iovec v[] = {IoVec(str), IoVec(kCRLF)};
ec_ = sink_->Write(v, ABSL_ARRAYSIZE(v));
}
} // namespace dfly

View file

@ -1,154 +0,0 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include <optional>
#include <string_view>
#include "core/op_status.h"
#include "io/sync_stream_interface.h"
namespace dfly {
class ReplyBuilderInterface {
public:
virtual ~ReplyBuilderInterface() {
}
// Reply for set commands.
virtual void SendStored() = 0;
// Common for both MC and Redis.
virtual void SendError(std::string_view str) = 0;
virtual std::error_code GetError() const = 0;
struct ResponseValue {
std::string_view key;
std::string value;
uint64_t mc_ver = 0; // 0 means we do not output it (i.e has not been requested).
uint32_t mc_flag = 0;
};
using OptResp = std::optional<ResponseValue>;
virtual void SendMGetResponse(const OptResp* resp, uint32_t count) = 0;
virtual void SendLong(long val) = 0;
virtual void SendSetSkipped() = 0;
};
class SinkReplyBuilder : public ReplyBuilderInterface {
public:
SinkReplyBuilder(const SinkReplyBuilder&) = delete;
void operator=(const SinkReplyBuilder&) = delete;
SinkReplyBuilder(::io::Sink* sink);
// In order to reduce interrupt rate we allow coalescing responses together using
// Batch mode. It is controlled by Connection state machine because it makes sense only
// when pipelined requests are arriving.
void SetBatchMode(bool batch) {
should_batch_ = batch;
}
// Used for QUIT - > should move to conn_context?
void CloseConnection();
std::error_code GetError() const override {
return ec_;
}
size_t io_write_cnt() const {
return io_write_cnt_;
}
size_t io_write_bytes() const {
return io_write_bytes_;
}
void reset_io_stats() {
io_write_cnt_ = 0;
io_write_bytes_ = 0;
}
//! Sends a string as is without any formatting. raw should be encoded according to the protocol.
void SendDirect(std::string_view str);
protected:
void Send(const iovec* v, uint32_t len);
std::string batch_;
::io::Sink* sink_;
std::error_code ec_;
size_t io_write_cnt_ = 0;
size_t io_write_bytes_ = 0;
bool should_batch_ = false;
};
class MCReplyBuilder : public SinkReplyBuilder {
public:
MCReplyBuilder(::io::Sink* stream);
void SendError(std::string_view str) final;
// void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final;
void SendMGetResponse(const OptResp* resp, uint32_t count) final;
void SendStored() final;
void SendLong(long val) final;
void SendSetSkipped() final;
void SendClientError(std::string_view str);
void SendNotFound();
};
class RedisReplyBuilder : public SinkReplyBuilder {
public:
RedisReplyBuilder(::io::Sink* stream);
void SendOk() {
SendSimpleString("OK");
}
void SendError(std::string_view str) override;
void SendMGetResponse(const OptResp* resp, uint32_t count) override;
void SendStored() override;
void SendLong(long val) override;
void SendSetSkipped() override;
void SendError(OpStatus status);
virtual void SendSimpleString(std::string_view str);
virtual void SendSimpleStrArr(const std::string_view* arr, uint32_t count);
virtual void SendNullArray();
virtual void SendStringArr(absl::Span<const std::string_view> arr);
virtual void SendNull();
virtual void SendDouble(double val);
virtual void SendBulkString(std::string_view str);
virtual void StartArray(unsigned len);
private:
};
class ReqSerializer {
public:
explicit ReqSerializer(::io::Sink* stream) : sink_(stream) {
}
void SendCommand(std::string_view str);
std::error_code ec() const {
return ec_;
}
private:
::io::Sink* sink_;
std::error_code ec_;
};
} // namespace dfly

View file

@ -46,6 +46,7 @@ using namespace util;
namespace fibers = ::boost::fibers;
namespace fs = std::filesystem;
using strings::HumanReadableNumBytes;
using facade::MCReplyBuilder;
namespace {
@ -97,7 +98,7 @@ void ServerFamily::Shutdown() {
});
}
void ServerFamily::StatsMC(std::string_view section, ConnectionContext* cntx) {
void ServerFamily::StatsMC(std::string_view section, facade::ConnectionContext* cntx) {
if (!section.empty()) {
return cntx->reply_builder()->SendError("");
}
@ -200,7 +201,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError("ACL is not supported yet");
}
if (!(cntx->conn_state.mask & ConnectionState::REQ_AUTH)) {
if (!cntx->req_auth) {
return (*cntx)->SendError(
"AUTH <password> called without any password configured for the "
"default user. Are you sure your configuration is correct?");
@ -208,10 +209,10 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
string_view pass = ArgS(args, 1);
if (pass == FLAGS_requirepass) {
cntx->conn_state.mask |= ConnectionState::AUTHENTICATED;
cntx->authenticated = true;
(*cntx)->SendOk();
} else {
(*cntx)->SendError(kAuthRejected);
(*cntx)->SendError(facade::kAuthRejected);
}
}
@ -544,14 +545,14 @@ void ServerFamily::_Shutdown(CmdArgList args, ConnectionContext* cntx) {
void ServerFamily::SyncGeneric(std::string_view repl_master_id, uint64_t offs,
ConnectionContext* cntx) {
if (cntx->conn_state.mask & ConnectionState::ASYNC_DISPATCH) {
if (cntx->async_dispatch) {
// SYNC is a special command that should not be sent in batch with other commands.
// It should be the last command since afterwards the server just dumps the replication data.
(*cntx)->SendError("Can not sync in pipeline mode");
return;
}
cntx->conn_state.mask |= ConnectionState::REPL_CONNECTION;
cntx->replica_conn = true;
ServerState::tl_connection_stats()->num_replicas += 1;
// TBD.
}

View file

@ -4,6 +4,8 @@
#pragma once
#include "facade/conn_context.h"
#include "facade/redis_parser.h"
#include "server/engine_shard_set.h"
#include "server/global_state.h"
#include "util/proactor_pool.h"
@ -28,7 +30,7 @@ struct Metrics {
size_t heap_used_bytes = 0;
size_t heap_comitted_bytes = 0;
ConnectionStats conn_stats;
facade::ConnectionStats conn_stats;
};
class ServerFamily {
@ -50,7 +52,7 @@ class ServerFamily {
return script_mgr_.get();
}
void StatsMC(std::string_view section, ConnectionContext* cntx);
void StatsMC(std::string_view section, facade::ConnectionContext* cntx);
private:
uint32_t shard_count() const {

View file

@ -28,7 +28,7 @@ class ServerState { // public struct - to allow initialization.
return &state_;
}
static ConnectionStats* tl_connection_stats() {
static facade::ConnectionStats* tl_connection_stats() {
return &state_.connection_stats;
}
@ -40,7 +40,7 @@ class ServerState { // public struct - to allow initialization.
bool is_master = true;
ConnectionStats connection_stats;
facade::ConnectionStats connection_stats;
void TxCountInc() {
++live_transactions_;

View file

@ -20,6 +20,7 @@ extern "C" {
namespace dfly {
using namespace std;
using ResultStringVec = vector<OpResult<vector<string>>>;
using ResultSetView = OpResult<absl::flat_hash_set<std::string_view>>;
using SvArray = vector<std::string_view>;

View file

@ -4,11 +4,13 @@
#pragma once
#include "core/op_status.h"
#include "facade/op_status.h"
#include "server/common_types.h"
namespace dfly {
using facade::OpResult;
class ConnectionContext;
class CommandRegistry;
class EngineShard;

View file

@ -6,6 +6,7 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
#include "server/command_registry.h"
#include "server/test_utils.h"

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
@ -23,6 +23,8 @@ namespace dfly {
namespace {
using namespace std;
using facade::Protocol;
using facade::ReplyBuilderInterface;
using CI = CommandId;
DEFINE_VARZ(VarzQps, set_qps);
@ -284,7 +286,7 @@ void StringFamily::IncrByGeneric(std::string_view key, int64_t val, ConnectionCo
case OpStatus::WRONG_TYPE:
return builder->SendError(kWrongTypeErr);
case OpStatus::KEY_NOTFOUND: // Relevant only for MC
return reinterpret_cast<MCReplyBuilder*>(builder)->SendNotFound();
return reinterpret_cast<facade::MCReplyBuilder*>(builder)->SendNotFound();
default:;
}
__builtin_unreachable();
@ -327,8 +329,10 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
unsigned shard_count = transaction->shard_set()->size();
std::vector<MGetResponse> mget_resp(shard_count);
ConnectionContext* dfly_cntx = static_cast<ConnectionContext*>(cntx);
bool fetch_mcflag = cntx->protocol() == Protocol::MEMCACHE;
bool fetch_mcver = fetch_mcflag && (cntx->conn_state.mask & ConnectionState::FETCH_CAS_VER);
bool fetch_mcver =
fetch_mcflag && (dfly_cntx->conn_state.memcache_flag & ConnectionState::FETCH_CAS_VER);
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
@ -436,8 +440,8 @@ OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) {
return OpStatus::OK;
}
OpResult<int64_t> StringFamily::OpIncrBy(const OpArgs& op_args, std::string_view key,
int64_t incr, bool skip_on_missing) {
OpResult<int64_t> StringFamily::OpIncrBy(const OpArgs& op_args, std::string_view key, int64_t incr,
bool skip_on_missing) {
auto& db_slice = op_args.shard->db_slice();
auto [it, expire_it] = db_slice.FindExt(op_args.db_ind, key);

View file

@ -12,6 +12,8 @@ namespace dfly {
class ConnectionContext;
class CommandRegistry;
using facade::OpStatus;
using facade::OpResult;
class SetCmd {
DbSlice* db_slice_;

View file

@ -6,6 +6,7 @@
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/engine_shard_set.h"

View file

@ -1,4 +1,4 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
@ -9,15 +9,14 @@
#include "base/logging.h"
#include "base/stl_util.h"
#include "server/dragonfly_connection.h"
#include "facade/dragonfly_connection.h"
#include "util/uring/uring_pool.h"
namespace dfly {
using namespace testing;
using namespace util;
using namespace std;
using MP = MemcacheParser;
using namespace std;
using namespace util;
using namespace testing;
static vector<string> SplitLines(const std::string& src) {
vector<string> res = absl::StrSplit(src, "\r\n");
@ -29,90 +28,6 @@ static vector<string> SplitLines(const std::string& src) {
return res;
}
bool RespMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const {
if (e.type != type_) {
*listener << "\nWrong type: " << RespExpr::TypeName(e.type);
return false;
}
if (type_ == RespExpr::STRING || type_ == RespExpr::ERROR) {
RespExpr::Buffer ebuf = e.GetBuf();
std::string_view actual{reinterpret_cast<char*>(ebuf.data()), ebuf.size()};
if (type_ == RespExpr::ERROR && !absl::StrContains(actual, exp_str_)) {
*listener << "Actual does not contain '" << exp_str_ << "'";
return false;
}
if (type_ == RespExpr::STRING && exp_str_ != actual) {
*listener << "\nActual string: " << actual;
return false;
}
} else if (type_ == RespExpr::INT64) {
auto actual = get<int64_t>(e.u);
if (exp_int_ != actual) {
*listener << "\nActual : " << actual << " expected: " << exp_int_;
return false;
}
} else if (type_ == RespExpr::ARRAY) {
size_t len = get<RespVec*>(e.u)->size();
if (len != size_t(exp_int_)) {
*listener << "Actual length " << len << ", expected: " << exp_int_;
return false;
}
}
return true;
}
void RespMatcher::DescribeTo(std::ostream* os) const {
*os << "is ";
switch (type_) {
case RespExpr::STRING:
case RespExpr::ERROR:
*os << exp_str_;
break;
case RespExpr::INT64:
*os << exp_str_;
break;
default:
*os << "TBD";
break;
}
}
void RespMatcher::DescribeNegationTo(std::ostream* os) const {
*os << "is not ";
}
bool RespTypeMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const {
if (e.type != type_) {
*listener << "\nWrong type: " << RespExpr::TypeName(e.type);
return false;
}
return true;
}
void RespTypeMatcher::DescribeTo(std::ostream* os) const {
*os << "is " << RespExpr::TypeName(type_);
}
void RespTypeMatcher::DescribeNegationTo(std::ostream* os) const {
*os << "is not " << RespExpr::TypeName(type_);
}
void PrintTo(const RespExpr::Vec& vec, std::ostream* os) {
*os << "Vec: [";
if (!vec.empty()) {
for (size_t i = 0; i < vec.size() - 1; ++i) {
*os << vec[i] << ",";
}
*os << vec.back();
}
*os << "]\n";
}
vector<int64_t> ToIntArr(const RespVec& vec) {
vector<int64_t> res;
for (auto a : vec) {
@ -126,7 +41,8 @@ vector<int64_t> ToIntArr(const RespVec& vec) {
}
BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto)
: dummy_conn(new Connection(proto, nullptr, nullptr)), cmd_cntx(&sink, dummy_conn.get()) {
: dummy_conn(new facade::Connection(proto, nullptr, nullptr, nullptr)),
cmd_cntx(&sink, dummy_conn.get()) {
}
BaseFamilyTest::TestConnWrapper::~TestConnWrapper() {

View file

@ -6,84 +6,15 @@
#include <gmock/gmock.h>
#include "facade/memcache_parser.h"
#include "facade/redis_parser.h"
#include "io/io.h"
#include "server/conn_context.h"
#include "server/main_service.h"
#include "server/memcache_parser.h"
#include "server/redis_parser.h"
#include "util/proactor_pool.h"
namespace dfly {
class RespMatcher {
public:
RespMatcher(std::string_view val, RespExpr::Type t = RespExpr::STRING) : type_(t), exp_str_(val) {
}
RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64) : type_(t), exp_int_(val) {
}
using is_gtest_matcher = void;
bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const;
void DescribeTo(std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;
private:
RespExpr::Type type_;
std::string exp_str_;
int64_t exp_int_;
};
class RespTypeMatcher {
public:
RespTypeMatcher(RespExpr::Type type) : type_(type) {
}
using is_gtest_matcher = void;
bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const;
void DescribeTo(std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;
private:
RespExpr::Type type_;
};
inline ::testing::PolymorphicMatcher<RespMatcher> StrArg(std::string_view str) {
return ::testing::MakePolymorphicMatcher(RespMatcher(str));
}
inline ::testing::PolymorphicMatcher<RespMatcher> ErrArg(std::string_view str) {
return ::testing::MakePolymorphicMatcher(RespMatcher(str, RespExpr::ERROR));
}
inline ::testing::PolymorphicMatcher<RespMatcher> IntArg(int64_t ival) {
return ::testing::MakePolymorphicMatcher(RespMatcher(ival));
}
inline ::testing::PolymorphicMatcher<RespMatcher> ArrLen(size_t len) {
return ::testing::MakePolymorphicMatcher(RespMatcher(len, RespExpr::ARRAY));
}
inline ::testing::PolymorphicMatcher<RespTypeMatcher> ArgType(RespExpr::Type t) {
return ::testing::MakePolymorphicMatcher(RespTypeMatcher(t));
}
inline bool operator==(const RespExpr& left, const char* s) {
return left.type == RespExpr::STRING && ToSV(left.GetBuf()) == s;
}
void PrintTo(const RespExpr::Vec& vec, std::ostream* os);
MATCHER_P(RespEq, val, "") {
return ::testing::ExplainMatchResult(::testing::ElementsAre(StrArg(val)), arg, result_listener);
}
using namespace facade;
std::vector<int64_t> ToIntArr(const RespVec& vec);
@ -99,7 +30,7 @@ class BaseFamilyTest : public ::testing::Test {
struct TestConnWrapper {
::io::StringSink sink; // holds the response blob
std::unique_ptr<Connection> dummy_conn;
std::unique_ptr<facade::Connection> dummy_conn;
ConnectionContext cmd_cntx;
std::vector<std::unique_ptr<std::string>> tmp_str_vec;

View file

@ -306,6 +306,7 @@ void Transaction::SetExecCmd(const CommandId* cid) {
}
unique_shard_cnt_ = 0;
args_.clear();
cid_ = cid;
cb_ = nullptr;
}

View file

@ -13,8 +13,8 @@
#include <vector>
#include "core/intent_lock.h"
#include "core/op_status.h"
#include "core/tx_queue.h"
#include "facade/op_status.h"
#include "server/common_types.h"
#include "server/table.h"
#include "util/fibers/fibers_ext.h"
@ -25,6 +25,9 @@ class DbSlice;
class EngineShardSet;
class EngineShard;
using facade::OpStatus;
using facade::OpResult;
class Transaction {
Transaction(const Transaction&);
void operator=(const Transaction&) = delete;
@ -48,7 +51,7 @@ class Transaction {
using time_point = ::std::chrono::steady_clock::time_point;
enum LocalMask : uint16_t {
ARMED = 1, // Transaction was armed with the callback
ARMED = 1, // Transaction was armed with the callback
OUT_OF_ORDER = 2,
KEYLOCK_ACQUIRED = 4,
SUSPENDED_Q = 0x10, // added by the coordination flow (via WaitBlocked()).

View file

@ -6,7 +6,7 @@
#include <variant>
#include "core/op_status.h"
#include "facade/op_status.h"
#include "server/common_types.h"
namespace dfly {