feat: Refactor command verification before execution (#1652)

* feat: Refactor verifications

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-08-08 12:36:31 +03:00 committed by GitHub
parent 6b29a642bb
commit 16e512c60d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 169 additions and 97 deletions

View file

@ -60,18 +60,18 @@ void CommandId::Invoke(CmdArgList args, ConnectionContext* cntx) const {
ent.second += (after - before) / 1000;
}
optional<facade::ErrorReply> CommandId::Validate(CmdArgList args) const {
if ((arity() > 0 && args.size() != size_t(arity())) ||
(arity() < 0 && args.size() < size_t(-arity()))) {
optional<facade::ErrorReply> CommandId::Validate(CmdArgList tail_args) const {
if ((arity() > 0 && tail_args.size() + 1 != size_t(arity())) ||
(arity() < 0 && tail_args.size() + 1 < size_t(-arity()))) {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}
if (key_arg_step() == 2 && (args.size() % 2) == 0) {
if (key_arg_step() == 2 && (tail_args.size() % 2) != 0) {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}
if (validator_)
return validator_(args.subspan(1));
return validator_(tail_args);
return nullopt;
}

View file

@ -72,7 +72,7 @@ class CommandId : public facade::CommandId {
void Invoke(CmdArgList args, ConnectionContext* cntx) const;
// Returns error if validation failed, otherwise nullopt
std::optional<facade::ErrorReply> Validate(CmdArgList args) const;
std::optional<facade::ErrorReply> Validate(CmdArgList tail_args) const;
bool IsTransactional() const;

View file

@ -750,25 +750,18 @@ OpStatus CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_info, const C
return OpStatus::OK;
}
optional<ErrorReply> Service::VerifyCommandArguments(const CommandId* cid, CmdArgList args) {
string_view cmd_str = ArgS(args, 0);
if (cid == nullptr) {
lock_guard lk(mu_);
if (unknown_cmds_.size() < 1024)
unknown_cmds_[cmd_str]++;
return ErrorReply{StrCat("unknown command `", cmd_str, "`"), "unknown_cmd"};
}
return cid->Validate(args);
optional<ErrorReply> Service::VerifyCommandExecution(const CommandId* cid) {
// TODO: Move OOM check here
return nullopt;
}
std::optional<ErrorReply> Service::VerifyCommand(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx) {
std::optional<ErrorReply> Service::VerifyCommandState(const CommandId* cid, CmdArgList tail_args,
const ConnectionContext& dfly_cntx) {
DCHECK(cid);
ServerState& etl = *ServerState::tlocal();
if (auto err = VerifyCommandArguments(cid, args); err)
if (auto err = cid->Validate(tail_args); err)
return err;
bool is_trans_cmd = CO::IsTransKind(cid->name());
@ -824,13 +817,13 @@ std::optional<ErrorReply> Service::VerifyCommand(const CommandId* cid, CmdArgLis
}
if (ClusterConfig::IsEnabled()) {
if (auto err = CheckKeysOwnership(cid, args.subspan(1), dfly_cntx); err)
if (auto err = CheckKeysOwnership(cid, tail_args, dfly_cntx); err)
return err;
}
if (under_script && cid->IsTransactional()) {
OpStatus status = CheckKeysDeclared(*dfly_cntx.conn_state.script_info, cid, args.subspan(1),
dfly_cntx.transaction);
OpStatus status =
CheckKeysDeclared(*dfly_cntx.conn_state.script_info, cid, tail_args, dfly_cntx.transaction);
if (status == OpStatus::KEY_NOTFOUND)
return ErrorReply{"script tried accessing undeclared key"};
@ -846,7 +839,14 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
CHECK(!args.empty());
DCHECK_NE(0u, shard_set->size()) << "Init was not called";
ServerState& etl = *ServerState::tlocal();
ToUpper(&args[0]);
const CommandId* cid = FindCmd(args);
if (cid == nullptr) {
return (*cntx)->SendError(ReportUnknownCmd(ArgS(args, 0)));
}
ConnectionContext* dfly_cntx = static_cast<ConnectionContext*>(cntx);
bool under_script = bool(dfly_cntx->conn_state.script_info);
@ -859,12 +859,11 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
<< " in dbid=" << dfly_cntx->conn_state.db_index;
}
const CommandId* cid = FindCmd(args);
ServerState& etl = *ServerState::tlocal();
etl.RecordCmd();
if (auto err = VerifyCommand(cid, args, *dfly_cntx); err) {
auto args_no_cmd = args.subspan(1);
if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) {
if (auto& exec_info = dfly_cntx->conn_state.exec_info; exec_info.IsCollecting())
exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR;
@ -872,8 +871,6 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
return;
}
auto args_no_cmd = args.subspan(1);
bool is_trans_cmd = CO::IsTransKind(cid->name());
if (dfly_cntx->conn_state.exec_info.IsCollecting() && !is_trans_cmd) {
// TODO: protect against aggregating huge transactions.
@ -933,7 +930,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
// itself. EXEC does not use DispatchCommand for dispatching.
bool collect_stats =
dfly_cntx->transaction && (!dfly_cntx->transaction->IsMulti() || dispatching_in_multi);
if (!InvokeCmd(args.subspan(1), cid, dfly_cntx, collect_stats)) {
if (!InvokeCmd(cid, args_no_cmd, dfly_cntx, collect_stats)) {
dfly_cntx->reply_builder()->SendError("Internal Error");
dfly_cntx->reply_builder()->CloseConnection();
}
@ -946,10 +943,18 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
}
}
bool Service::InvokeCmd(CmdArgList args, const CommandId* cid, ConnectionContext* cntx,
bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionContext* cntx,
bool record_stats) {
DCHECK(cid);
DCHECK(!cid->Validate(tail_args));
if (auto err = VerifyCommandExecution(cid); err) {
(*cntx)->SendError(move(*err));
return true; // return false only for internal error aborts
}
try {
cid->Invoke(args, cntx);
cid->Invoke(tail_args, cntx);
} catch (std::exception& e) {
LOG(ERROR) << "Internal error, system probably unstable " << e.what();
return false;
@ -1068,6 +1073,14 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
dfly_cntx->conn_state.memcache_flag = 0;
}
ErrorReply Service::ReportUnknownCmd(string_view cmd_name) {
lock_guard lk(mu_);
if (unknown_cmds_.size() < 1024)
unknown_cmds_[cmd_name]++;
return ErrorReply{StrCat("unknown command `", cmd_name, "`"), "unknown_cmd"};
}
bool RequireAdminAuth() {
return !GetFlag(FLAGS_admin_nopass);
}
@ -1196,8 +1209,9 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
cntx->transaction->MultiSwitchCmd(eval_cid);
CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
WithReplies(&crb, cntx,
[&] { MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, true); });
WithReplies(&crb, cntx, [&] {
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, this, true, true);
});
info->async_cmds_heap_mem = 0;
info->async_cmds.clear();
@ -1214,28 +1228,32 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
facade::SinkReplyBuilder* orig = cntx->Inject(&replier);
absl::Cleanup clean = [orig, cntx] { cntx->Inject(orig); };
optional<ErrorReply> findcmd_err;
if (ca.async) {
auto& info = cntx->conn_state.script_info;
ToUpper(&ca.args[0]);
auto* cid = registry_.Find(facade::ToSV(ca.args[0]));
if (auto err = VerifyCommand(cid, ca.args, *cntx); err) {
(*cntx)->SendError(move(*err));
return;
// Full command verification happens during squashed execution
if (auto* cid = registry_.Find(ArgS(ca.args, 0)); cid != nullptr) {
auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies);
info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory();
} else if (ca.error_abort) { // If we don't abort on errors, we can ignore it completely
findcmd_err = ReportUnknownCmd(ArgS(ca.args, 0));
}
auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies);
info->async_cmds_heap_mem += info->async_cmds.back().UsedHeapMemory();
}
if (auto err = FlushEvalAsyncCmds(cntx, !ca.async); err) {
if (auto err = FlushEvalAsyncCmds(cntx, !ca.async || findcmd_err.has_value()); err) {
CapturingReplyBuilder::Apply(move(*err), &replier); // forward error to lua
*ca.requested_abort = true;
return;
}
if (findcmd_err.has_value()) {
replier.RedisReplyBuilder::SendError(move(*findcmd_err));
*ca.requested_abort |= ca.error_abort;
}
if (ca.async)
return;
@ -1600,7 +1618,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
if (!exec_info.body.empty()) {
if (absl::GetFlag(FLAGS_multi_exec_squash) && state == ExecEvalState::NONE) {
MultiCommandSquasher::Execute(absl::MakeSpan(exec_info.body), cntx);
MultiCommandSquasher::Execute(absl::MakeSpan(exec_info.body), cntx, this);
} else {
for (auto& scmd : exec_info.body) {
VLOG(2) << "TX CMD " << scmd.Cid()->name() << " " << scmd.NumArgs();
@ -1621,7 +1639,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
}
}
bool ok = InvokeCmd(args, scmd.Cid(), cntx, true);
bool ok = InvokeCmd(scmd.Cid(), args, cntx, true);
if (!ok || rb->GetError()) // checks for i/o error, not logical error.
break;
}

View file

@ -39,10 +39,22 @@ class Service : public facade::ServiceInterface {
void Shutdown();
// Prepare command execution, verify and execute, reply to context
void DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) final;
// Returns true if command was executed successfully.
bool InvokeCmd(CmdArgList args, const CommandId* cid, ConnectionContext* cntx, bool record_stats);
// Check VerifyCommandExecution and invoke command with args
bool InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionContext* reply_cntx,
bool record_stats = false);
// Verify command can be executed now (check out of memory), always called immediately before
// execution
std::optional<facade::ErrorReply> VerifyCommandExecution(const CommandId* cid);
// Verify command prepares excution in correct state.
// It's usually called before command execution. Only for multi/exec transactions it's checked
// when the command is queued for execution, not before the execution itself.
std::optional<facade::ErrorReply> VerifyCommandState(const CommandId* cid, CmdArgList tail_args,
const ConnectionContext& cntx);
void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
facade::ConnectionContext* cntx) final;
@ -52,6 +64,24 @@ class Service : public facade::ServiceInterface {
facade::ConnectionStats* GetThreadLocalConnectionStats() final;
const CommandId* FindCmd(std::string_view cmd) const {
return registry_.Find(cmd);
}
facade::ErrorReply ReportUnknownCmd(std::string_view cmd_name);
// Returns: the new state.
// if from equals the old state then the switch is performed "to" is returned.
// Otherwise, does not switch and returns the current state in the system.
// Upon switch, updates cached global state in threadlocal ServerState struct.
GlobalState SwitchState(GlobalState from, GlobalState to);
GlobalState GetGlobalState() const;
void ConfigureHttpHandlers(util::HttpListenerBase* base) final;
void OnClose(facade::ConnectionContext* cntx) final;
std::string GetContextInfo(facade::ConnectionContext* cntx) final;
uint32_t shard_count() const {
return shard_set->size();
}
@ -66,10 +96,6 @@ class Service : public facade::ServiceInterface {
absl::flat_hash_map<std::string, unsigned> UknownCmdMap() const;
const CommandId* FindCmd(std::string_view cmd) const {
return registry_.Find(cmd);
}
ScriptMgr* script_mgr() {
return server_family_.script_mgr();
}
@ -78,18 +104,6 @@ class Service : public facade::ServiceInterface {
return server_family_;
}
// Returns: the new state.
// if from equals the old state then the switch is performed "to" is returned.
// Otherwise, does not switch and returns the current state in the system.
// Upon switch, updates cached global state in threadlocal ServerState struct.
GlobalState SwitchState(GlobalState from, GlobalState to);
GlobalState GetGlobalState() const;
void ConfigureHttpHandlers(util::HttpListenerBase* base) final;
void OnClose(facade::ConnectionContext* cntx) final;
std::string GetContextInfo(facade::ConnectionContext* cntx) final;
private:
static void Quit(CmdArgList args, ConnectionContext* cntx);
static void Multi(CmdArgList args, ConnectionContext* cntx);
@ -119,13 +133,6 @@ class Service : public facade::ServiceInterface {
CmdArgList keys, args;
};
// Verify command exists and has no obvious formatting errors
std::optional<facade::ErrorReply> VerifyCommandArguments(const CommandId* cid, CmdArgList args);
// Verify command can be executed
std::optional<facade::ErrorReply> VerifyCommand(const CommandId* cid, CmdArgList args,
const ConnectionContext& cntx);
// Return error if not all keys are owned by the server when running in cluster mode
std::optional<facade::ErrorReply> CheckKeysOwnership(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx);

View file

@ -25,8 +25,9 @@ template <typename F> void IterateKeys(CmdArgList args, KeyIndex keys, F&& f) {
} // namespace
MultiCommandSquasher::MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx,
bool error_abort)
: cmds_{cmds}, cntx_{cntx}, base_cid_{nullptr}, error_abort_{error_abort} {
Service* service, bool verify_commands, bool error_abort)
: cmds_{cmds}, cntx_{cntx}, service_{service}, base_cid_{nullptr},
verify_commands_{verify_commands}, error_abort_{error_abort} {
auto mode = cntx->transaction->GetMultiMode();
base_cid_ = mode == Transaction::NON_ATOMIC ? nullptr : cntx->transaction->GetCId();
}
@ -50,6 +51,8 @@ MultiCommandSquasher::ShardExecInfo& MultiCommandSquasher::PrepareShardInfo(Shar
}
MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cmd) {
DCHECK(cmd->Cid());
if (!cmd->Cid()->IsTransactional() || (cmd->Cid()->opt_mask() & CO::BLOCKING) ||
(cmd->Cid()->opt_mask() & CO::GLOBAL_TRANS))
return SquashResult::NOT_SQUASHED;
@ -90,19 +93,28 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm
return need_flush ? SquashResult::SQUASHED_FULL : SquashResult::SQUASHED;
}
void MultiCommandSquasher::ExecuteStandalone(StoredCmd* cmd) {
bool MultiCommandSquasher::ExecuteStandalone(StoredCmd* cmd) {
DCHECK(order_.empty()); // check no squashed chain is interrupted
cmd->Fill(&tmp_keylist_);
auto args = absl::MakeSpan(tmp_keylist_);
if (verify_commands_) {
if (auto err = service_->VerifyCommandState(cmd->Cid(), args, *cntx_); err) {
(*cntx_)->SendError(move(*err));
return !error_abort_;
}
}
auto* tx = cntx_->transaction;
tx->MultiSwitchCmd(cmd->Cid());
cntx_->cid = cmd->Cid();
cmd->Fill(&tmp_keylist_);
auto args = absl::MakeSpan(tmp_keylist_);
if (cmd->Cid()->IsTransactional())
tx->InitByArgs(cntx_->conn_state.db_index, args);
cmd->Cid()->Invoke(args, cntx_);
service_->InvokeCmd(cmd->Cid(), args, cntx_);
return true;
}
OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard* es) {
@ -116,16 +128,25 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
absl::InlinedVector<MutableSlice, 4> arg_vec;
for (auto* cmd : sinfo.cmds) {
local_tx->MultiSwitchCmd(cmd->Cid());
local_cntx.cid = cmd->Cid();
crb.SetReplyMode(cmd->ReplyMode());
arg_vec.resize(cmd->NumArgs());
auto args = absl::MakeSpan(arg_vec);
cmd->Fill(args);
if (verify_commands_) {
// The shared context is used for state verification, the local one is only for replies
if (auto err = service_->VerifyCommandState(cmd->Cid(), args, *cntx_); err) {
crb.SendError(*move(err));
sinfo.replies.emplace_back(crb.Take());
continue;
}
}
local_tx->MultiSwitchCmd(cmd->Cid());
local_cntx.cid = cmd->Cid();
crb.SetReplyMode(cmd->ReplyMode());
local_tx->InitByArgs(parent_tx->GetDbIndex(), args);
cmd->Cid()->Invoke(args, &local_cntx);
service_->InvokeCmd(cmd->Cid(), args, &local_cntx);
sinfo.replies.emplace_back(crb.Take());
}
@ -179,7 +200,7 @@ bool MultiCommandSquasher::ExecuteSquashed() {
sinfo.cmds.clear();
order_.clear();
return aborted;
return !aborted;
}
void MultiCommandSquasher::Run() {
@ -190,12 +211,14 @@ void MultiCommandSquasher::Run() {
break;
if (res == SquashResult::NOT_SQUASHED || res == SquashResult::SQUASHED_FULL) {
if (ExecuteSquashed())
if (!ExecuteSquashed())
break;
}
if (res == SquashResult::NOT_SQUASHED)
ExecuteStandalone(&cmd);
if (res == SquashResult::NOT_SQUASHED) {
if (!ExecuteStandalone(&cmd))
break;
}
}
ExecuteSquashed(); // Flush leftover

View file

@ -8,6 +8,7 @@
#include "core/fibers.h"
#include "facade/reply_capture.h"
#include "server/conn_context.h"
#include "server/main_service.h"
namespace dfly {
@ -23,9 +24,9 @@ namespace dfly {
// contains a non-atomic multi transaction to execute squashed commands.
class MultiCommandSquasher {
public:
static void Execute(absl::Span<StoredCmd> cmds, ConnectionContext* cntx,
bool error_abort = false) {
MultiCommandSquasher{cmds, cntx, error_abort}.Run();
static void Execute(absl::Span<StoredCmd> cmds, ConnectionContext* cntx, Service* service,
bool verify_commands = false, bool error_abort = false) {
MultiCommandSquasher{cmds, cntx, service, verify_commands, error_abort}.Run();
}
private:
@ -45,7 +46,8 @@ class MultiCommandSquasher {
static constexpr int kMaxSquashing = 32;
private:
MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx, bool error_abort);
MultiCommandSquasher(absl::Span<StoredCmd> cmds, ConnectionContext* cntx, Service* Service,
bool verify_commands, bool error_abort);
// Lazy initialize shard info.
ShardExecInfo& PrepareShardInfo(ShardId sid);
@ -53,13 +55,13 @@ class MultiCommandSquasher {
// Retrun squash flags
SquashResult TrySquash(StoredCmd* cmd);
// Execute separate non-squashed cmd.
void ExecuteStandalone(StoredCmd* cmd);
// Execute separate non-squashed cmd. Return false if aborting on error.
bool ExecuteStandalone(StoredCmd* cmd);
// Callback that runs on shards during squashed hop.
facade::OpStatus SquashedHopCb(Transaction* parent_tx, EngineShard* es);
// Execute all currently squashed commands. Return true if aborting on error.
// Execute all currently squashed commands. Return false if aborting on error.
bool ExecuteSquashed();
// Run all commands until completion.
@ -70,11 +72,13 @@ class MultiCommandSquasher {
private:
absl::Span<StoredCmd> cmds_; // Input range of stored commands
ConnectionContext* cntx_; // Underlying context
Service* service_;
// underlying cid (exec or eval) for executing batch hops, nullptr for non-atomic
const CommandId* base_cid_;
bool error_abort_ = false; // Abort upon receiving error
bool verify_commands_ = false; // Whether commands need to be verified before execution
bool error_abort_ = false; // Abort upon receiving error
std::vector<ShardExecInfo> sharded_;
std::vector<ShardId> order_; // reply order for squashed cmds

View file

@ -18,6 +18,7 @@
ABSL_DECLARE_FLAG(uint32_t, multi_exec_mode);
ABSL_DECLARE_FLAG(bool, multi_exec_squash);
ABSL_DECLARE_FLAG(bool, lua_auto_async);
ABSL_DECLARE_FLAG(std::string, default_lua_flags);
namespace dfly {
@ -861,4 +862,23 @@ TEST_F(MultiEvalTest, MultiSomeEval) {
EXPECT_THAT(brpop_resp, ArgType(RespExpr::NIL_ARRAY));
}
TEST_F(MultiEvalTest, ScriptSquashingUknownCmd) {
absl::FlagSaver fs;
absl::SetFlag(&FLAGS_lua_auto_async, true);
// The script below contains two commands for which execution can't even be prepared
// (FIRST/SECOND WRONG). The first is issued with pcall, so its error should be completely
// ignored, the second one should cause an abort and no further commands should be executed
string_view s = R"(
redis.pcall('INCR', 'A')
redis.pcall('FIRST WRONG')
redis.pcall('INCR', 'A')
redis.call('SECOND WRONG')
redis.pcall('INCR', 'A')
)";
EXPECT_THAT(Run({"EVAL", s, "1", "A"}), ErrArg("unknown command `SECOND WRONG`"));
EXPECT_EQ(Run({"get", "A"}), "2");
}
} // namespace dfly