diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index a4b641897..e61734e1f 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -1,6 +1,8 @@ -add_library(dfly_core compact_object.cc dragonfly_core.cc tx_queue.cc) -cxx_link(dfly_core base absl::flat_hash_map redis_lib) +add_library(dfly_core compact_object.cc dragonfly_core.cc interpreter.cc + tx_queue.cc) +cxx_link(dfly_core base absl::flat_hash_map redis_lib TRDP::lua crypto) cxx_test(dfly_core_test dfly_core LABELS DFLY) cxx_test(compact_object_test dfly_core LABELS DFLY) -cxx_test(dash_test dfly_core LABELS DFLY) \ No newline at end of file +cxx_test(dash_test dfly_core LABELS DFLY) +cxx_test(interpreter_test dfly_core LABELS DFLY) diff --git a/core/interpreter.cc b/core/interpreter.cc new file mode 100644 index 000000000..7b1a97305 --- /dev/null +++ b/core/interpreter.cc @@ -0,0 +1,170 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/interpreter.h" + +#include +#include + +#include + +extern "C" { +#include +#include +#include +} + +#include "base/logging.h" + +namespace dfly { +using namespace std; + +namespace { + +void RunSafe(lua_State* lua, string_view buf, const char* name) { + CHECK_EQ(0, luaL_loadbuffer(lua, buf.data(), buf.size(), name)); + int err = lua_pcall(lua, 0, 0, 0); + if (err) { + const char* errstr = lua_tostring(lua, -1); + LOG(FATAL) << "Error running " << name << " " << errstr; + } +} + +void Require(lua_State* lua, const char* name, lua_CFunction openf) { + luaL_requiref(lua, name, openf, 1); + lua_pop(lua, 1); /* remove lib */ +} + +void InitLua(lua_State* lua) { + Require(lua, "", luaopen_base); + Require(lua, LUA_TABLIBNAME, luaopen_table); + Require(lua, LUA_STRLIBNAME, luaopen_string); + Require(lua, LUA_MATHLIBNAME, luaopen_math); + Require(lua, LUA_DBLIBNAME, luaopen_debug); + + /* Add a helper function we use for pcall error reporting. + * Note that when the error is in the C function we want to report the + * information about the caller, that's what makes sense from the point + * of view of the user debugging a script. */ + { + const char errh_func[] = + "local dbg = debug\n" + "function __redis__err__handler(err)\n" + " local i = dbg.getinfo(2,'nSl')\n" + " if i and i.what == 'C' then\n" + " i = dbg.getinfo(3,'nSl')\n" + " end\n" + " if i then\n" + " return i.source .. ':' .. i.currentline .. ': ' .. err\n" + " else\n" + " return err\n" + " end\n" + "end\n"; + RunSafe(lua, errh_func, "@err_handler_def"); + } + + { + const char code[] = R"( +local dbg=debug +local mt = {} + +setmetatable(_G, mt) +mt.__newindex = function (t, n, v) + if dbg.getinfo(2) then + local w = dbg.getinfo(2, "S").what + if w ~= "main" and w ~= "C" then + error("Script attempted to create global variable '"..tostring(n).."'", 2) + end + end + rawset(t, n, v) +end +mt.__index = function (t, n) + if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then + error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2) + end + return rawget(t, n) +end +debug = nil +)"; + RunSafe(lua, code, "@enable_strict_lua"); + } +} + +void ToHex(const uint8_t* src, char* dest) { + const char cset[] = "0123456789abcdef"; + for (size_t j = 0; j < 20; j++) { + dest[j * 2] = cset[((src[j] & 0xF0) >> 4)]; + dest[j * 2 + 1] = cset[(src[j] & 0xF)]; + } + dest[40] = '\0'; +} + +} // namespace + +Interpreter::Interpreter() { + lua_ = luaL_newstate(); + InitLua(lua_); +} + +Interpreter::~Interpreter() { + lua_close(lua_); +} + +void Interpreter::Fingerprint(string_view body, char* fp) { + SHA_CTX ctx; + uint8_t buf[20]; + + SHA1_Init(&ctx); + SHA1_Update(&ctx, body.data(), body.size()); + SHA1_Final(buf, &ctx); + fp[0] = 'f'; + fp[1] = '_'; + ToHex(buf, fp + 2); +} + +bool Interpreter::AddFunction(string_view body, string* result) { + char funcname[43]; + Fingerprint(body, funcname); + + string script = absl::StrCat("function ", funcname, "() \n"); + absl::StrAppend(&script, body, "\nend"); + + int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script"); + if (res == 0) { + res = lua_pcall(lua_, 0, 0, 0); // run func definition code + } + + if (res) { + result->assign(lua_tostring(lua_, -1)); + lua_pop(lua_, 1); // Remove the error. + + return false; + } + + result->assign(funcname); + + return true; +} + +bool Interpreter::RunFunction(const char* f_id, std::string* error) { + lua_getglobal(lua_, "__redis__err__handler"); + int type = lua_getglobal(lua_, f_id); + if (type != LUA_TFUNCTION) { + error->assign("function not found"); // TODO: noscripterr. + lua_pop(lua_, 2); + + return false; + } + + /* We have zero arguments and expect + * a single return value. */ + int err = lua_pcall(lua_, 0, 1, -2); + + if (err) { + *error = lua_tostring(lua_, -1); + } + return err == 0; +} + +} // namespace dfly diff --git a/core/interpreter.h b/core/interpreter.h new file mode 100644 index 000000000..f7ee4ab68 --- /dev/null +++ b/core/interpreter.h @@ -0,0 +1,43 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +typedef struct lua_State lua_State; + +namespace dfly { + +class Interpreter { + public: + Interpreter(); + ~Interpreter(); + + Interpreter(const Interpreter&) = delete; + void operator=(const Interpreter&) = delete; + + // Note: We leak the state for now. + // Production code should not access this method. + lua_State* lua() { + return lua_; + } + + // returns false if an error happenned, sets error string into result. + // otherwise, returns true and sets result to function id. + bool AddFunction(std::string_view body, std::string* result); + + // Runs already added function f_id returned by a successful call to AddFunction(). + // Returns: true if the call succeeded, otherwise fills error and returns false. + bool RunFunction(const char* f_id, std::string* err); + + // fp must point to buffer with at least 43 chars. + // fp[42] will be set to '\0'. + static void Fingerprint(std::string_view body, char* fp); + + private: + lua_State* lua_; +}; + +} // namespace dfly diff --git a/core/interpreter_test.cc b/core/interpreter_test.cc new file mode 100644 index 000000000..e3cb37bf5 --- /dev/null +++ b/core/interpreter_test.cc @@ -0,0 +1,67 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/interpreter.h" + +extern "C" { +#include +#include +} + +#include + +#include "base/gtest.h" +#include "base/logging.h" + +namespace dfly { +using namespace std; + +class InterpreterTest : public ::testing::Test { + protected: + InterpreterTest() { + } + + lua_State* lua() { + return intptr_.lua(); + } + + void RunInline(string_view buf, const char* name) { + CHECK_EQ(0, luaL_loadbuffer(lua(), buf.data(), buf.size(), name)); + CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0)); + } + + Interpreter intptr_; +}; + +TEST_F(InterpreterTest, Basic) { + RunInline(R"( + function foo(n) + return n,n+1 + end)", + "code1"); + + int type = lua_getglobal(lua(), "foo"); + ASSERT_EQ(LUA_TFUNCTION, type); + lua_pushnumber(lua(), 42); + lua_pcall(lua(), 1, 2, 0); + int val1 = lua_tointeger(lua(), -1); + int val2 = lua_tointeger(lua(), -2); + lua_pop(lua(), 2); + + EXPECT_EQ(43, val1); + EXPECT_EQ(42, val2); + EXPECT_EQ(0, lua_gettop(lua())); +} + +TEST_F(InterpreterTest, Add) { + string res1, res2; + + EXPECT_TRUE(intptr_.AddFunction("return 0", &res1)); + EXPECT_EQ(0, lua_gettop(lua())); + EXPECT_FALSE(intptr_.AddFunction("foobar", &res2)); + EXPECT_THAT(res2, testing::HasSubstr("syntax error")); + EXPECT_EQ(0, lua_gettop(lua())); +} + +} // namespace dfly diff --git a/helio b/helio index 3ee017cce..a300a704e 160000 --- a/helio +++ b/helio @@ -1 +1 @@ -Subproject commit 3ee017cce280493c845a010a28caa4cf1d0f4e9b +Subproject commit a300a704e193d115333f41a81438ba74d3df8c51 diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 2947714de..892c47cd0 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -24,4 +24,5 @@ cxx_test(rdb_test dfly_test_lib DATA testdata/empty.rdb testdata/small.rdb LABEL add_custom_target(check_dfly WORKING_DIRECTORY .. COMMAND ctest -L DFLY) add_dependencies(check_dfly dragonfly_test list_family_test - generic_family_test memcache_parser_test redis_parser_test string_family_test) + generic_family_test memcache_parser_test rdb_test + redis_parser_test string_family_test) diff --git a/server/common.cc b/server/common.cc index fe5e5b1e7..ecd404079 100644 --- a/server/common.cc +++ b/server/common.cc @@ -7,7 +7,6 @@ #include "base/logging.h" #include "server/common_types.h" #include "server/error.h" -#include "server/global_state.h" #include "server/server_state.h" namespace dfly { @@ -22,6 +21,23 @@ ServerState::ServerState() { ServerState::~ServerState() { } +void ServerState::Init() { + gstate_ = GlobalState::IDLE; +} + +void ServerState::Shutdown() { + gstate_ = GlobalState::SHUTTING_DOWN; + interpreter_.reset(); +} + +Interpreter& ServerState::GetInterpreter() { + if (!interpreter_) { + interpreter_.emplace(); + } + + return interpreter_.value(); +} + #define ADD(x) (x) += o.x ConnectionStats& ConnectionStats::operator+=(const ConnectionStats& o) { diff --git a/server/conn_context.h b/server/conn_context.h index 31ee31fd9..ad4ac6bfe 100644 --- a/server/conn_context.h +++ b/server/conn_context.h @@ -31,6 +31,10 @@ struct ConnectionState { enum Mask : uint32_t { ASYNC_DISPATCH = 1, // whether a command is handled via async dispatch. CONN_CLOSING = 2, // could be because of unrecoverable error or planned action. + + // Whether this connection belongs to replica, i.e. a dragonfly slave is connected to this + // host (master) via this connection to sync from it. + REPL_CONNECTION = 2, }; uint32_t mask = 0; // A bitmask of Mask values. diff --git a/server/dragonfly_connection.cc b/server/dragonfly_connection.cc index b033c66b0..c5fb4779d 100644 --- a/server/dragonfly_connection.cc +++ b/server/dragonfly_connection.cc @@ -280,6 +280,11 @@ finish: stats->read_buf_capacity -= io_buf_.Capacity(); + // Update num_replicas if this was a replica connection. + if (cc_->conn_state.mask & ConnectionState::REPL_CONNECTION) { + --stats->num_replicas; + } + if (cc_->ec()) { ec = cc_->ec(); } else { diff --git a/server/main_service.cc b/server/main_service.cc index 023a77e03..01da07431 100644 --- a/server/main_service.cc +++ b/server/main_service.cc @@ -72,6 +72,8 @@ void Service::Init(util::AcceptServer* acceptor, const InitOpts& opts) { shard_set_.Init(shard_num); pp_.Await([&](uint32_t index, ProactorBase* pb) { + ServerState::tlocal()->Init(); + if (index < shard_count()) { shard_set_.InitThreadLocal(pb, !opts.disable_time_update); } @@ -96,6 +98,8 @@ void Service::Shutdown() { request_latency_usec.Shutdown(); ping_qps.Shutdown(); + pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Shutdown(); }); + // to shutdown all the runtime components that depend on EngineShard. server_family_.Shutdown(); StringFamily::Shutdown(); @@ -129,7 +133,20 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) { return cntx->SendError(absl::StrCat("unknown command `", cmd_str, "`")); } + if (etl.gstate() == GlobalState::LOADING || etl.gstate() == GlobalState::SHUTTING_DOWN) { + string err = absl::StrCat("Can not execute during ", GlobalState::Name(etl.gstate())); + cntx->SendError(err); + return; + } + + bool is_write_cmd = cid->opt_mask() & CO::WRITE; bool under_multi = cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd; + + if (!etl.is_master && is_write_cmd) { + cntx->SendError("-READONLY You can't write against a read only replica."); + return; + } + if ((cid->arity() > 0 && args.size() != size_t(cid->arity())) || (cid->arity() < 0 && args.size() < size_t(-cid->arity()))) { return cntx->SendError(WrongNumArgsError(cmd_str)); @@ -272,6 +289,12 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) { return cntx->SendOk(); } +void Service::Eval(CmdArgList args, ConnectionContext* cntx) { + Interpreter& script = ServerState::tlocal()->GetInterpreter(); + script.lua(); + return cntx->SendOk(); +} + void Service::Exec(CmdArgList args, ConnectionContext* cntx) { if (cntx->conn_state.exec_state == ConnectionState::EXEC_INACTIVE) { return cntx->SendError("EXEC without MULTI"); @@ -315,9 +338,12 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { VarzValue::Map Service::GetVarzStats() { VarzValue::Map res; - atomic_ulong num_keys{0}; - shard_set_.RunBriefInParallel([&](EngineShard* es) { num_keys += es->db_slice().DbSize(0); }); - res.emplace_back("keys", VarzValue::FromInt(num_keys.load())); + Metrics m = server_family_.GetMetrics(); + + res.emplace_back("keys", VarzValue::FromInt(m.db.key_count)); + res.emplace_back("obj_mem_usage", VarzValue::FromInt(m.db.obj_memory_usage)); + double load = double(m.db.key_count) / (1 + m.db.bucket_count); + res.emplace_back("table_load_factor", VarzValue::FromDouble(load)); return res; } @@ -338,6 +364,7 @@ void Service::RegisterCommands() { registry_ << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0}.HFUNC(Quit) << CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING | CO::STALE, 1, 0, 0, 0}.HFUNC( Multi) + << CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.HFUNC(Eval) << CI{"EXEC", kExecMask, 1, 0, 0, 0}.SetHandler(cb_exec); StringFamily::Register(®istry_); diff --git a/server/main_service.h b/server/main_service.h index 253b06cfa..a1d16ca66 100644 --- a/server/main_service.h +++ b/server/main_service.h @@ -64,6 +64,8 @@ class Service { private: static void Quit(CmdArgList args, ConnectionContext* cntx); static void Multi(CmdArgList args, ConnectionContext* cntx); + static void Eval(CmdArgList args, ConnectionContext* cntx); + void Exec(CmdArgList args, ConnectionContext* cntx); diff --git a/server/server_family.cc b/server/server_family.cc index e73305257..9c8797064 100644 --- a/server/server_family.cc +++ b/server/server_family.cc @@ -63,6 +63,7 @@ error_code CreateDirs(fs::path dir_path) { } return ec; } + } // namespace ServerFamily::ServerFamily(Service* engine) @@ -80,6 +81,7 @@ void ServerFamily::Init(util::AcceptServer* acceptor) { void ServerFamily::Shutdown() { VLOG(1) << "ServerFamily::Shutdown"; + pp_.GetNextProactor()->Await([this] { unique_lock lk(replica_of_mu_); if (replica_) { @@ -176,7 +178,7 @@ void ServerFamily::Save(CmdArgList args, ConnectionContext* cntx) { return; } - pp_.Await([](auto*) { ServerState::tlocal()->state = GlobalState::SAVING; }); + pp_.Await([](auto*) { ServerState::tlocal()->set_gstate(GlobalState::SAVING); }); unique_ptr<::io::WriteFile> wf(*res); auto start = absl::Now(); @@ -200,7 +202,7 @@ void ServerFamily::Save(CmdArgList args, ConnectionContext* cntx) { return; } - pp_.Await([](auto*) { ServerState::tlocal()->state = GlobalState::IDLE; }); + pp_.Await([](auto*) { ServerState::tlocal()->set_gstate(GlobalState::IDLE); }); CHECK_EQ(GlobalState::SAVING, global_state_.Clear()); absl::Duration dur = absl::Now() - start; @@ -243,7 +245,7 @@ Metrics ServerFamily::GetMetrics() const { void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) { const char kInfo1[] = R"(# Server -redis_version:6.2.0 +redis_version:1.9.9 redis_mode:standalone arch_bits:64 multiplexing_api:iouring @@ -292,6 +294,9 @@ tcp_port:)"; absl::StrAppend(&info, "master_last_io_seconds_ago:", rinfo.master_last_io_sec, "\n"); absl::StrAppend(&info, "master_sync_in_progress:", rinfo.sync_in_progress, "\n"); } + absl::StrAppend(&info, "\n# Keyspace\n"); + absl::StrAppend(&info, "db0:keys=xxx,expires=yyy,avg_ttl=zzz\n"); // TODO + cntx->SendBulkString(info); } @@ -308,7 +313,6 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) { auto repl_ptr = replica_; CHECK(repl_ptr); - pp_.AwaitFiberOnAll([&](util::ProactorBase* pb) { ServerState::tlocal()->is_master = true; }); replica_->Stop(); replica_.reset(); @@ -378,6 +382,8 @@ void ServerFamily::SyncGeneric(std::string_view repl_master_id, uint64_t offs, return; } + cntx->conn_state.mask |= ConnectionState::REPL_CONNECTION; + ServerState::tl_connection_stats()->num_replicas += 1; // TBD. } diff --git a/server/server_state.h b/server/server_state.h index 002bb9f70..de6d8c578 100644 --- a/server/server_state.h +++ b/server/server_state.h @@ -4,10 +4,12 @@ #pragma once +#include #include #include "server/common_types.h" #include "server/global_state.h" +#include "core/interpreter.h" namespace dfly { @@ -32,7 +34,9 @@ class ServerState { // public struct - to allow initialization. ServerState(); ~ServerState(); - GlobalState::S state = GlobalState::IDLE; + void Init(); + void Shutdown(); + bool is_master = true; ConnectionStats connection_stats; @@ -49,8 +53,15 @@ class ServerState { // public struct - to allow initialization. return live_transactions_; } + GlobalState::S gstate() const { return gstate_;} + void set_gstate(GlobalState::S s) { gstate_ = s; } + + Interpreter& GetInterpreter(); + private: int64_t live_transactions_ = 0; + std::optional interpreter_; + GlobalState::S gstate_ = GlobalState::IDLE; static thread_local ServerState state_; };