diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 58faf7ad7..bff4ebff4 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -1,7 +1,8 @@ add_executable(dragonfly dfly_main.cc) cxx_link(dragonfly base dragonfly_lib) -add_library(dragonfly_lib db_slice.cc dragonfly_listener.cc dragonfly_connection.cc +add_library(dragonfly_lib command_registry.cc db_slice.cc dragonfly_listener.cc + dragonfly_connection.cc main_service.cc engine_shard_set.cc redis_parser.cc resp_expr.cc reply_builder.cc) diff --git a/server/command_registry.cc b/server/command_registry.cc new file mode 100644 index 000000000..41c210410 --- /dev/null +++ b/server/command_registry.cc @@ -0,0 +1,87 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#include "server/command_registry.h" + +#include "absl/strings/str_cat.h" +#include "base/bits.h" +#include "base/logging.h" +#include "server/conn_context.h" + +using namespace std; + +namespace dfly { + +using absl::StrAppend; +using absl::StrCat; + +CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, + int8_t last_key, int8_t step) + : name_(name), opt_mask_(mask), arity_(arity), first_key_(first_key), last_key_(last_key), + step_key_(step) { +} + +uint32_t CommandId::OptCount(uint32_t mask) { + return absl::popcount(mask); +} + +CommandRegistry::CommandRegistry() { + CommandId cd("COMMAND", CO::RANDOM | CO::LOADING | CO::STALE, 0, 0, 0, 0); + cd.AssignCallback([this](const auto& args, auto* cntx) { return Command(args, cntx); }); + const char* nm = cd.name(); + cmd_map_.emplace(nm, std::move(cd)); +} + +void CommandRegistry::Command(CmdArgList args, ConnectionContext* cntx) { + size_t sz = cmd_map_.size(); + string resp = absl::StrCat("*", sz, "\r\n"); + + for (const auto& val : cmd_map_) { + const CommandId& cd = val.second; + StrAppend(&resp, "*6\r\n$", strlen(cd.name()), "\r\n", cd.name(), "\r\n"); + StrAppend(&resp, ":", int(cd.arity()), "\r\n", "*", CommandId::OptCount(cd.opt_mask()), "\r\n"); + uint32_t opt_bit = 1; + + for (uint32_t i = 1; i < 32; ++i, opt_bit <<= 1) { + if (cd.opt_mask() & opt_bit) { + const char* name = CO::OptName(CO::CommandOpt{opt_bit}); + StrAppend(&resp, "+", name, "\r\n"); + } + } + + StrAppend(&resp, ":", cd.first_key_pos(), "\r\n"); + StrAppend(&resp, ":", cd.last_key_pos(), "\r\n"); + StrAppend(&resp, ":", cd.key_arg_step(), "\r\n"); + } + + cntx->SendDirect(resp); +} + +namespace CO { + +const char* OptName(CO::CommandOpt fl) { + using namespace CO; + + switch (fl) { + case WRITE: + return "write"; + case READONLY: + return "readonly"; + case DENYOOM: + return "denyoom"; + case FAST: + return "fast"; + case STALE: + return "stale"; + case LOADING: + return "loading"; + case RANDOM: + return "random"; + } + return ""; +} + +} // namespace CO + +} // namespace dfly diff --git a/server/command_registry.h b/server/command_registry.h new file mode 100644 index 000000000..16ff2935e --- /dev/null +++ b/server/command_registry.h @@ -0,0 +1,151 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#pragma once + +#include +#include + +#include + +#include "base/function2.hpp" + +namespace dfly { + +class ConnectionContext; + +namespace CO { + +enum CommandOpt : uint32_t { + READONLY = 1, + FAST = 2, + WRITE = 4, + LOADING = 8, + DENYOOM = 0x10, // use-memory in redis. + STALE = 0x20, + RANDOM = 0x40, +}; + +const char* OptName(CommandOpt fl); + +}; // namespace CO + +using MutableStrSpan = absl::Span; +using CmdArgList = absl::Span; + +class CommandId { + public: + using CmdFunc = std::function; + + /** + * @brief Construct a new Command Id object + * + * @param name + * @param mask + * @param arity - positive if command has fixed number of required arguments + * negative if command has minimum number of required arguments, but may have + * more. + * @param first_key - position of first key in argument list + * @param last_key - position of last key in argument list, + * -1 means the last key index is (arg_length - 1), -2 means that the last key + * index is (arg_length - 2). + * @param step - step count for locating repeating keys + */ + CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, int8_t last_key, + int8_t step); + + const char* name() const { + return name_; + } + + int arity() const { + return arity_; + } + + uint32_t opt_mask() const { + return opt_mask_; + } + + int8_t first_key_pos() const { + return first_key_; + } + + int8_t last_key_pos() const { + return last_key_; + } + + bool is_multi_key() const { + return last_key_ != first_key_; + } + + int8_t key_arg_step() const { + return step_key_; + } + + CommandId& AssignCallback(CmdFunc f) { + func_ = std::move(f); + return *this; + } + + void Invoke(CmdArgList args, ConnectionContext* cntx) const { + func_(std::move(args), cntx); + } + + static const char* OptName(CO::CommandOpt fl); + static uint32_t OptCount(uint32_t mask); + + private: + const char* name_; + + uint32_t opt_mask_; + int8_t arity_; + int8_t first_key_; + int8_t last_key_; + int8_t step_key_; + + CmdFunc func_; +}; + +class CommandRegistry { + absl::flat_hash_map cmd_map_; + + public: + CommandRegistry(); + + CommandRegistry& operator<<(CommandId cmd) { + const char* k = cmd.name(); + cmd_map_.emplace(k, std::move(cmd)); + + return *this; + } + + const CommandId* Find(std::string_view cmd) const { + auto it = cmd_map_.find(cmd); + return it == cmd_map_.end() ? nullptr : &it->second; + } + + CommandId* Find(std::string_view cmd) { + auto it = cmd_map_.find(cmd); + return it == cmd_map_.end() ? nullptr : &it->second; + } + + using TraverseCb = std::function; + + void Traverse(TraverseCb cb) { + for (const auto& k_v : cmd_map_) { + cb(k_v.first, k_v.second); + } + } + + private: + // Implements COMMAND functionality. + void Command(CmdArgList args, ConnectionContext* cntx); +}; + +inline std::string_view ArgS(CmdArgList args, size_t i) { + auto arg = args[i]; + return std::string_view(arg.data(), arg.size()); +} + +} // namespace dfly diff --git a/server/conn_context.h b/server/conn_context.h index cd51a9997..56bd91281 100644 --- a/server/conn_context.h +++ b/server/conn_context.h @@ -10,6 +10,7 @@ namespace dfly { class Connection; class EngineShardSet; +class CommandId; class ConnectionContext : public ReplyBuilder { public: @@ -17,6 +18,7 @@ class ConnectionContext : public ReplyBuilder { } // TODO: to introduce proper accessors. + const CommandId* cid = nullptr; EngineShardSet* shard_set = nullptr; Connection* owner() { return owner_;} diff --git a/server/dragonfly_connection.cc b/server/dragonfly_connection.cc index 951a7df46..5ebc9225a 100644 --- a/server/dragonfly_connection.cc +++ b/server/dragonfly_connection.cc @@ -11,6 +11,7 @@ #include "server/main_service.h" #include "server/redis_parser.h" #include "server/conn_context.h" +#include "server/command_registry.h" #include "util/fiber_sched_algo.h" using namespace util; @@ -21,7 +22,33 @@ namespace fibers = boost::fibers; namespace dfly { namespace { +using CmdArgVec = std::vector; +void SendProtocolError(RedisParser::Result pres, FiberSocketBase* peer) { + string res("-ERR Protocol error: "); + if (pres == RedisParser::BAD_BULKLEN) { + res.append("invalid bulk length\r\n"); + } else { + CHECK_EQ(RedisParser::BAD_ARRAYLEN, pres); + res.append("invalid multibulk length\r\n"); + } + + auto size_res = peer->Send(::io::Buffer(res)); + if (!size_res) { + LOG(WARNING) << "Error " << size_res.error(); + } +} + +inline MutableStrSpan ToMSS(absl::Span span) { + return MutableStrSpan{reinterpret_cast(span.data()), span.size()}; +} + +void RespToArgList(const RespVec& src, CmdArgVec* dest) { + dest->resize(src.size()); + for (size_t i = 0; i < src.size(); ++i) { + (*dest)[i] = ToMSS(src[i].GetBuf()); + } +} constexpr size_t kMinReadSize = 256; @@ -110,7 +137,25 @@ void Connection::InputLoop(FiberSocketBase* peer) { } else if (status != OK) { break; } - } while (peer->IsOpen()); + } while (peer->IsOpen() && !cc_->ec()); + + if (cc_->ec()) { + ec = cc_->ec(); + } else { + if (status == ERROR) { + VLOG(1) << "Error stats " << status; + if (redis_parser_) { + SendProtocolError(RedisParser::Result(parser_error_), peer); + } else { + string_view sv{"CLIENT_ERROR bad command line format\r\n"}; + auto size_res = peer->Send(::io::Buffer(sv)); + if (!size_res) { + LOG(WARNING) << "Error " << size_res.error(); + ec = size_res.error(); + } + } + } + } if (ec && !FiberSocketBase::IsConnClosed(ec)) { LOG(WARNING) << "Socket error " << ec; @@ -119,6 +164,7 @@ void Connection::InputLoop(FiberSocketBase* peer) { auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { RespVec args; + CmdArgVec arg_vec; uint32_t consumed = 0; RedisParser::Result result = RedisParser::OK; @@ -132,15 +178,8 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { DVLOG(2) << "Got Args with first token " << ToSV(first.GetBuf()); } - CHECK_EQ(RespExpr::STRING, first.type); // TODO - string_view sv = ToSV(first.GetBuf()); - if (sv == "PING") { - cc_->SendSimpleString("PONG"); - } else if (sv == "SET") { - CHECK_EQ(3u, args.size()); - service_->Set(ToSV(args[1].GetBuf()), ToSV(args[2].GetBuf())); - cc_->SendOk(); - } + RespToArgList(args, &arg_vec); + service_->DispatchCommand(CmdArgList{arg_vec.data(), arg_vec.size()}, cc_.get()); } io_buf->ConsumeInput(consumed); } while (RedisParser::OK == result && !cc_->ec()); @@ -154,4 +193,5 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { return ERROR; } + } // namespace dfly diff --git a/server/main_service.cc b/server/main_service.cc index 5a2f67938..a980a72ce 100644 --- a/server/main_service.cc +++ b/server/main_service.cc @@ -4,16 +4,35 @@ #include "server/main_service.h" -#include -#include +#include #include +#include +#include + #include "base/logging.h" +#include "server/conn_context.h" #include "util/uring/uring_fiber_algo.h" #include "util/varz.h" DEFINE_uint32(port, 6380, "Redis port"); +namespace std { + +ostream& operator<<(ostream& os, dfly::CmdArgList args) { + os << "["; + if (!args.empty()) { + for (size_t i = 0; i < args.size() - 1; ++i) { + os << dfly::ArgS(args, i) << ","; + } + os << dfly::ArgS(args, args.size() - 1); + } + os << "]"; + + return os; +} + +} // namespace std namespace dfly { @@ -27,6 +46,8 @@ namespace this_fiber = ::boost::this_fiber; namespace { DEFINE_VARZ(VarzMapAverage, request_latency_usec); +DEFINE_VARZ(VarzQps, ping_qps); +DEFINE_VARZ(VarzQps, set_qps); std::optional engine_varz; @@ -35,11 +56,21 @@ inline ShardId Shard(string_view sv, ShardId shard_num) { return hash % shard_num; } +inline void ToUpper(const MutableStrSpan* val) { + for (auto& c : *val) { + c = absl::ascii_toupper(c); + } +} + +string WrongNumArgsError(string_view cmd) { + return absl::StrCat("wrong number of arguments for '", cmd, "' command"); +} + } // namespace -Service::Service(ProactorPool* pp) - : shard_set_(pp), pp_(*pp) { +Service::Service(ProactorPool* pp) : shard_set_(pp), pp_(*pp) { CHECK(pp); + RegisterCommands(); engine_varz.emplace("engine", [this] { return GetVarzStats(); }); } @@ -57,38 +88,103 @@ void Service::Init(util::AcceptServer* acceptor) { }); request_latency_usec.Init(&pp_); + ping_qps.Init(&pp_); + set_qps.Init(&pp_); } void Service::Shutdown() { engine_varz.reset(); request_latency_usec.Shutdown(); - + ping_qps.Shutdown(); + set_qps.Shutdown(); shard_set_.RunBriefInParallel([&](EngineShard*) { EngineShard::DestroyThreadLocal(); }); } +void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) { + CHECK(!args.empty()); + DCHECK_NE(0u, shard_set_.size()) << "Init was not called"; + + ToUpper(&args[0]); + + VLOG(2) << "Got: " << args; + + string_view cmd_str = ArgS(args, 0); + const CommandId* cid = registry_.Find(cmd_str); + + if (cid == nullptr) { + return cntx->SendError(absl::StrCat("unknown command `", cmd_str, "`")); + } + + if ((cid->arity() > 0 && args.size() != size_t(cid->arity())) || + (cid->arity() < 0 && args.size() < size_t(-cid->arity()))) { + return cntx->SendError(WrongNumArgsError(cmd_str)); + } + uint64_t start_usec = ProactorBase::GetMonotonicTimeNs(), end_usec; + cntx->cid = cid; + cid->Invoke(args, cntx); + end_usec = ProactorBase::GetMonotonicTimeNs(); + + request_latency_usec.IncBy(cmd_str, (end_usec - start_usec) / 1000); +} + void Service::RegisterHttp(HttpListenerBase* listener) { CHECK_NOTNULL(listener); } -void Service::Set(std::string_view key, std::string_view val) { +void Service::Ping(CmdArgList args, ConnectionContext* cntx) { + if (args.size() > 2) { + return cntx->SendError("wrong number of arguments for 'ping' command"); + } + ping_qps.Inc(); + + if (args.size() == 1) { + return cntx->SendSimpleString("PONG"); + } + std::string_view arg = ArgS(args, 1); + DVLOG(2) << "Ping " << arg; + + return cntx->SendSimpleString(arg); +} + +void Service::Set(CmdArgList args, ConnectionContext* cntx) { + set_qps.Inc(); + + std::string_view key = ArgS(args, 1); + std::string_view val = ArgS(args, 2); + VLOG(2) << "Set " << key << " " << val; + ShardId sid = Shard(key, shard_count()); shard_set_.Await(sid, [&] { EngineShard* es = EngineShard::tlocal(); auto [it, res] = es->db_slice.AddOrFind(0, key); it->second = val; }); + cntx->SendOk(); } + VarzValue::Map Service::GetVarzStats() { VarzValue::Map res; atomic_ulong num_keys{0}; - shard_set_.RunBriefInParallel([&](EngineShard* es) { - num_keys += es->db_slice.DbSize(0); - }); + shard_set_.RunBriefInParallel([&](EngineShard* es) { num_keys += es->db_slice.DbSize(0); }); res.emplace_back("keys", VarzValue::FromInt(num_keys.load())); return res; } +using ServiceFunc = void (Service::*)(CmdArgList args, ConnectionContext* cntx); +inline CommandId::CmdFunc HandlerFunc(Service* se, ServiceFunc f) { + return [=](CmdArgList args, ConnectionContext* cntx) { return (se->*f)(args, cntx); }; +} + +#define HFUNC(x) AssignCallback(HandlerFunc(this, &Service::x)) + +void Service::RegisterCommands() { + using CI = CommandId; + + registry_ << CI{"PING", CO::STALE | CO::FAST, -1, 0, 0, 0}.HFUNC(Ping) + << CI{"SET", CO::WRITE | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(Set); +} + } // namespace dfly diff --git a/server/main_service.h b/server/main_service.h index 62d4fc64c..40edc4925 100644 --- a/server/main_service.h +++ b/server/main_service.h @@ -5,8 +5,9 @@ #pragma once #include "base/varz_value.h" -#include "util/http/http_handler.h" +#include "server/command_registry.h" #include "server/engine_shard_set.h" +#include "util/http/http_handler.h" namespace util { class AcceptServer; @@ -27,6 +28,8 @@ class Service { void Shutdown(); + void DispatchCommand(CmdArgList args, ConnectionContext* cntx); + uint32_t shard_count() const { return shard_set_.size(); } @@ -39,11 +42,15 @@ class Service { return pp_; } - void Set(std::string_view key, std::string_view val); private: + void Ping(CmdArgList args, ConnectionContext* cntx); + void Set(CmdArgList args, ConnectionContext* cntx); + + void RegisterCommands(); base::VarzValue::Map GetVarzStats(); + CommandRegistry registry_; EngineShardSet shard_set_; util::ProactorPool& pp_; };