diff --git a/core/interpreter_test.cc b/core/interpreter_test.cc index 25bf758a8..dc8b0ad2d 100644 --- a/core/interpreter_test.cc +++ b/core/interpreter_test.cc @@ -157,6 +157,18 @@ TEST_F(InterpreterTest, Basic) { } } +TEST_F(InterpreterTest, UnknownFunc) { + string_view code(R"( + function foo(n) + return myunknownfunc(1, n) + end)"); + + CHECK_EQ(0, luaL_loadbuffer(lua(), code.data(), code.size(), "code1")); + CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0)); + int type = lua_getglobal(lua(), "myunknownfunc"); + ASSERT_EQ(LUA_TNIL, type); +} + TEST_F(InterpreterTest, Stack) { RunInline(R"( local x = {} diff --git a/server/command_registry.cc b/server/command_registry.cc index 2f0ac6978..1029ac819 100644 --- a/server/command_registry.cc +++ b/server/command_registry.cc @@ -33,6 +33,7 @@ CommandRegistry::CommandRegistry() { CommandId cd("COMMAND", CO::RANDOM | CO::LOADING, 0, 0, 0, 0); cd.SetHandler([this](const auto& args, auto* cntx) { return Command(args, cntx); }); + const char* nm = cd.name(); cmd_map_.emplace(nm, std::move(cd)); } @@ -70,6 +71,38 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) { return *this; } +KeyIndex DetermineKeys(const CommandId* cid, const CmdArgList& args) { + DCHECK_EQ(0u, cid->opt_mask() & CO::GLOBAL_TRANS); + + KeyIndex key_index; + + if (cid->first_key_pos() > 0) { + key_index.start = cid->first_key_pos(); + int last = cid->last_key_pos(); + key_index.end = last > 0 ? last + 1 : (int(args.size()) + 1 + last); + key_index.step = cid->key_arg_step(); + + return key_index; + } + + string_view name{cid->name()}; + if (name == "EVAL" || name == "EVALSHA") { + DCHECK_GE(args.size(), 3u); + uint32_t num_keys; + + CHECK(absl::SimpleAtoi(ArgS(args, 2), &num_keys)); + key_index.start = 3; + key_index.end = 3 + num_keys; + key_index.step = 1; + + return key_index; + } + + LOG(FATAL) << "TBD: Not supported"; + + return key_index; +} + namespace CO { const char* OptName(CO::CommandOpt fl) { diff --git a/server/command_registry.h b/server/command_registry.h index cf74eade4..1e29a2819 100644 --- a/server/command_registry.h +++ b/server/command_registry.h @@ -37,8 +37,13 @@ const char* OptName(CommandOpt fl); class CommandId { public: - using Handler = std::function; - using ArgValidator = std::function; + using Handler = + fu2::function_base; + + using ArgValidator = fu2::function_base; /** * @brief Construct a new Command Id object @@ -152,4 +157,8 @@ class CommandRegistry { void Command(CmdArgList args, ConnectionContext* cntx); }; + +// Given the command and the arguments determines the keys range (index). +KeyIndex DetermineKeys(const CommandId* cid, const CmdArgList& args); + } // namespace dfly diff --git a/server/common_types.h b/server/common_types.h index d7019e29c..a4d076a22 100644 --- a/server/common_types.h +++ b/server/common_types.h @@ -43,6 +43,12 @@ struct KeyLockArgs { unsigned key_step; }; +// Describes key indices. +struct KeyIndex { + unsigned start; + unsigned end; // does not include this index (open limit). + unsigned step; // 1 for commands like mget. 2 for commands like mset. +}; struct ConnectionStats { uint32_t num_conns = 0; diff --git a/server/conn_context.h b/server/conn_context.h index 39e59cbe9..92e5a5e96 100644 --- a/server/conn_context.h +++ b/server/conn_context.h @@ -4,6 +4,8 @@ #pragma once +#include + #include "server/common_types.h" #include "server/reply_builder.h" @@ -50,6 +52,8 @@ struct ConnectionState { // Lua-script related data. struct Script { bool is_write = true; + + absl::flat_hash_set keys; }; std::optional