Handle GET parameter for SET command. (#1023)

* feat(server): Handle GET parameter for SET command.

Return previous value, as per Redis, in case GET was specified on SET.

---------

Signed-off-by: chakaz <chakaz@chakaz>
Co-authored-by: chakaz <chakaz@chakaz>
This commit is contained in:
Chaka 2023-04-04 15:52:27 +03:00 committed by GitHub
parent d6e6bcf5a9
commit ee7705a84d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 15 deletions

View file

@ -46,6 +46,12 @@ constexpr uint32_t kMinTieredLen = TieredStorage::kMinBlobLen;
string GetString(EngineShard* shard, const PrimeValue& pv) {
string res;
if (pv.ObjType() != OBJ_STRING) {
// An attempt to read a non-string's string value can happen when overriding a non-string value
// with a string value.
return "";
}
if (pv.IsExternal()) {
auto* tiered = shard->tiered_storage();
auto [offset, size] = pv.GetExternalSlice();
@ -340,24 +346,25 @@ OpStatus OpMSet(const OpArgs& op_args, ArgSlice args) {
for (size_t i = 0; i < args.size(); i += 2) {
DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1];
OpStatus res = sg.Set(params, args[i], args[i + 1]);
if (res != OpStatus::OK) { // OOM for example.
return res;
OpResult<optional<string>> res = sg.Set(params, args[i], args[i + 1]);
if (res.status() != OpStatus::OK) { // OOM for example.
return res.status();
}
}
return OpStatus::OK;
}
OpResult<void> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& sparams,
string_view key, string_view value, bool manual_journal) {
// See comment for SetCmd::Set() for when and how OpResult's value (i.e. optional<string>) is set.
OpResult<optional<string>> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& sparams,
string_view key, string_view value, bool manual_journal) {
DCHECK(cntx->transaction);
auto cb = [&](Transaction* t, EngineShard* shard) {
SetCmd sg(t->GetOpArgs(shard), manual_journal);
return sg.Set(sparams, key, value);
};
return cntx->transaction->ScheduleSingleHop(std::move(cb));
return cntx->transaction->ScheduleSingleHopT(std::move(cb));
}
// emission_interval_ms assumed to be positive
@ -470,9 +477,37 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_ms, reset_after_ms};
}
class SetResultBuilder {
public:
explicit SetResultBuilder(bool return_prev_value) : return_prev_value_(return_prev_value) {
}
void CachePrevValueIfNeeded(string_view value) {
if (return_prev_value_) {
prev_value_ = value;
}
}
// Returns either the previous value or `status`, depending on return_prev_value_.
OpResult<optional<string>> Return(OpStatus status) && {
if (return_prev_value_) {
return std::move(prev_value_);
} else {
return status;
}
}
private:
bool return_prev_value_;
std::optional<string> prev_value_;
};
} // namespace
OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) {
OpResult<optional<string>> SetCmd::Set(const SetParams& params, string_view key,
string_view value) {
SetResultBuilder result_builder(params.flags & SET_GET);
EngineShard* shard = op_args_.shard;
auto& db_slice = shard->db_slice();
@ -482,16 +517,20 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
if (params.IsConditionalSet()) {
const auto [it, expire_it] = db_slice.FindExt(op_args_.db_cntx, key);
if (IsValid(it)) {
result_builder.CachePrevValueIfNeeded(GetString(shard, it->second));
}
// Make sure that we have this key, and only add it if it does exists
if (params.flags & SET_IF_EXISTS) {
if (IsValid(it)) {
return SetExisting(params, it, expire_it, key, value);
return std::move(result_builder).Return(SetExisting(params, it, expire_it, key, value));
} else {
return OpStatus::SKIPPED;
return std::move(result_builder).Return(OpStatus::SKIPPED);
}
} else {
if (IsValid(it)) { // if the policy is not to overide and have the key, just return
return OpStatus::SKIPPED;
return std::move(result_builder).Return(OpStatus::SKIPPED);
}
}
}
@ -507,7 +546,8 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
PrimeIterator it = get<0>(add_res);
if (!get<2>(add_res)) { // Existing.
return SetExisting(params, it, get<1>(add_res), key, value);
result_builder.CachePrevValueIfNeeded(GetString(shard, it->second));
return std::move(result_builder).Return(SetExisting(params, it, get<1>(add_res), key, value));
}
// Adding new value.
@ -535,7 +575,7 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
RecordJournal(params, key, value);
}
return OpStatus::OK;
return std::move(result_builder).Return(OpStatus::OK);
}
OpStatus SetCmd::SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it,
@ -692,6 +732,16 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
const auto result{SetGeneric(cntx, sparams, key, value, true)};
if (sparams.flags & SetCmd::SET_GET) {
// When SET_GET is used, the reply is not affected by whether anything was set.
if (result->has_value()) {
(*cntx)->SendBulkString(result->value());
} else {
(*cntx)->SendNull();
}
return;
}
if (result == OpStatus::OK) {
return builder->SendStored();
}
@ -803,7 +853,7 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
SetCmd cmd(t->GetOpArgs(shard), false);
return cmd.Set(sparams, key, value);
return cmd.Set(sparams, key, value).status();
};
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
@ -1054,7 +1104,7 @@ void StringFamily::SetExGeneric(bool seconds, CmdArgList args, ConnectionContext
auto cb = [&](Transaction* t, EngineShard* shard) {
SetCmd sg(t->GetOpArgs(shard), true);
return sg.Set(sparams, key, value);
return sg.Set(sparams, key, value).status();
};
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));

View file

@ -45,7 +45,11 @@ class SetCmd {
}
};
OpStatus Set(const SetParams& params, std::string_view key, std::string_view value);
// OpResult's value (i.e. optional<string>) is set in the case `params.flags` has SET_GET bit on,
// in which case the previous value (or nullopt if none) is returned. Otherwise, OpResult only
// contains a status.
OpResult<std::optional<std::string>> Set(const SetParams& params, std::string_view key,
std::string_view value);
private:
OpStatus SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it,

View file

@ -673,4 +673,17 @@ TEST_F(StringFamilyTest, SetMGetWithNilResp3) {
EXPECT_THAT(resp.GetVec(), ElementsAre("val", ArgType(RespExpr::NIL)));
}
TEST_F(StringFamilyTest, SetWithGetParam) {
EXPECT_THAT(Run({"set", "key1", "val1", "get"}), ArgType(RespExpr::NIL));
EXPECT_EQ(Run({"set", "key1", "val2", "get"}), "val1");
EXPECT_THAT(Run({"set", "key2", "val2", "nx", "get"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"set", "key2", "not used", "nx", "get"}), "val2");
EXPECT_EQ(Run({"get", "key2"}), "val2");
EXPECT_THAT(Run({"set", "key3", "not used", "xx", "get"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"set", "key2", "val3", "xx", "get"}), "val2");
EXPECT_EQ(Run({"get", "key2"}), "val3");
}
} // namespace dfly