From c08719117c0111725441aee8e1456676a18b4a6b Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 13 Jun 2024 12:33:24 +0300 Subject: [PATCH] feat(json): MSET (#3167) Signed-off-by: Vladislav Oleshko --- src/server/command_registry.cc | 6 +- src/server/json_family.cc | 55 +++++++++++++--- src/server/json_family_test.cc | 17 +++-- src/server/string_family.cc | 113 ++++++++++++++++----------------- src/server/transaction.cc | 2 +- src/server/tx_base.h | 6 ++ tests/dragonfly/utility.py | 4 +- 7 files changed, 127 insertions(+), 76 deletions(-) diff --git a/src/server/command_registry.cc b/src/server/command_registry.cc index 327fc2d61..9f58dfaf7 100644 --- a/src/server/command_registry.cc +++ b/src/server/command_registry.cc @@ -78,8 +78,10 @@ optional CommandId::Validate(CmdArgList tail_args) const { return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType}; } - if ((opt_mask() & CO::INTERLEAVED_KEYS) && (tail_args.size() % 2) != 0) { - return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType}; + if ((opt_mask() & CO::INTERLEAVED_KEYS)) { + if ((name() == "JSON.MSET" && tail_args.size() % 3 != 0) || + (name() == "MSET" && tail_args.size() % 2 != 0)) + return facade::ErrorReply{facade::WrongNumArgsError(name()), kSyntaxErrType}; } if (validator_) diff --git a/src/server/json_family.cc b/src/server/json_family.cc index 3b02c64ba..f12b4a723 100644 --- a/src/server/json_family.cc +++ b/src/server/json_family.cc @@ -23,6 +23,7 @@ #include "facade/op_status.h" #include "server/acl/acl_commands_def.h" #include "server/command_registry.h" +#include "server/common.h" #include "server/error.h" #include "server/journal/journal.h" #include "server/search/doc_index.h" @@ -1129,7 +1130,8 @@ OpResult> OpArrIndex(const OpArgs& op_args, string_view key, Jso } // Returns string vector that represents the query result of each supplied key. -vector OpJsonMGet(JsonPathV2 expression, const Transaction* t, EngineShard* shard) { +vector OpJsonMGet(const JsonPathV2& expression, const Transaction* t, + EngineShard* shard) { ShardArgs args = t->GetShardArgs(shard->shard_id()); DCHECK(!args.Empty()); vector response(args.Size()); @@ -1289,6 +1291,40 @@ OpResult OpSet(const OpArgs& op_args, string_view key, string_view path, return operation_result; } +OpStatus OpMSet(const OpArgs& op_args, const ShardArgs& args) { + DCHECK_EQ(args.Size() % 3, 0u); + + OpStatus result = OpStatus::OK; + size_t stored = 0; + for (auto it = args.begin(); it != args.end();) { + string_view key = *(it++); + string_view path = *(it++); + string_view value = *(it++); + if (auto res = OpSet(op_args, key, path, value, false, false); !res.ok()) { + result = res.status(); + break; + } + + stored++; + } + + // Replicate custom journal, see OpMSet + if (auto journal = op_args.shard->journal(); journal) { + if (stored * 3 == args.Size()) { + RecordJournal(op_args, "JSON.MSET", args, op_args.tx->GetUniqueShardCnt()); + DCHECK_EQ(result, OpStatus::OK); + return result; + } + + string_view cmd = stored == 0 ? "PING" : "JSON.MSET"; + vector store_args(args.begin(), args.end()); + store_args.resize(stored * 3); + RecordJournal(op_args, cmd, store_args, op_args.tx->GetUniqueShardCnt()); + } + + return result; +} + // Implements the recursive algorithm from // https://datatracker.ietf.org/doc/html/rfc7386#section-2 void RecursiveMerge(const JsonType& patch, JsonType* dest) { @@ -1414,16 +1450,19 @@ void JsonFamily::MSet(CmdArgList args, ConnectionContext* cntx) { return cntx->SendError(facade::WrongNumArgsError("json.mset")); } - return cntx->SendError("Not implemented"); - - auto cb = [&](Transaction* t, EngineShard* shard) { + AggregateStatus status; + auto cb = [&status](Transaction* t, EngineShard* shard) { + auto op_args = t->GetOpArgs(shard); ShardArgs args = t->GetShardArgs(shard->shard_id()); - (void)args; // TBD + if (auto result = OpMSet(op_args, args); result != OpStatus::OK) + status = result; return OpStatus::OK; }; - Transaction* trans = cntx->transaction; - trans->ScheduleSingleHop(cb); + cntx->transaction->ScheduleSingleHop(cb); + + if (*status != OpStatus::OK) + return cntx->SendError(*status); cntx->SendOk(); } @@ -1530,7 +1569,7 @@ void JsonFamily::MGet(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { ShardId sid = shard->shard_id(); - mget_resp[sid] = OpJsonMGet(*ParseJsonPath(path), t, shard); + mget_resp[sid] = OpJsonMGet(expression, t, shard); return OpStatus::OK; }; diff --git a/src/server/json_family_test.cc b/src/server/json_family_test.cc index 6a09c9ac3..ed5cbe733 100644 --- a/src/server/json_family_test.cc +++ b/src/server/json_family_test.cc @@ -952,6 +952,9 @@ TEST_F(JsonFamilyTest, MGet) { resp = Run({"JSON.SET", "json2", ".", json[1]}); ASSERT_THAT(resp, "OK"); + resp = Run({"JSON.MGET", "json1", "??INNNNVALID??"}); + EXPECT_THAT(resp, ErrArg("Unknown token")); + resp = Run({"JSON.MGET", "json1", "json2", "json3", "$.address.country"}); ASSERT_EQ(RespExpr::ARRAY, resp.type); EXPECT_THAT(resp.GetVec(), @@ -1082,18 +1085,20 @@ TEST_F(JsonFamilyTest, Set) { } TEST_F(JsonFamilyTest, MSet) { - GTEST_SKIP() << "Not implemented"; - string json = R"( - {"a":{"a":1, "b":2, "c":3}} - )"; + string json1 = R"({"a":{"a":1,"b":2,"c":3}})"; + string json2 = R"({"a":{"a":4,"b":5,"c":6}})"; auto resp = Run({"JSON.MSET", "j1", "$"}); EXPECT_THAT(resp, ErrArg("wrong number")); - resp = Run({"JSON.MSET", "j1", "$", json, "j3", "$"}); + resp = Run({"JSON.MSET", "j1", "$", json1, "j3", "$"}); EXPECT_THAT(resp, ErrArg("wrong number")); - resp = Run({"JSON.MSET", "j1", "$", json, "j3", "$", json}); + resp = Run({"JSON.MSET", "j1", "$", json1, "j2", "$", json2, "j3", "$", json1, "j4", "$", json2}); EXPECT_EQ(resp, "OK"); + + resp = Run({"JSON.MGET", "j1", "j2", "j3", "j4", "$"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("[" + json1 + "]", "[" + json2 + "]", "[" + json1 + "]", + "[" + json2 + "]")); } TEST_F(JsonFamilyTest, Merge) { diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 53d7d8fe9..9772e188a 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -263,54 +263,43 @@ int64_t AbsExpiryToTtl(int64_t abs_expiry_time, bool as_milli) { } // Returns true if keys were set, false otherwise. -void OpMSet(const OpArgs& op_args, const ShardArgs& args, atomic_bool* success) { +OpStatus OpMSet(const OpArgs& op_args, const ShardArgs& args) { DCHECK(!args.Empty() && args.Size() % 2 == 0); SetCmd::SetParams params; SetCmd sg(op_args, false); - size_t index = 0; - bool partial = false; - for (auto it = args.begin(); it != args.end(); ++it) { - string_view key = *it; - ++it; - string_view value = *it; - DVLOG(1) << "MSet " << key << ":" << value; - if (sg.Set(params, key, value) != OpStatus::OK) { // OOM for example. - success->store(false); - partial = true; + OpStatus result = OpStatus::OK; + size_t stored = 0; + for (auto it = args.begin(); it != args.end();) { + string_view key = *(it++); + string_view value = *(it++); + if (auto status = sg.Set(params, key, value); status != OpStatus::OK) { + result = status; break; } - index += 2; + + stored++; } + // Above loop could have parial success (e.g. OOM), so replicate only what was + // changed if (auto journal = op_args.shard->journal(); journal) { - // We write a custom journal because an OOM in the above loop could lead to partial success, so - // we replicate only what was changed. - if (partial) { - string_view cmd; - ArgSlice cmd_args; - vector store_args(index); - if (index == 0) { - // All shards must record the tx was executed for the replica to execute it, so we send a - // PING in case nothing was changed - cmd = "PING"; - } else { - // journal [0, i) - cmd = "MSET"; - unsigned i = 0; - for (string_view arg : args) { - store_args[i++] = arg; - if (i >= store_args.size()) - break; - } - cmd_args = absl::MakeSpan(store_args); - } - RecordJournal(op_args, cmd, cmd_args, op_args.tx->GetUniqueShardCnt()); - } else { + if (stored * 2 == args.Size()) { RecordJournal(op_args, "MSET", args, op_args.tx->GetUniqueShardCnt()); + DCHECK_EQ(result, OpStatus::OK); + return result; } + + // Even without changes, we have to send a dummy command like PING for the + // replica to ack + string_view cmd = stored == 0 ? "PING" : "MSET"; + vector store_args(args.begin(), args.end()); + store_args.resize(stored * 2); + RecordJournal(op_args, cmd, store_args, op_args.tx->GetUniqueShardCnt()); } + + return result; } // emission_interval_ms assumed to be positive @@ -451,7 +440,8 @@ SinkReplyBuilder::MGetResponse OpMGet(util::fb2::BlockingCounter wait_bc, bool f auto& resp = response.resp_arr[i].emplace(); - // Copy to buffer or trigger tiered read that will eventually write to buffer + // Copy to buffer or trigger tiered read that will eventually write to + // buffer if (it->second.IsExternal()) { wait_bc->Add(1); auto cb = [next, wait_bc](const string& v) mutable { @@ -481,7 +471,8 @@ SinkReplyBuilder::MGetResponse OpMGet(util::fb2::BlockingCounter wait_bc, bool f return response; } -// Extend key with value, either prepend or append. Return size of stored string after modification +// Extend key with value, either prepend or append. Return size of stored string +// after modification OpResult>> OpExtend(const OpArgs& op_args, std::string_view key, std::string_view value, @@ -761,13 +752,15 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) { bool is_ms = (opt[0] == 'P'); - // for []AT we need to take expiration time as absolute from the value given - // check here and if the time is in the past, return OK but don't set it - // Note that the time pass here for PXAT is in milliseconds, we must not change it! + // for []AT we need to take expiration time as absolute from the value + // given check here and if the time is in the past, return OK but don't + // set it Note that the time pass here for PXAT is in milliseconds, we + // must not change it! if (absl::EndsWith(opt, "AT")) { int_arg = AbsExpiryToTtl(int_arg, is_ms); if (int_arg < 0) { - // this happened in the past, just return, for some reason Redis reports OK in this case + // this happened in the past, just return, for some reason Redis + // reports OK in this case return builder->SendStored(); } } @@ -843,7 +836,8 @@ void StringFamily::SetNx(CmdArgList args, ConnectionContext* cntx) { // This is the same as calling the "Set" function, only in this case we are // change the value only if the key does not exist. Otherwise the function // will not modify it. in which case it would return 0 - // it would return to the caller 1 in case the key did not exists and was added + // it would return to the caller 1 in case the key did not exists and was + // added string_view key = ArgS(args, 0); string_view value = ArgS(args, 1); @@ -1168,7 +1162,8 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) { // wait for all tiered reads to finish tiering_bc->Wait(); - // reorder the responses back according to the order of their corresponding keys. + // reorder the responses back according to the order of their corresponding + // keys. SinkReplyBuilder::MGetResponse res(args.size()); for (ShardId sid = 0; sid < mget_resp.size(); ++sid) { @@ -1208,18 +1203,21 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) { LOG(INFO) << "MSET/" << transaction->GetUniqueShardCnt() << str; } - atomic_bool success = true; + AggregateStatus result; auto cb = [&](Transaction* t, EngineShard* shard) { ShardArgs args = t->GetShardArgs(shard->shard_id()); - OpMSet(t->GetOpArgs(shard), args, &success); + if (auto status = OpMSet(t->GetOpArgs(shard), args); status != OpStatus::OK) + result = status; return OpStatus::OK; }; - OpStatus status = transaction->ScheduleSingleHop(std::move(cb)); - if (success.load()) { + if (auto status = transaction->ScheduleSingleHop(std::move(cb)); status != OpStatus::OK) + result = status; + + if (*result == OpStatus::OK) { cntx->SendOk(); } else { - cntx->SendError(status); + cntx->SendError(*result); } } @@ -1245,18 +1243,19 @@ void StringFamily::MSetNx(CmdArgList args, ConnectionContext* cntx) { transaction->Execute(std::move(cb), false); const bool to_skip = exists.load(memory_order_relaxed); - atomic_bool success = true; + AggregateStatus result; auto epilog_cb = [&](Transaction* t, EngineShard* shard) { if (to_skip) return OpStatus::OK; auto args = t->GetShardArgs(shard->shard_id()); - OpMSet(t->GetOpArgs(shard), std::move(args), &success); + if (auto status = OpMSet(t->GetOpArgs(shard), args); status != OpStatus::OK) + result = status; return OpStatus::OK; }; transaction->Execute(std::move(epilog_cb), true); - cntx->SendLong(to_skip || !success.load() ? 0 : 1); + cntx->SendLong(to_skip || (*result != OpStatus::OK) ? 0 : 1); } void StringFamily::StrLen(CmdArgList args, ConnectionContext* cntx) { @@ -1343,13 +1342,13 @@ void StringFamily::SetRange(CmdArgList args, ConnectionContext* cntx) { * 1. Whether the action was limited: * - 0 indicates the action is allowed. * - 1 indicates that the action was limited/blocked. - * 2. The total limit of the key (max_burst + 1). This is equivalent to the common - * X-RateLimit-Limit HTTP header. + * 2. The total limit of the key (max_burst + 1). This is equivalent to the + * common X-RateLimit-Limit HTTP header. * 3. The remaining limit of the key. Equivalent to X-RateLimit-Remaining. - * 4. The number of seconds until the user should retry, and always -1 if the action was allowed. - * Equivalent to Retry-After. - * 5. The number of seconds until the limit will reset to its maximum capacity. Equivalent to - * X-RateLimit-Reset. + * 4. The number of seconds until the user should retry, and always -1 if the + * action was allowed. Equivalent to Retry-After. + * 5. The number of seconds until the limit will reset to its maximum capacity. + * Equivalent to X-RateLimit-Reset. */ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { const string_view key = ArgS(args, 0); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index ffabc7436..91eb33394 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -312,7 +312,7 @@ void Transaction::InitByKeys(const KeyIndex& key_index) { } shard_data_.resize(shard_set->size()); // shard_data isn't sparse, so we must allocate for all :( - DCHECK_EQ(full_args_.size() % key_index.step, 0u); + DCHECK_EQ(full_args_.size() % key_index.step, 0u) << full_args_; // Safe, because flow below is not preemptive. auto& shard_index = tmp_space.GetShardIndex(shard_data_.size()); diff --git a/src/server/tx_base.h b/src/server/tx_base.h index 02f4f057b..5410221f2 100644 --- a/src/server/tx_base.h +++ b/src/server/tx_base.h @@ -153,6 +153,12 @@ class ShardArgs { return *this; } + Iterator operator++(int) { + Iterator copy = *this; + operator++(); + return copy; + } + size_t index() const { return index_it_->first + delta_; } diff --git a/tests/dragonfly/utility.py b/tests/dragonfly/utility.py index 8e2cfd61d..e20eba864 100644 --- a/tests/dragonfly/utility.py +++ b/tests/dragonfly/utility.py @@ -233,7 +233,7 @@ class CommandGenerator: ValueType.SET: "SADD", ValueType.HSET: "HMSET", ValueType.ZSET: "ZADD", - ValueType.JSON: "JSON.SET", + ValueType.JSON: "JSON.MSET", } def gen_grow_cmd(self): @@ -242,7 +242,7 @@ class CommandGenerator: """ # TODO: Implement COPY in Dragonfly. t = self.random_type() - if t == ValueType.STRING: + if t in [ValueType.STRING, ValueType.JSON]: count = random.randint(1, self.max_multikey) else: count = 1