feat(json): MSET (#3167)

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2024-06-13 12:33:24 +03:00 committed by GitHub
parent a80063189e
commit c08719117c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 127 additions and 76 deletions

View file

@ -78,7 +78,9 @@ optional<facade::ErrorReply> CommandId::Validate(CmdArgList tail_args) const {
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}
if ((opt_mask() & CO::INTERLEAVED_KEYS) && (tail_args.size() % 2) != 0) {
if ((opt_mask() & CO::INTERLEAVED_KEYS)) {
if ((name() == "JSON.MSET" && tail_args.size() % 3 != 0) ||
(name() == "MSET" && tail_args.size() % 2 != 0))
return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType};
}

View file

@ -23,6 +23,7 @@
#include "facade/op_status.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/common.h"
#include "server/error.h"
#include "server/journal/journal.h"
#include "server/search/doc_index.h"
@ -1129,7 +1130,8 @@ OpResult<vector<OptLong>> OpArrIndex(const OpArgs& op_args, string_view key, Jso
}
// Returns string vector that represents the query result of each supplied key.
vector<OptString> OpJsonMGet(JsonPathV2 expression, const Transaction* t, EngineShard* shard) {
vector<OptString> OpJsonMGet(const JsonPathV2& expression, const Transaction* t,
EngineShard* shard) {
ShardArgs args = t->GetShardArgs(shard->shard_id());
DCHECK(!args.Empty());
vector<OptString> response(args.Size());
@ -1289,6 +1291,40 @@ OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
return operation_result;
}
OpStatus OpMSet(const OpArgs& op_args, const ShardArgs& args) {
DCHECK_EQ(args.Size() % 3, 0u);
OpStatus result = OpStatus::OK;
size_t stored = 0;
for (auto it = args.begin(); it != args.end();) {
string_view key = *(it++);
string_view path = *(it++);
string_view value = *(it++);
if (auto res = OpSet(op_args, key, path, value, false, false); !res.ok()) {
result = res.status();
break;
}
stored++;
}
// Replicate custom journal, see OpMSet
if (auto journal = op_args.shard->journal(); journal) {
if (stored * 3 == args.Size()) {
RecordJournal(op_args, "JSON.MSET", args, op_args.tx->GetUniqueShardCnt());
DCHECK_EQ(result, OpStatus::OK);
return result;
}
string_view cmd = stored == 0 ? "PING" : "JSON.MSET";
vector<string_view> store_args(args.begin(), args.end());
store_args.resize(stored * 3);
RecordJournal(op_args, cmd, store_args, op_args.tx->GetUniqueShardCnt());
}
return result;
}
// Implements the recursive algorithm from
// https://datatracker.ietf.org/doc/html/rfc7386#section-2
void RecursiveMerge(const JsonType& patch, JsonType* dest) {
@ -1414,16 +1450,19 @@ void JsonFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(facade::WrongNumArgsError("json.mset"));
}
return cntx->SendError("Not implemented");
auto cb = [&](Transaction* t, EngineShard* shard) {
AggregateStatus status;
auto cb = [&status](Transaction* t, EngineShard* shard) {
auto op_args = t->GetOpArgs(shard);
ShardArgs args = t->GetShardArgs(shard->shard_id());
(void)args; // TBD
if (auto result = OpMSet(op_args, args); result != OpStatus::OK)
status = result;
return OpStatus::OK;
};
Transaction* trans = cntx->transaction;
trans->ScheduleSingleHop(cb);
cntx->transaction->ScheduleSingleHop(cb);
if (*status != OpStatus::OK)
return cntx->SendError(*status);
cntx->SendOk();
}
@ -1530,7 +1569,7 @@ void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
mget_resp[sid] = OpJsonMGet(*ParseJsonPath(path), t, shard);
mget_resp[sid] = OpJsonMGet(expression, t, shard);
return OpStatus::OK;
};

View file

@ -952,6 +952,9 @@ TEST_F(JsonFamilyTest, MGet) {
resp = Run({"JSON.SET", "json2", ".", json[1]});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.MGET", "json1", "??INNNNVALID??"});
EXPECT_THAT(resp, ErrArg("Unknown token"));
resp = Run({"JSON.MGET", "json1", "json2", "json3", "$.address.country"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
EXPECT_THAT(resp.GetVec(),
@ -1082,18 +1085,20 @@ TEST_F(JsonFamilyTest, Set) {
}
TEST_F(JsonFamilyTest, MSet) {
GTEST_SKIP() << "Not implemented";
string json = R"(
{"a":{"a":1, "b":2, "c":3}}
)";
string json1 = R"({"a":{"a":1,"b":2,"c":3}})";
string json2 = R"({"a":{"a":4,"b":5,"c":6}})";
auto resp = Run({"JSON.MSET", "j1", "$"});
EXPECT_THAT(resp, ErrArg("wrong number"));
resp = Run({"JSON.MSET", "j1", "$", json, "j3", "$"});
resp = Run({"JSON.MSET", "j1", "$", json1, "j3", "$"});
EXPECT_THAT(resp, ErrArg("wrong number"));
resp = Run({"JSON.MSET", "j1", "$", json, "j3", "$", json});
resp = Run({"JSON.MSET", "j1", "$", json1, "j2", "$", json2, "j3", "$", json1, "j4", "$", json2});
EXPECT_EQ(resp, "OK");
resp = Run({"JSON.MGET", "j1", "j2", "j3", "j4", "$"});
EXPECT_THAT(resp.GetVec(), ElementsAre("[" + json1 + "]", "[" + json2 + "]", "[" + json1 + "]",
"[" + json2 + "]"));
}
TEST_F(JsonFamilyTest, Merge) {

View file

@ -263,54 +263,43 @@ int64_t AbsExpiryToTtl(int64_t abs_expiry_time, bool as_milli) {
}
// Returns true if keys were set, false otherwise.
void OpMSet(const OpArgs& op_args, const ShardArgs& args, atomic_bool* success) {
OpStatus OpMSet(const OpArgs& op_args, const ShardArgs& args) {
DCHECK(!args.Empty() && args.Size() % 2 == 0);
SetCmd::SetParams params;
SetCmd sg(op_args, false);
size_t index = 0;
bool partial = false;
for (auto it = args.begin(); it != args.end(); ++it) {
string_view key = *it;
++it;
string_view value = *it;
DVLOG(1) << "MSet " << key << ":" << value;
if (sg.Set(params, key, value) != OpStatus::OK) { // OOM for example.
success->store(false);
partial = true;
OpStatus result = OpStatus::OK;
size_t stored = 0;
for (auto it = args.begin(); it != args.end();) {
string_view key = *(it++);
string_view value = *(it++);
if (auto status = sg.Set(params, key, value); status != OpStatus::OK) {
result = status;
break;
}
index += 2;
}
stored++;
}
// Above loop could have parial success (e.g. OOM), so replicate only what was
// changed
if (auto journal = op_args.shard->journal(); journal) {
// We write a custom journal because an OOM in the above loop could lead to partial success, so
// we replicate only what was changed.
if (partial) {
string_view cmd;
ArgSlice cmd_args;
vector<string_view> store_args(index);
if (index == 0) {
// All shards must record the tx was executed for the replica to execute it, so we send a
// PING in case nothing was changed
cmd = "PING";
} else {
// journal [0, i)
cmd = "MSET";
unsigned i = 0;
for (string_view arg : args) {
store_args[i++] = arg;
if (i >= store_args.size())
break;
}
cmd_args = absl::MakeSpan(store_args);
}
RecordJournal(op_args, cmd, cmd_args, op_args.tx->GetUniqueShardCnt());
} else {
if (stored * 2 == args.Size()) {
RecordJournal(op_args, "MSET", args, op_args.tx->GetUniqueShardCnt());
DCHECK_EQ(result, OpStatus::OK);
return result;
}
// Even without changes, we have to send a dummy command like PING for the
// replica to ack
string_view cmd = stored == 0 ? "PING" : "MSET";
vector<string_view> store_args(args.begin(), args.end());
store_args.resize(stored * 2);
RecordJournal(op_args, cmd, store_args, op_args.tx->GetUniqueShardCnt());
}
return result;
}
// emission_interval_ms assumed to be positive
@ -451,7 +440,8 @@ SinkReplyBuilder::MGetResponse OpMGet(util::fb2::BlockingCounter wait_bc, bool f
auto& resp = response.resp_arr[i].emplace();
// Copy to buffer or trigger tiered read that will eventually write to buffer
// Copy to buffer or trigger tiered read that will eventually write to
// buffer
if (it->second.IsExternal()) {
wait_bc->Add(1);
auto cb = [next, wait_bc](const string& v) mutable {
@ -481,7 +471,8 @@ SinkReplyBuilder::MGetResponse OpMGet(util::fb2::BlockingCounter wait_bc, bool f
return response;
}
// Extend key with value, either prepend or append. Return size of stored string after modification
// Extend key with value, either prepend or append. Return size of stored string
// after modification
OpResult<variant<size_t, util::fb2::Future<size_t>>> OpExtend(const OpArgs& op_args,
std::string_view key,
std::string_view value,
@ -761,13 +752,15 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
bool is_ms = (opt[0] == 'P');
// for []AT we need to take expiration time as absolute from the value given
// check here and if the time is in the past, return OK but don't set it
// Note that the time pass here for PXAT is in milliseconds, we must not change it!
// for []AT we need to take expiration time as absolute from the value
// given check here and if the time is in the past, return OK but don't
// set it Note that the time pass here for PXAT is in milliseconds, we
// must not change it!
if (absl::EndsWith(opt, "AT")) {
int_arg = AbsExpiryToTtl(int_arg, is_ms);
if (int_arg < 0) {
// this happened in the past, just return, for some reason Redis reports OK in this case
// this happened in the past, just return, for some reason Redis
// reports OK in this case
return builder->SendStored();
}
}
@ -843,7 +836,8 @@ void StringFamily::SetNx(CmdArgList args, ConnectionContext* cntx) {
// This is the same as calling the "Set" function, only in this case we are
// change the value only if the key does not exist. Otherwise the function
// will not modify it. in which case it would return 0
// it would return to the caller 1 in case the key did not exists and was added
// it would return to the caller 1 in case the key did not exists and was
// added
string_view key = ArgS(args, 0);
string_view value = ArgS(args, 1);
@ -1168,7 +1162,8 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
// wait for all tiered reads to finish
tiering_bc->Wait();
// reorder the responses back according to the order of their corresponding keys.
// reorder the responses back according to the order of their corresponding
// keys.
SinkReplyBuilder::MGetResponse res(args.size());
for (ShardId sid = 0; sid < mget_resp.size(); ++sid) {
@ -1208,18 +1203,21 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
LOG(INFO) << "MSET/" << transaction->GetUniqueShardCnt() << str;
}
atomic_bool success = true;
AggregateStatus result;
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardArgs args = t->GetShardArgs(shard->shard_id());
OpMSet(t->GetOpArgs(shard), args, &success);
if (auto status = OpMSet(t->GetOpArgs(shard), args); status != OpStatus::OK)
result = status;
return OpStatus::OK;
};
OpStatus status = transaction->ScheduleSingleHop(std::move(cb));
if (success.load()) {
if (auto status = transaction->ScheduleSingleHop(std::move(cb)); status != OpStatus::OK)
result = status;
if (*result == OpStatus::OK) {
cntx->SendOk();
} else {
cntx->SendError(status);
cntx->SendError(*result);
}
}
@ -1245,18 +1243,19 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) {
transaction->Execute(std::move(cb), false);
const bool to_skip = exists.load(memory_order_relaxed);
atomic_bool success = true;
AggregateStatus result;
auto epilog_cb = [&](Transaction* t, EngineShard* shard) {
if (to_skip)
return OpStatus::OK;
auto args = t->GetShardArgs(shard->shard_id());
OpMSet(t->GetOpArgs(shard), std::move(args), &success);
if (auto status = OpMSet(t->GetOpArgs(shard), args); status != OpStatus::OK)
result = status;
return OpStatus::OK;
};
transaction->Execute(std::move(epilog_cb), true);
cntx->SendLong(to_skip || !success.load() ? 0 : 1);
cntx->SendLong(to_skip || (*result != OpStatus::OK) ? 0 : 1);
}
void StringFamily::StrLen(CmdArgList args, ConnectionContext* cntx) {
@ -1343,13 +1342,13 @@ void StringFamily::SetRange(CmdArgList args, ConnectionContext* cntx) {
* 1. Whether the action was limited:
* - 0 indicates the action is allowed.
* - 1 indicates that the action was limited/blocked.
* 2. The total limit of the key (max_burst + 1). This is equivalent to the common
* X-RateLimit-Limit HTTP header.
* 2. The total limit of the key (max_burst + 1). This is equivalent to the
* common X-RateLimit-Limit HTTP header.
* 3. The remaining limit of the key. Equivalent to X-RateLimit-Remaining.
* 4. The number of seconds until the user should retry, and always -1 if the action was allowed.
* Equivalent to Retry-After.
* 5. The number of seconds until the limit will reset to its maximum capacity. Equivalent to
* X-RateLimit-Reset.
* 4. The number of seconds until the user should retry, and always -1 if the
* action was allowed. Equivalent to Retry-After.
* 5. The number of seconds until the limit will reset to its maximum capacity.
* Equivalent to X-RateLimit-Reset.
*/
void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) {
const string_view key = ArgS(args, 0);

View file

@ -312,7 +312,7 @@ void Transaction::InitByKeys(const KeyIndex& key_index) {
}
shard_data_.resize(shard_set->size()); // shard_data isn't sparse, so we must allocate for all :(
DCHECK_EQ(full_args_.size() % key_index.step, 0u);
DCHECK_EQ(full_args_.size() % key_index.step, 0u) << full_args_;
// Safe, because flow below is not preemptive.
auto& shard_index = tmp_space.GetShardIndex(shard_data_.size());

View file

@ -153,6 +153,12 @@ class ShardArgs {
return *this;
}
Iterator operator++(int) {
Iterator copy = *this;
operator++();
return copy;
}
size_t index() const {
return index_it_->first + delta_;
}

View file

@ -233,7 +233,7 @@ class CommandGenerator:
ValueType.SET: "SADD",
ValueType.HSET: "HMSET",
ValueType.ZSET: "ZADD",
ValueType.JSON: "JSON.SET",
ValueType.JSON: "JSON.MSET",
}
def gen_grow_cmd(self):
@ -242,7 +242,7 @@ class CommandGenerator:
"""
# TODO: Implement COPY in Dragonfly.
t = self.random_type()
if t == ValueType.STRING:
if t in [ValueType.STRING, ValueType.JSON]:
count = random.randint(1, self.max_multikey)
else:
count = 1