feat: Use const ConnectionContext in VerifyCommand (#1633)

* feat: Use const ConnectionContext in VerifyCommand

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-08-06 11:02:43 +03:00 committed by GitHub
parent 6faa530d42
commit 3bc1e26050
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 148 additions and 119 deletions

View file

@ -7,8 +7,12 @@
#include <absl/container/flat_hash_map.h>
#include <absl/types/span.h>
#include <optional>
#include <string>
#include <string_view>
#include <variant>
#include "facade/op_status.h"
namespace facade {
@ -56,6 +60,24 @@ struct ConnectionStats {
ConnectionStats& operator+=(const ConnectionStats& o);
};
struct ErrorReply {
explicit ErrorReply(std::string&& msg, std::string_view kind = {})
: message{move(msg)}, kind{kind} {
}
explicit ErrorReply(std::string_view msg, std::string_view kind = {}) : message{msg}, kind{kind} {
}
explicit ErrorReply(const char* msg,
std::string_view kind = {}) // to resolve ambiguity of constructors above
: message{std::string_view{msg}}, kind{kind} {
}
explicit ErrorReply(OpStatus status) : message{}, kind{}, status{status} {
}
std::variant<std::string, std::string_view> message;
std::string_view kind;
std::optional<OpStatus> status{std::nullopt};
};
inline MutableSlice ToMSS(absl::Span<uint8_t> span) {
return MutableSlice{reinterpret_cast<char*>(span.data()), span.size()};
}

View file

@ -220,6 +220,14 @@ void RedisReplyBuilder::SendError(string_view str, string_view err_type) {
}
}
void RedisReplyBuilder::SendError(ErrorReply error) {
if (error.status)
return SendError(*error.status);
string_view message_sv = visit([](auto&& str) -> string_view { return str; }, error.message);
SendError(message_sv, error.kind);
}
void RedisReplyBuilder::SendProtocolError(std::string_view str) {
SendError(absl::StrCat("-ERR Protocol error: ", str), "protocol_error");
}

View file

@ -8,6 +8,7 @@
#include <optional>
#include <string_view>
#include "facade/facade_types.h"
#include "facade/op_status.h"
#include "io/io.h"
@ -174,6 +175,8 @@ class RedisReplyBuilder : public SinkReplyBuilder {
void SetResp3(bool is_resp3);
void SendError(std::string_view str, std::string_view type = {}) override;
virtual void SendError(ErrorReply error);
void SendMGetResponse(absl::Span<const OptResp>) override;
void SendStored() override;

View file

@ -20,6 +20,13 @@ void CapturingReplyBuilder::SendError(std::string_view str, std::string_view typ
Capture(Error{str, type});
}
void CapturingReplyBuilder::SendError(ErrorReply error) {
SKIP_LESS(ReplyMode::ONLY_ERR);
string message = visit([](auto&& str) -> string { return string{move(str)}; }, error.message);
Capture(Error{move(message), error.kind});
}
void CapturingReplyBuilder::SendMGetResponse(absl::Span<const OptResp> arr) {
SKIP_LESS(ReplyMode::FULL);
Capture(vector<OptResp>{arr.begin(), arr.end()});

View file

@ -24,6 +24,7 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
public:
void SendError(std::string_view str, std::string_view type = {}) override;
void SendError(ErrorReply error) override;
void SendMGetResponse(absl::Span<const OptResp>) override;
// SendStored -> SendSimpleString("OK")

View file

@ -60,6 +60,21 @@ void CommandId::Invoke(CmdArgList args, ConnectionContext* cntx) const {
ent.second += (after - before) / 1000;
}
optional<facade::ErrorReply> CommandId::Validate(CmdArgList args) const {
if ((arity() > 0 && args.size() != size_t(arity())) ||
(arity() < 0 && args.size() < size_t(-arity()))) {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}
if (key_arg_step() == 2 && (args.size() % 2) == 0) {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}
if (validator_)
return validator_(args.subspan(1));
return nullopt;
}
CommandRegistry::CommandRegistry() {
vector<string> rename_command = GetFlag(FLAGS_rename_command);

View file

@ -8,6 +8,7 @@
#include <absl/types/span.h>
#include <functional>
#include <optional>
#include "base/function2.hpp"
#include "facade/command_id.h"
@ -66,11 +67,16 @@ class CommandId : public facade::CommandId {
void(CmdArgList, ConnectionContext*) const>;
using ArgValidator = fu2::function_base<true, true, fu2::capacity_default, false, false,
bool(CmdArgList, ConnectionContext*) const>;
std::optional<facade::ErrorReply>(CmdArgList) const>;
bool is_multi_key() const {
return (last_key_ != first_key_) || (opt_mask_ & CO::VARIADIC_KEYS);
}
void Invoke(CmdArgList args, ConnectionContext* cntx) const;
// Returns error if validation failed, otherwise nullopt
std::optional<facade::ErrorReply> Validate(CmdArgList args) const;
bool IsTransactional() const;
static const char* OptName(CO::CommandOpt fl);
CommandId& SetHandler(Handler f) {
handler_ = std::move(f);
@ -79,21 +85,13 @@ class CommandId : public facade::CommandId {
CommandId& SetValidator(ArgValidator f) {
validator_ = std::move(f);
return *this;
}
void Invoke(CmdArgList args, ConnectionContext* cntx) const;
// Returns true if validation succeeded.
bool Validate(CmdArgList args, ConnectionContext* cntx) const {
return !validator_ || validator_(std::move(args), cntx);
bool is_multi_key() const {
return (last_key_ != first_key_) || (opt_mask_ & CO::VARIADIC_KEYS);
}
bool IsTransactional() const;
static const char* OptName(CO::CommandOpt fl);
private:
Handler handler_;
ArgValidator validator_;

View file

@ -46,6 +46,7 @@ extern "C" {
#include "util/varz.h"
using namespace std;
using facade::ErrorReply;
using dfly::operator""_KB;
struct MaxMemoryFlag {
@ -429,21 +430,17 @@ bool IsSHA(string_view str) {
return true;
}
bool EvalValidator(CmdArgList args, ConnectionContext* cntx) {
optional<ErrorReply> EvalValidator(CmdArgList args) {
string_view num_keys_str = ArgS(args, 1);
int32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
(*cntx)->SendError(facade::kInvalidIntErr);
return false;
}
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0)
return ErrorReply{facade::kInvalidIntErr};
if (unsigned(num_keys) > args.size() - 2) {
(*cntx)->SendError("Number of keys can't be greater than number of args", kSyntaxErrType);
return false;
}
if (unsigned(num_keys) > args.size() - 2)
return ErrorReply{"Number of keys can't be greater than number of args", kSyntaxErrType};
return true;
return nullopt;
}
void Topkeys(const http::QueryArgs& args, HttpContext* send) {
@ -672,20 +669,20 @@ void Service::Shutdown() {
ThisFiber::SleepFor(10ms);
}
bool Service::CheckKeysOwnership(const CommandId* cid, CmdArgList args,
ConnectionContext* dfly_cntx) {
if (dfly_cntx->is_replicating) {
optional<ErrorReply> Service::CheckKeysOwnership(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx) {
if (dfly_cntx.is_replicating) {
// Always allow commands on the replication port, as it might be for future-owned keys.
return true;
return nullopt;
}
if (cid->first_key_pos() == 0) {
return true; // No key command.
return nullopt; // No key command.
}
OpResult<KeyIndex> key_index_res = DetermineKeys(cid, args);
if (!key_index_res) {
(*dfly_cntx)->SendError(key_index_res.status());
return false;
return ErrorReply{key_index_res.status()};
}
const auto& key_index = *key_index_res;
@ -704,24 +701,22 @@ bool Service::CheckKeysOwnership(const CommandId* cid, CmdArgList args,
}
if (cross_slot) {
(*dfly_cntx)->SendError("-CROSSSLOT Keys in request don't hash to the same slot");
return false;
return ErrorReply{"-CROSSSLOT Keys in request don't hash to the same slot"};
}
// Check keys slot is in my ownership
const ClusterConfig* cluster_config = cluster_family_.cluster_config();
if (cluster_config == nullptr) {
(*dfly_cntx)->SendError(kClusterNotConfigured);
return false;
return ErrorReply{kClusterNotConfigured};
}
if (keys_slot.has_value() && !cluster_config->IsMySlot(*keys_slot)) {
// See more details here: https://redis.io/docs/reference/cluster-spec/#moved-redirection
ClusterConfig::Node master = cluster_config->GetMasterNodeForSlot(*keys_slot);
(*dfly_cntx)->SendError(absl::StrCat("-MOVED ", *keys_slot, " ", master.ip, ":", master.port));
return false;
return ErrorReply{absl::StrCat("-MOVED ", *keys_slot, " ", master.ip, ":", master.port)};
}
return true;
return nullopt;
}
// Return OK if all keys are allowed to be accessed: either declared in EVAL or
@ -755,32 +750,36 @@ OpStatus CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_info, const C
return OpStatus::OK;
}
bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionContext* dfly_cntx) {
ServerState& etl = *ServerState::tlocal();
optional<ErrorReply> Service::VerifyCommandArguments(const CommandId* cid, CmdArgList args) {
string_view cmd_str = ArgS(args, 0);
absl::Cleanup multi_error([exec_info = &dfly_cntx->conn_state.exec_info] {
if (exec_info->IsCollecting()) {
exec_info->state = ConnectionState::ExecInfo::EXEC_ERROR;
}
});
if (cid == nullptr) {
(*dfly_cntx)->SendError(StrCat("unknown command `", cmd_str, "`"), "unknown_cmd");
lock_guard lk(mu_);
if (unknown_cmds_.size() < 1024)
unknown_cmds_[cmd_str]++;
return false;
return ErrorReply{StrCat("unknown command `", cmd_str, "`"), "unknown_cmd"};
}
return cid->Validate(args);
}
std::optional<ErrorReply> Service::VerifyCommand(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx) {
ServerState& etl = *ServerState::tlocal();
if (auto err = VerifyCommandArguments(cid, args); err)
return err;
bool is_trans_cmd = CO::IsTransKind(cid->name());
bool under_script = dfly_cntx->conn_state.script_info != nullptr;
bool under_script = dfly_cntx.conn_state.script_info != nullptr;
bool is_write_cmd = cid->opt_mask() & CO::WRITE;
bool under_multi = dfly_cntx.conn_state.exec_info.IsCollecting() && !is_trans_cmd;
bool allowed_by_state = true;
switch (etl.gstate()) {
case GlobalState::LOADING:
allowed_by_state = dfly_cntx->journal_emulated || (cid->opt_mask() & CO::LOADING);
allowed_by_state = dfly_cntx.journal_emulated || (cid->opt_mask() & CO::LOADING);
break;
case GlobalState::SHUTTING_DOWN:
allowed_by_state = false;
@ -791,91 +790,56 @@ bool Service::VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionCon
default:
break;
}
if (!allowed_by_state) {
VLOG(1) << "Command " << cid->name() << " not executed because global state is "
<< GlobalStateName(etl.gstate());
string err = StrCat("Can not execute during ", GlobalStateName(etl.gstate()));
(*dfly_cntx)->SendError(err);
return false;
return ErrorReply{StrCat("Can not execute during ", GlobalStateName(etl.gstate()))};
}
string_view cmd_name{cid->name()};
if (dfly_cntx->req_auth && !dfly_cntx->authenticated) {
if (dfly_cntx.req_auth && !dfly_cntx.authenticated) {
if (cmd_name != "AUTH" && cmd_name != "QUIT" && cmd_name != "HELLO") {
(*dfly_cntx)->SendError("-NOAUTH Authentication required.");
return false;
return ErrorReply{"-NOAUTH Authentication required."};
}
}
// only reset and quit are allow if this connection is used for monitoring
if (dfly_cntx->monitor && (cmd_name != "RESET" && cmd_name != "QUIT")) {
(*dfly_cntx)->SendError("Replica can't interact with the keyspace");
return false;
}
if (dfly_cntx.monitor && (cmd_name != "RESET" && cmd_name != "QUIT"))
return ErrorReply{"Replica can't interact with the keyspace"};
if (under_script && (cid->opt_mask() & CO::NOSCRIPT)) {
(*dfly_cntx)->SendError("This Redis command is not allowed from script");
return false;
}
if (under_script && (cid->opt_mask() & CO::NOSCRIPT))
return ErrorReply{"This Redis command is not allowed from script"};
bool is_write_cmd = cid->opt_mask() & CO::WRITE;
bool under_multi = dfly_cntx->conn_state.exec_info.IsCollecting() && !is_trans_cmd;
if (!etl.is_master && is_write_cmd && !dfly_cntx->is_replicating) {
(*dfly_cntx)->SendError("-READONLY You can't write against a read only replica.");
return false;
}
if ((cid->arity() > 0 && args.size() != size_t(cid->arity())) ||
(cid->arity() < 0 && args.size() < size_t(-cid->arity()))) {
(*dfly_cntx)->SendError(facade::WrongNumArgsError(cmd_str), kSyntaxErrType);
return false;
}
if (cid->key_arg_step() == 2 && (args.size() % 2) == 0) {
(*dfly_cntx)->SendError(facade::WrongNumArgsError(cmd_str), kSyntaxErrType);
return false;
}
// Validate more complicated cases with custom validators.
if (!cid->Validate(args.subspan(1), dfly_cntx)) {
return false;
}
if (!etl.is_master && is_write_cmd && !dfly_cntx.is_replicating)
return ErrorReply{"-READONLY You can't write against a read only replica."};
if (under_multi) {
if (cmd_name == "SELECT" || absl::EndsWith(cmd_name, "SUBSCRIBE")) {
(*dfly_cntx)->SendError(absl::StrCat("Can not call ", cmd_name, " within a transaction"));
return false;
}
if (cmd_name == "SELECT" || absl::EndsWith(cmd_name, "SUBSCRIBE"))
return ErrorReply{absl::StrCat("Can not call ", cmd_name, " within a transaction")};
if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB") {
(*dfly_cntx)->SendError(absl::StrCat("'", cmd_name, "' inside MULTI is not allowed"));
return false;
}
if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB")
return ErrorReply{absl::StrCat("'", cmd_name, "' inside MULTI is not allowed")};
}
if (ClusterConfig::IsEnabled() && !CheckKeysOwnership(cid, args.subspan(1), dfly_cntx)) {
return false;
if (ClusterConfig::IsEnabled()) {
if (auto err = CheckKeysOwnership(cid, args.subspan(1), dfly_cntx); err)
return err;
}
if (under_script && cid->IsTransactional()) {
OpStatus status = CheckKeysDeclared(*dfly_cntx->conn_state.script_info, cid, args.subspan(1),
dfly_cntx->transaction);
OpStatus status = CheckKeysDeclared(*dfly_cntx.conn_state.script_info, cid, args.subspan(1),
dfly_cntx.transaction);
if (status == OpStatus::KEY_NOTFOUND) {
(*dfly_cntx)->SendError("script tried accessing undeclared key");
return false;
}
if (status == OpStatus::KEY_NOTFOUND)
return ErrorReply{"script tried accessing undeclared key"};
if (status != OpStatus::OK) {
(*dfly_cntx)->SendError(status);
return false;
}
if (status != OpStatus::OK)
return ErrorReply{status};
}
std::move(multi_error).Cancel();
return true;
return nullopt;
}
void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) {
@ -900,8 +864,13 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
etl.RecordCmd();
if (!VerifyCommand(cid, args, dfly_cntx))
if (auto err = VerifyCommand(cid, args, *dfly_cntx); err) {
if (auto& exec_info = dfly_cntx->conn_state.exec_info; exec_info.IsCollecting())
exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR;
(*dfly_cntx)->SendError(move(*err));
return;
}
auto args_no_cmd = args.subspan(1);
@ -1251,8 +1220,10 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
ToUpper(&ca.args[0]);
auto* cid = registry_.Find(facade::ToSV(ca.args[0]));
if (!VerifyCommand(cid, ca.args, cntx))
if (auto err = VerifyCommand(cid, ca.args, *cntx); err) {
(*cntx)->SendError(move(*err));
return;
}
auto replies = ca.error_abort ? ReplyMode::ONLY_ERR : ReplyMode::NONE;
info->async_cmds.emplace_back(move(*ca.buffer), cid, ca.args.subspan(1), replies);

View file

@ -24,7 +24,6 @@ using facade::MemcacheParser;
class Service : public facade::ServiceInterface {
public:
using error_code = std::error_code;
struct InitOpts {
bool disable_time_update;
@ -120,12 +119,16 @@ class Service : public facade::ServiceInterface {
CmdArgList keys, args;
};
// Return false if command is invalid and reply with error.
bool VerifyCommand(const CommandId* cid, CmdArgList args, ConnectionContext* cntx);
// Verify command exists and has no obvious formatting errors
std::optional<facade::ErrorReply> VerifyCommandArguments(const CommandId* cid, CmdArgList args);
// Return false if not all keys are owned by the server when running in cluster mode.
// If false is returned error was sent to the client.
bool CheckKeysOwnership(const CommandId* cid, CmdArgList args, ConnectionContext* dfly_cntx);
// Verify command can be executed
std::optional<facade::ErrorReply> VerifyCommand(const CommandId* cid, CmdArgList args,
const ConnectionContext& cntx);
// Return error if not all keys are owned by the server when running in cluster mode
std::optional<facade::ErrorReply> CheckKeysOwnership(const CommandId* cid, CmdArgList args,
const ConnectionContext& dfly_cntx);
const CommandId* FindCmd(CmdArgList args) const;
@ -141,6 +144,7 @@ class Service : public facade::ServiceInterface {
base::VarzValue::Map GetVarzStats();
private:
util::ProactorPool& pp_;
ServerFamily server_family_;