From c2d32f9d68672899d02434c8f687056aa3bdd59d Mon Sep 17 00:00:00 2001 From: iko1 Date: Thu, 20 Apr 2023 10:20:00 +0200 Subject: [PATCH] fix(server): json.set should add missing keys and add missing cmd... (#1065) fix(server): json.set should add missing keys and add missing command flags Signed-off-by: iko1 --- src/server/json_family.cc | 110 ++++++++++++++++++++++++++++++--- src/server/json_family_test.cc | 22 +++++++ 2 files changed, 122 insertions(+), 10 deletions(-) diff --git a/src/server/json_family.cc b/src/server/json_family.cc index b83d02535..82ec6b580 100644 --- a/src/server/json_family.cc +++ b/src/server/json_family.cc @@ -10,6 +10,7 @@ extern "C" { #include #include +#include #include #include @@ -35,14 +36,14 @@ using OptLong = optional; using OptSizeT = optional; using OptString = optional; using JsonReplaceCb = function; -using JsonReplaceVerify = std::function; +using JsonReplaceVerify = std::function; using CI = CommandId; static const char DefaultJsonPath[] = "$"; namespace { -inline OpStatus JsonReplaceVerifyNoOp() { +inline OpStatus JsonReplaceVerifyNoOp(JsonType&) { return OpStatus::OK; } @@ -148,7 +149,7 @@ OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_vi } // Make sure that we don't have other internal issue with the operation - OpStatus res = verify_op(); + OpStatus res = verify_op(json_entry); if (res == OpStatus::OK) { db_slice.PostUpdate(db_index, entry_it, key); } @@ -280,6 +281,39 @@ string ConvertToJsonPointer(string_view json_path) { return result; } +string ConvertExpressionToJsonPointer(string_view json_path) { + if (json_path.empty() || !absl::StartsWith(json_path, "$.")) { + VLOG(1) << "retrieved malformed JSON path expression: " << json_path; + return {}; + } + + // remove prefix + json_path.remove_prefix(2); + + std::string pointer; + vector splitted = absl::StrSplit(json_path, '.'); + for (auto& it : splitted) { + if (it.front() == '[' && it.back() == ']') { + std::string index = it.substr(1, it.size() - 2); + if (index.empty()) { + return {}; + } + + for (char ch : index) { + if (!std::isdigit(ch)) { + return {}; + } + } + + pointer += '/' + index; + } else { + pointer += '/' + it; + } + } + + return pointer; +} + size_t CountJsonFields(const JsonType& j) { size_t res = 0; json_type type = j.type(); @@ -487,7 +521,7 @@ OpResult OpDoubleArithmetic(const OpArgs& op_args, string_view key, stri } }; - auto verifier = [&is_result_overflow]() { + auto verifier = [&is_result_overflow](JsonType&) { if (is_result_overflow) { return OpStatus::INVALID_NUMERIC_RESULT; } @@ -964,7 +998,7 @@ OpResult> OpResp(const OpArgs& op_args, string_view key, // Returns boolean that represents the result of the operation. OpResult OpSet(const OpArgs& op_args, string_view key, string_view path, - std::string_view json_str) { + std::string_view json_str, bool is_nx_condition, bool is_xx_condition) { std::optional parsed_json = JsonFromString(json_str); if (!parsed_json) { LOG(WARNING) << "got invalid JSON string '" << json_str << "' cannot be saved"; @@ -976,6 +1010,19 @@ OpResult OpSet(const OpArgs& op_args, string_view key, string_view path, // this is regardless of the current key type. In redis if the key exists // and its not JSON, it would return an error. if (path == "." || path == "$") { + if (is_nx_condition || is_xx_condition) { + OpResult it_res = + op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON); + bool key_exists = (it_res.status() != OpStatus::KEY_NOTFOUND); + if (is_nx_condition && key_exists) { + return false; + } + + if (is_xx_condition && !key_exists) { + return false; + } + } + SetJson(op_args, key, std::move(parsed_json.value())); return true; } @@ -986,17 +1033,46 @@ OpResult OpSet(const OpArgs& op_args, string_view key, string_view path, // an array that this expression will match each entry in it // then the assign here is called N times, where N == array.size(). bool path_exists = false; + bool operation_result = false; const JsonType& new_json = parsed_json.value(); auto cb = [&](const string& path, JsonType& val) { path_exists = true; - val = new_json; + if (!is_nx_condition) { + operation_result = true; + val = new_json; + } }; - OpStatus status = UpdateEntry(op_args, key, path, cb); + auto inserter = [&](JsonType& json) { + // Set a new value if the path doesn't exist and the nx condition is not set. + if (!path_exists && !is_xx_condition) { + string pointer = ConvertExpressionToJsonPointer(path); + if (pointer.empty()) { + VLOG(1) << "Failed to convert the following expression path to a valid JSON pointer: " + << path; + return OpStatus::SYNTAX_ERR; + } + + error_code ec; + jsonpointer::add(json, pointer, new_json, ec); + if (ec) { + VLOG(1) << "Failed to add a JSON value to the following path: " << path + << " with the error: " << ec.message(); + return OpStatus::SYNTAX_ERR; + } + + operation_result = true; + } + + return OpStatus::OK; + }; + + OpStatus status = UpdateEntry(op_args, key, path, cb, inserter); if (status != OpStatus::OK) { return status; } - return path_exists; + + return operation_result; } } // namespace @@ -1005,9 +1081,23 @@ void JsonFamily::Set(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); string_view path = ArgS(args, 1); string_view json_str = ArgS(args, 2); + bool is_nx_condition = false; + bool is_xx_condition = false; + string_view operation_opts; + if (args.size() > 3) { + operation_opts = ArgS(args, 3); + if (absl::EqualsIgnoreCase(operation_opts, "NX")) { + is_nx_condition = true; + } else if (absl::EqualsIgnoreCase(operation_opts, "XX")) { + is_xx_condition = true; + } else { + (*cntx)->SendError(kSyntaxErr); + return; + } + } auto cb = [&](Transaction* t, EngineShard* shard) { - return OpSet(t->GetOpArgs(shard), key, path, json_str); + return OpSet(t->GetOpArgs(shard), key, path, json_str, is_nx_condition, is_xx_condition); }; Transaction* trans = cntx->transaction; @@ -1707,7 +1797,7 @@ void JsonFamily::Register(CommandRegistry* registry) { *registry << CI{"JSON.ARRINDEX", CO::READONLY | CO::FAST, -4, 1, 1, 1}.HFUNC(ArrIndex); *registry << CI{"JSON.DEBUG", CO::READONLY | CO::FAST, -2, 1, 1, 1}.HFUNC(Debug); *registry << CI{"JSON.RESP", CO::READONLY | CO::FAST, -2, 1, 1, 1}.HFUNC(Resp); - *registry << CI{"JSON.SET", CO::WRITE | CO::DENYOOM | CO::FAST, 4, 1, 1, 1}.HFUNC(Set); + *registry << CI{"JSON.SET", CO::WRITE | CO::DENYOOM | CO::FAST, -4, 1, 1, 1}.HFUNC(Set); } } // namespace dfly diff --git a/src/server/json_family_test.cc b/src/server/json_family_test.cc index b3e97c7e5..b55710f47 100644 --- a/src/server/json_family_test.cc +++ b/src/server/json_family_test.cc @@ -981,6 +981,28 @@ TEST_F(JsonFamilyTest, Set) { resp = Run({"JSON.GET", "json2", "$"}); EXPECT_EQ(resp, R"([{"a":[0,0,0,0,0]}])"); + + json = R"( + {"a": 2} + )"; + + resp = Run({"JSON.SET", "json3", "$", json}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"JSON.SET", "json3", "$.b", "8"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"JSON.SET", "json3", "$.c", "[1,2,3]"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"JSON.SET", "json3", "$.z", "3", "XX"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL)); + + resp = Run({"JSON.SET", "json3", "$.b", "4", "NX"}); + EXPECT_THAT(resp, ArgType(RespExpr::NIL)); + + resp = Run({"JSON.GET", "json3", "$"}); + EXPECT_EQ(resp, R"([{"a":2,"b":8,"c":[1,2,3]}])"); } } // namespace dfly