diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 0cf424d5d..e1ef7ee92 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -5,6 +5,7 @@ #include "server/string_family.h" #include +#include #include #include @@ -14,6 +15,8 @@ #include "base/flags.h" #include "base/logging.h" +#include "base/stl_util.h" +#include "facade/cmd_arg_parser.h" #include "server/acl/acl_commands_def.h" #include "server/command_registry.h" #include "server/conn_context.h" @@ -693,49 +696,48 @@ void SetCmd::RecordJournal(const SetParams& params, string_view key, string_view } void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) { - string_view key = ArgS(args, 0); - string_view value = ArgS(args, 1); + facade::CmdArgParser parser{args}; + + auto [key, value] = parser.Next(); SetCmd::SetParams sparams; sparams.memcache_flags = cntx->conn_state.memcache_flag; - int64_t int_arg; - SinkReplyBuilder* builder = cntx->reply_builder(); + facade::SinkReplyBuilder* builder = cntx->reply_builder(); - for (size_t i = 2; i < args.size(); ++i) { - ToUpper(&args[i]); + while (parser.HasNext()) { + parser.ToUpper(); + if (base::_in(parser.Peek(), {"EX", "PX", "EXAT", "PXAT"})) { + auto [opt, int_arg] = parser.Next(); - string_view cur_arg = ArgS(args, i); + if (auto err = parser.Error(); err) { + return builder->SendError(err->MakeReply()); + } - if ((cur_arg == "EX" || cur_arg == "PX" || cur_arg == "EXAT" || cur_arg == "PXAT") && - !(sparams.flags & SetCmd::SET_KEEP_EXPIRE) && - !(sparams.flags & SetCmd::SET_EXPIRE_AFTER_MS)) { - sparams.flags |= SetCmd::SET_EXPIRE_AFTER_MS; - bool is_ms = (cur_arg == "PX" || cur_arg == "PXAT"); - ++i; - if (i == args.size()) { + // We can set expiry only once. + if (sparams.flags & SetCmd::SET_EXPIRE_AFTER_MS) return builder->SendError(kSyntaxErr); - } - string_view ex = ArgS(args, i); - if (!absl::SimpleAtoi(ex, &int_arg)) { - return builder->SendError(kInvalidIntErr); - } + sparams.flags |= SetCmd::SET_EXPIRE_AFTER_MS; // Since PXAT/EXAT can change this, we need to check this ahead if (int_arg <= 0) { return builder->SendError(InvalidExpireTime("set")); } + + bool is_ms = (opt[0] == 'P'); + // for []AT we need to take expiration time as absolute from the value given // check here and if the time is in the past, return OK but don't set it // Note that the time pass here for PXAT is in milliseconds, we must not change it! - if (cur_arg == "EXAT" || cur_arg == "PXAT") { + if (absl::EndsWith(opt, "AT")) { int_arg = AbsExpiryToTtl(int_arg, is_ms); if (int_arg < 0) { // this happened in the past, just return, for some reason Redis reports OK in this case return builder->SendStored(); } } + if (is_ms) { if (int_arg > kMaxExpireDeadlineMs) { int_arg = kMaxExpireDeadlineMs; @@ -747,22 +749,26 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) { int_arg *= 1000; } sparams.expire_after_ms = int_arg; - } else if (cur_arg == "NX" && !(sparams.flags & SetCmd::SET_IF_EXISTS)) { - sparams.flags |= SetCmd::SET_IF_NOTEXIST; - } else if (cur_arg == "XX" && !(sparams.flags & SetCmd::SET_IF_NOTEXIST)) { - sparams.flags |= SetCmd::SET_IF_EXISTS; - } else if (cur_arg == "KEEPTTL" && !(sparams.flags & SetCmd::SET_EXPIRE_AFTER_MS)) { - sparams.flags |= SetCmd::SET_KEEP_EXPIRE; - } else if (cur_arg == "GET") { - sparams.flags |= SetCmd::SET_GET; - } else if (cur_arg == "STICK") { - sparams.flags |= SetCmd::SET_STICK; } else { - return builder->SendError(kSyntaxErr); + uint16_t flag = parser.Switch( // + "GET", SetCmd::SET_GET, "STICK", SetCmd::SET_STICK, "KEEPTTL", SetCmd::SET_KEEP_EXPIRE, + "XX", SetCmd::SET_IF_EXISTS, "NX", SetCmd::SET_IF_NOTEXIST); + sparams.flags |= flag; } } - const auto result{SetGeneric(cntx, sparams, key, value, true)}; + if (auto err = parser.Error(); err) { + return builder->SendError(err->MakeReply()); + } + + auto has_mask = [&](uint16_t m) { return (sparams.flags & m) == m; }; + + if (has_mask(SetCmd::SET_IF_EXISTS | SetCmd::SET_IF_NOTEXIST) || + has_mask(SetCmd::SET_KEEP_EXPIRE | SetCmd::SET_EXPIRE_AFTER_MS)) { + return builder->SendError(kSyntaxErr); + } + + OpResult result{SetGeneric(cntx, sparams, key, value, true)}; if (sparams.flags & SetCmd::SET_GET) { auto* rb = static_cast(cntx->reply_builder()); @@ -783,7 +789,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) { return builder->SendError(kOutOfMemory); } - CHECK_EQ(result, OpStatus::SKIPPED); // in case of NX option + DCHECK_EQ(result, OpStatus::SKIPPED); // in case of NX option builder->SendSetSkipped(); }