Handle GET parameter for SET command. (#1023)

* feat(server): Handle GET parameter for SET command.

Return previous value, as per Redis, in case GET was specified on SET.

---------

Signed-off-by: chakaz <chakaz@chakaz>
Co-authored-by: chakaz <chakaz@chakaz>
This commit is contained in:
Chaka 2023-04-04 15:52:27 +03:00 committed by GitHub
parent d6e6bcf5a9
commit ee7705a84d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 15 deletions

View file

@ -46,6 +46,12 @@ constexpr uint32_t kMinTieredLen = TieredStorage::kMinBlobLen;
string GetString(EngineShard* shard, const PrimeValue& pv) { string GetString(EngineShard* shard, const PrimeValue& pv) {
string res; string res;
if (pv.ObjType() != OBJ_STRING) {
// An attempt to read a non-string's string value can happen when overriding a non-string value
// with a string value.
return "";
}
if (pv.IsExternal()) { if (pv.IsExternal()) {
auto* tiered = shard->tiered_storage(); auto* tiered = shard->tiered_storage();
auto [offset, size] = pv.GetExternalSlice(); auto [offset, size] = pv.GetExternalSlice();
@ -340,16 +346,17 @@ OpStatus OpMSet(const OpArgs& op_args, ArgSlice args) {
for (size_t i = 0; i < args.size(); i += 2) { for (size_t i = 0; i < args.size(); i += 2) {
DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1]; DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1];
OpStatus res = sg.Set(params, args[i], args[i + 1]); OpResult<optional<string>> res = sg.Set(params, args[i], args[i + 1]);
if (res != OpStatus::OK) { // OOM for example. if (res.status() != OpStatus::OK) { // OOM for example.
return res; return res.status();
} }
} }
return OpStatus::OK; return OpStatus::OK;
} }
OpResult<void> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& sparams, // See comment for SetCmd::Set() for when and how OpResult's value (i.e. optional<string>) is set.
OpResult<optional<string>> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& sparams,
string_view key, string_view value, bool manual_journal) { string_view key, string_view value, bool manual_journal) {
DCHECK(cntx->transaction); DCHECK(cntx->transaction);
@ -357,7 +364,7 @@ OpResult<void> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar
SetCmd sg(t->GetOpArgs(shard), manual_journal); SetCmd sg(t->GetOpArgs(shard), manual_journal);
return sg.Set(sparams, key, value); return sg.Set(sparams, key, value);
}; };
return cntx->transaction->ScheduleSingleHop(std::move(cb)); return cntx->transaction->ScheduleSingleHopT(std::move(cb));
} }
// emission_interval_ms assumed to be positive // emission_interval_ms assumed to be positive
@ -470,9 +477,37 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_ms, reset_after_ms}; return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_ms, reset_after_ms};
} }
class SetResultBuilder {
public:
explicit SetResultBuilder(bool return_prev_value) : return_prev_value_(return_prev_value) {
}
void CachePrevValueIfNeeded(string_view value) {
if (return_prev_value_) {
prev_value_ = value;
}
}
// Returns either the previous value or `status`, depending on return_prev_value_.
OpResult<optional<string>> Return(OpStatus status) && {
if (return_prev_value_) {
return std::move(prev_value_);
} else {
return status;
}
}
private:
bool return_prev_value_;
std::optional<string> prev_value_;
};
} // namespace } // namespace
OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) { OpResult<optional<string>> SetCmd::Set(const SetParams& params, string_view key,
string_view value) {
SetResultBuilder result_builder(params.flags & SET_GET);
EngineShard* shard = op_args_.shard; EngineShard* shard = op_args_.shard;
auto& db_slice = shard->db_slice(); auto& db_slice = shard->db_slice();
@ -482,16 +517,20 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
if (params.IsConditionalSet()) { if (params.IsConditionalSet()) {
const auto [it, expire_it] = db_slice.FindExt(op_args_.db_cntx, key); const auto [it, expire_it] = db_slice.FindExt(op_args_.db_cntx, key);
if (IsValid(it)) {
result_builder.CachePrevValueIfNeeded(GetString(shard, it->second));
}
// Make sure that we have this key, and only add it if it does exists // Make sure that we have this key, and only add it if it does exists
if (params.flags & SET_IF_EXISTS) { if (params.flags & SET_IF_EXISTS) {
if (IsValid(it)) { if (IsValid(it)) {
return SetExisting(params, it, expire_it, key, value); return std::move(result_builder).Return(SetExisting(params, it, expire_it, key, value));
} else { } else {
return OpStatus::SKIPPED; return std::move(result_builder).Return(OpStatus::SKIPPED);
} }
} else { } else {
if (IsValid(it)) { // if the policy is not to overide and have the key, just return if (IsValid(it)) { // if the policy is not to overide and have the key, just return
return OpStatus::SKIPPED; return std::move(result_builder).Return(OpStatus::SKIPPED);
} }
} }
} }
@ -507,7 +546,8 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
PrimeIterator it = get<0>(add_res); PrimeIterator it = get<0>(add_res);
if (!get<2>(add_res)) { // Existing. if (!get<2>(add_res)) { // Existing.
return SetExisting(params, it, get<1>(add_res), key, value); result_builder.CachePrevValueIfNeeded(GetString(shard, it->second));
return std::move(result_builder).Return(SetExisting(params, it, get<1>(add_res), key, value));
} }
// Adding new value. // Adding new value.
@ -535,7 +575,7 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
RecordJournal(params, key, value); RecordJournal(params, key, value);
} }
return OpStatus::OK; return std::move(result_builder).Return(OpStatus::OK);
} }
OpStatus SetCmd::SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it, OpStatus SetCmd::SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it,
@ -692,6 +732,16 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
const auto result{SetGeneric(cntx, sparams, key, value, true)}; const auto result{SetGeneric(cntx, sparams, key, value, true)};
if (sparams.flags & SetCmd::SET_GET) {
// When SET_GET is used, the reply is not affected by whether anything was set.
if (result->has_value()) {
(*cntx)->SendBulkString(result->value());
} else {
(*cntx)->SendNull();
}
return;
}
if (result == OpStatus::OK) { if (result == OpStatus::OK) {
return builder->SendStored(); return builder->SendStored();
} }
@ -803,7 +853,7 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) { auto cb = [&](Transaction* t, EngineShard* shard) {
SetCmd cmd(t->GetOpArgs(shard), false); SetCmd cmd(t->GetOpArgs(shard), false);
return cmd.Set(sparams, key, value); return cmd.Set(sparams, key, value).status();
}; };
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
@ -1054,7 +1104,7 @@ void StringFamily::SetExGeneric(bool seconds, CmdArgList args, ConnectionContext
auto cb = [&](Transaction* t, EngineShard* shard) { auto cb = [&](Transaction* t, EngineShard* shard) {
SetCmd sg(t->GetOpArgs(shard), true); SetCmd sg(t->GetOpArgs(shard), true);
return sg.Set(sparams, key, value); return sg.Set(sparams, key, value).status();
}; };
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb)); OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));

View file

@ -45,7 +45,11 @@ class SetCmd {
} }
}; };
OpStatus Set(const SetParams& params, std::string_view key, std::string_view value); // OpResult's value (i.e. optional<string>) is set in the case `params.flags` has SET_GET bit on,
// in which case the previous value (or nullopt if none) is returned. Otherwise, OpResult only
// contains a status.
OpResult<std::optional<std::string>> Set(const SetParams& params, std::string_view key,
std::string_view value);
private: private:
OpStatus SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it, OpStatus SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it,

View file

@ -673,4 +673,17 @@ TEST_F(StringFamilyTest, SetMGetWithNilResp3) {
EXPECT_THAT(resp.GetVec(), ElementsAre("val", ArgType(RespExpr::NIL))); EXPECT_THAT(resp.GetVec(), ElementsAre("val", ArgType(RespExpr::NIL)));
} }
TEST_F(StringFamilyTest, SetWithGetParam) {
EXPECT_THAT(Run({"set", "key1", "val1", "get"}), ArgType(RespExpr::NIL));
EXPECT_EQ(Run({"set", "key1", "val2", "get"}), "val1");
EXPECT_THAT(Run({"set", "key2", "val2", "nx", "get"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"set", "key2", "not used", "nx", "get"}), "val2");
EXPECT_EQ(Run({"get", "key2"}), "val2");
EXPECT_THAT(Run({"set", "key3", "not used", "xx", "get"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"set", "key2", "val3", "xx", "get"}), "val2");
EXPECT_EQ(Run({"get", "key2"}), "val3");
}
} // namespace dfly } // namespace dfly