Lua script async calls (#1070)

Introduces squashing for scripts and a new `redis.acall` command for async commands
This commit is contained in:
Vladislav 2023-04-12 23:37:25 +03:00 committed by GitHub
parent 282c168d34
commit 70cf436c05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 127 additions and 29 deletions

View file

@ -4,6 +4,7 @@
#include "core/interpreter.h"
#include <absl/container/fixed_array.h>
#include <absl/strings/str_cat.h>
#include <absl/time/clock.h>
#include <openssl/evp.h>
@ -367,6 +368,11 @@ Interpreter::Interpreter() {
lua_pushcfunction(lua_, RedisPCallCommand);
lua_settable(lua_, -3);
/* redis.acall */
lua_pushstring(lua_, "acall");
lua_pushcfunction(lua_, RedisACallCommand);
lua_settable(lua_, -3);
lua_pushstring(lua_, "sha1hex");
lua_pushcfunction(lua_, RedisSha1Command);
lua_settable(lua_, -3);
@ -614,7 +620,7 @@ void Interpreter::ResetStack() {
// Returns number of results, which is always 1 in this case.
// Please note that lua resets the stack once the function returns so no need
// to unwind the stack manually in the function (though lua allows doing this).
int Interpreter::RedisGenericCommand(bool raise_error) {
int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
/* By using Lua debug hooks it is possible to trigger a recursive call
* to luaRedisGenericCommand(), which normally should never happen.
* To make this function reentrant is futile and makes it slower, but
@ -646,7 +652,9 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
size_t blob_len = 0;
char tmpbuf[64];
for (int idx = 1; idx <= argc; idx++) {
// Determine size required for backing storage for all args.
// Skip command name (idx=1), as its stored in a separate buffer.
for (int idx = 2; idx <= argc; idx++) {
switch (lua_type(lua_, idx)) {
case LUA_TNUMBER:
if (lua_isinteger(lua_, idx)) {
@ -667,14 +675,20 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
}
}
// backing storage.
unique_ptr<char[]> blob(new char[blob_len + 8]); // 8 safety.
vector<absl::Span<char>> cmdargs(argc);
char* cur = blob.get();
char name_buffer[32]; // backing storage for cmd name
string buffer(blob_len + 4, '\0'); // backing storage for args
absl::FixedArray<absl::Span<char>, 4> args(argc);
char* cur = buffer.data();
char* end = cur + blob_len;
for (int j = 0; j < argc; j++) {
unsigned idx = j + 1;
// Copy command name to name_buffer and set it as first arg.
unsigned len = lua_rawlen(lua_, 1);
DCHECK_LT(len, ABSL_ARRAYSIZE(name_buffer));
memcpy(name_buffer, lua_tostring(lua_, 1), len);
args[0] = {name_buffer, len};
for (int idx = 2; idx <= argc; idx++) {
size_t len = 0;
switch (lua_type(lua_, idx)) {
case LUA_TNUMBER:
@ -694,7 +708,7 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
memcpy(cur, lua_tostring(lua_, idx), len);
};
cmdargs[j] = {cur, len};
args[idx - 1] = {cur, len};
cur += len;
}
@ -702,8 +716,10 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
* and this way we guaranty we will have room on the stack for the result. */
lua_pop(lua_, argc);
RedisTranslator translator(lua_);
redis_func_(MutSliceSpan{cmdargs}, &translator);
DCHECK_EQ(1, lua_gettop(lua_));
redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async});
if (!async)
DCHECK_EQ(1, lua_gettop(lua_));
cmd_depth_--;
@ -712,12 +728,17 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
int Interpreter::RedisCallCommand(lua_State* lua) {
void** ptr = static_cast<void**>(lua_getextraspace(lua));
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(true);
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(true, false);
}
int Interpreter::RedisPCallCommand(lua_State* lua) {
void** ptr = static_cast<void**>(lua_getextraspace(lua));
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(false);
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(false, false);
}
int Interpreter::RedisACallCommand(lua_State* lua) {
void** ptr = static_cast<void**>(lua_getextraspace(lua));
return reinterpret_cast<Interpreter*>(*ptr)->RedisGenericCommand(false, true);
}
Interpreter* InterpreterManager::Get() {

View file

@ -32,7 +32,21 @@ class ObjectExplorer {
class Interpreter {
public:
using RedisFunc = std::function<void(MutSliceSpan, ObjectExplorer*)>;
// Arguments received from redis.call
struct CallArgs {
// Full arguments, including cmd name.
MutSliceSpan args;
// Pointer to backing storage for args (excluding cmd name).
// Moving can invalidate arg slice pointers. Moved by async to re-use buffer.
std::string* buffer;
ObjectExplorer* translator;
bool async; // async by redis.acall
};
using RedisFunc = std::function<void(CallArgs)>;
Interpreter();
~Interpreter();
@ -97,10 +111,11 @@ class Interpreter {
bool AddInternal(const char* f_id, std::string_view body, std::string* error);
bool IsTableSafe() const;
int RedisGenericCommand(bool raise_error);
int RedisGenericCommand(bool raise_error, bool async);
static int RedisCallCommand(lua_State* lua);
static int RedisPCallCommand(lua_State* lua);
static int RedisACallCommand(lua_State* lua);
lua_State* lua_;
unsigned cmd_depth_ = 0;

View file

@ -255,7 +255,9 @@ TEST_F(InterpreterTest, Execute) {
}
TEST_F(InterpreterTest, Call) {
auto cb = [](MutSliceSpan span, ObjectExplorer* reply) {
auto cb = [](auto ca) {
auto* reply = ca.translator;
auto span = ca.args;
CHECK_GE(span.size(), 1u);
string_view cmd{span[0].data(), span[0].size()};
if (cmd == "string") {
@ -291,7 +293,8 @@ TEST_F(InterpreterTest, Call) {
}
TEST_F(InterpreterTest, CallArray) {
auto cb = [](MutSliceSpan span, ObjectExplorer* reply) {
auto cb = [](auto ca) {
auto* reply = ca.translator;
reply->OnArrayStart(2);
reply->OnArrayStart(1);
reply->OnArrayStart(2);

View file

@ -31,6 +31,15 @@ StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args)
}
}
StoredCmd::StoredCmd(string&& buffer, const CommandId* cid, CmdArgList args)
: cid_{cid}, buffer_{move(buffer)}, sizes_(args.size()) {
for (unsigned i = 0; i < args.size(); i++) {
// Assume tightly packed list.
DCHECK(i + 1 == args.size() || args[i].data() + args[i].size() == args[i + 1].data());
sizes_[i] = args[i].size();
}
}
void StoredCmd::Fill(CmdArgList args) {
CHECK_GE(args.size(), sizes_.size());
unsigned offset = 0;

View file

@ -24,6 +24,9 @@ class StoredCmd {
public:
StoredCmd(const CommandId* cid, CmdArgList args);
// Create on top of already filled tightly-packed buffer.
StoredCmd(std::string&& buffer, const CommandId* cid, CmdArgList args);
size_t NumArgs() const;
// Fill the arg list with stored arguments, it should be at least of size NumArgs().
@ -71,7 +74,8 @@ struct ConnectionState {
// Lua-script related data.
struct ScriptInfo {
bool is_write = true;
absl::flat_hash_set<std::string_view> keys;
absl::flat_hash_set<std::string_view> keys; // declared keys
std::vector<StoredCmd> async_cmds; // aggregated by acall
};
// PUB-SUB messaging related data.

View file

@ -1016,14 +1016,59 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendOk();
}
void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) {
DCHECK(cntx->transaction);
DVLOG(1) << "CallFromScript " << cntx->transaction->DebugId() << " " << ArgS(args, 0);
template <typename F> void WithoutReplies(ConnectionContext* cntx, F&& f) {
io::NullSink null_sink;
facade::RedisReplyBuilder rrb{&null_sink};
auto* old_rrb = cntx->Inject(&rrb);
InterpreterReplier replier(reply);
f();
cntx->Inject(old_rrb);
}
void Service::FlushEvalAsyncCmds(ConnectionContext* cntx, bool force) {
const int kMaxAsyncCmds = 100;
auto& info = cntx->conn_state.script_info;
if ((!force && info->async_cmds.size() <= kMaxAsyncCmds) || info->async_cmds.empty())
return;
auto* eval_cid = registry_.Find("EVAL");
DCHECK(eval_cid);
cntx->transaction->MultiSwitchCmd(eval_cid);
WithoutReplies(cntx,
[&] { MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx); });
info->async_cmds.clear();
}
void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) {
DCHECK(cntx->transaction);
DVLOG(1) << "CallFromScript " << cntx->transaction->DebugId() << " " << ArgS(ca.args, 0);
if (ca.async) {
auto& info = cntx->conn_state.script_info;
auto* cid = registry_.Find(facade::ToSV(ca.args[0]));
bool valid = true;
WithoutReplies(cntx, [&] { valid = VerifyCommand(cid, ca.args, cntx); });
if (!valid) // TODO: collect errors with capturing reply builder.
return;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1));
FlushEvalAsyncCmds(cntx, false);
return;
}
FlushEvalAsyncCmds(cntx, true);
InterpreterReplier replier(ca.translator);
facade::SinkReplyBuilder* orig = cntx->Inject(&replier);
DispatchCommand(std::move(args), cntx);
DispatchCommand(ca.args, cntx);
cntx->Inject(orig);
}
@ -1185,12 +1230,13 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
interpreter->SetGlobalArray("KEYS", eval_args.keys);
interpreter->SetGlobalArray("ARGV", eval_args.args);
interpreter->SetRedisFunc(
[cntx, this](CmdArgList args, ObjectExplorer* reply) { CallFromScript(args, reply, cntx); });
interpreter->SetRedisFunc([cntx, this](auto args) { CallFromScript(cntx, args); });
Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error);
absl::Cleanup clean = [interpreter]() { interpreter->ResetStack(); };
FlushEvalAsyncCmds(cntx, true);
cntx->conn_state.script_info.reset(); // reset script_info
// Conclude the transaction.

View file

@ -5,6 +5,7 @@
#pragma once
#include "base/varz_value.h"
#include "core/interpreter.h"
#include "facade/service_interface.h"
#include "server/command_registry.h"
#include "server/engine_shard_set.h"
@ -119,7 +120,8 @@ class Service : public facade::ServiceInterface {
void EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, ConnectionContext* cntx);
void CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx);
void FlushEvalAsyncCmds(ConnectionContext* cntx, bool force = false);
void CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& args);
void RegisterCommands();

View file

@ -139,8 +139,6 @@ void MultiCommandSquasher::ExecuteSquashed() {
if (order_.empty())
return;
VLOG(1) << "Executing " << order_.size() << " commands squashed";
Transaction* tx = cntx_->transaction;
if (track_keys_) {

View file

@ -258,7 +258,7 @@ void Transaction::InitByKeys(KeyIndex key_index) {
shard_data_.front().local_mask |= ACTIVE;
unique_shard_cnt_ = 1;
unique_shard_id_ = Shard(args_.front(), shard_set->size());
unique_shard_id_ = Shard(args_.front(), shard_set->size()); // TODO: Squashed bug
return;
}