fix(server): json.set should add missing keys and add missing cmd... (#1065)

fix(server): json.set should add missing keys and add missing command flags

Signed-off-by: iko1 <me@remotecpp.dev>
This commit is contained in:
iko1 2023-04-20 10:20:00 +02:00 committed by GitHub
parent 6632261a2d
commit c2d32f9d68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 122 additions and 10 deletions

View file

@ -10,6 +10,7 @@ extern "C" {
#include <absl/strings/match.h>
#include <absl/strings/str_join.h>
#include <absl/strings/str_split.h>
#include <jsoncons/json.hpp>
#include <jsoncons_ext/jsonpatch/jsonpatch.hpp>
@ -35,14 +36,14 @@ using OptLong = optional<long>;
using OptSizeT = optional<size_t>;
using OptString = optional<string>;
using JsonReplaceCb = function<void(const string&, JsonType&)>;
using JsonReplaceVerify = std::function<OpStatus()>;
using JsonReplaceVerify = std::function<OpStatus(JsonType&)>;
using CI = CommandId;
static const char DefaultJsonPath[] = "$";
namespace {
inline OpStatus JsonReplaceVerifyNoOp() {
inline OpStatus JsonReplaceVerifyNoOp(JsonType&) {
return OpStatus::OK;
}
@ -148,7 +149,7 @@ OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_vi
}
// Make sure that we don't have other internal issue with the operation
OpStatus res = verify_op();
OpStatus res = verify_op(json_entry);
if (res == OpStatus::OK) {
db_slice.PostUpdate(db_index, entry_it, key);
}
@ -280,6 +281,39 @@ string ConvertToJsonPointer(string_view json_path) {
return result;
}
string ConvertExpressionToJsonPointer(string_view json_path) {
if (json_path.empty() || !absl::StartsWith(json_path, "$.")) {
VLOG(1) << "retrieved malformed JSON path expression: " << json_path;
return {};
}
// remove prefix
json_path.remove_prefix(2);
std::string pointer;
vector<string> splitted = absl::StrSplit(json_path, '.');
for (auto& it : splitted) {
if (it.front() == '[' && it.back() == ']') {
std::string index = it.substr(1, it.size() - 2);
if (index.empty()) {
return {};
}
for (char ch : index) {
if (!std::isdigit(ch)) {
return {};
}
}
pointer += '/' + index;
} else {
pointer += '/' + it;
}
}
return pointer;
}
size_t CountJsonFields(const JsonType& j) {
size_t res = 0;
json_type type = j.type();
@ -487,7 +521,7 @@ OpResult<string> OpDoubleArithmetic(const OpArgs& op_args, string_view key, stri
}
};
auto verifier = [&is_result_overflow]() {
auto verifier = [&is_result_overflow](JsonType&) {
if (is_result_overflow) {
return OpStatus::INVALID_NUMERIC_RESULT;
}
@ -964,7 +998,7 @@ OpResult<vector<JsonType>> OpResp(const OpArgs& op_args, string_view key,
// Returns boolean that represents the result of the operation.
OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
std::string_view json_str) {
std::string_view json_str, bool is_nx_condition, bool is_xx_condition) {
std::optional<JsonType> parsed_json = JsonFromString(json_str);
if (!parsed_json) {
LOG(WARNING) << "got invalid JSON string '" << json_str << "' cannot be saved";
@ -976,6 +1010,19 @@ OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
// this is regardless of the current key type. In redis if the key exists
// and its not JSON, it would return an error.
if (path == "." || path == "$") {
if (is_nx_condition || is_xx_condition) {
OpResult<PrimeIterator> it_res =
op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON);
bool key_exists = (it_res.status() != OpStatus::KEY_NOTFOUND);
if (is_nx_condition && key_exists) {
return false;
}
if (is_xx_condition && !key_exists) {
return false;
}
}
SetJson(op_args, key, std::move(parsed_json.value()));
return true;
}
@ -986,17 +1033,46 @@ OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
// an array that this expression will match each entry in it
// then the assign here is called N times, where N == array.size().
bool path_exists = false;
bool operation_result = false;
const JsonType& new_json = parsed_json.value();
auto cb = [&](const string& path, JsonType& val) {
path_exists = true;
val = new_json;
if (!is_nx_condition) {
operation_result = true;
val = new_json;
}
};
OpStatus status = UpdateEntry(op_args, key, path, cb);
auto inserter = [&](JsonType& json) {
// Set a new value if the path doesn't exist and the nx condition is not set.
if (!path_exists && !is_xx_condition) {
string pointer = ConvertExpressionToJsonPointer(path);
if (pointer.empty()) {
VLOG(1) << "Failed to convert the following expression path to a valid JSON pointer: "
<< path;
return OpStatus::SYNTAX_ERR;
}
error_code ec;
jsonpointer::add(json, pointer, new_json, ec);
if (ec) {
VLOG(1) << "Failed to add a JSON value to the following path: " << path
<< " with the error: " << ec.message();
return OpStatus::SYNTAX_ERR;
}
operation_result = true;
}
return OpStatus::OK;
};
OpStatus status = UpdateEntry(op_args, key, path, cb, inserter);
if (status != OpStatus::OK) {
return status;
}
return path_exists;
return operation_result;
}
} // namespace
@ -1005,9 +1081,23 @@ void JsonFamily::Set(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
string_view path = ArgS(args, 1);
string_view json_str = ArgS(args, 2);
bool is_nx_condition = false;
bool is_xx_condition = false;
string_view operation_opts;
if (args.size() > 3) {
operation_opts = ArgS(args, 3);
if (absl::EqualsIgnoreCase(operation_opts, "NX")) {
is_nx_condition = true;
} else if (absl::EqualsIgnoreCase(operation_opts, "XX")) {
is_xx_condition = true;
} else {
(*cntx)->SendError(kSyntaxErr);
return;
}
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpSet(t->GetOpArgs(shard), key, path, json_str);
return OpSet(t->GetOpArgs(shard), key, path, json_str, is_nx_condition, is_xx_condition);
};
Transaction* trans = cntx->transaction;
@ -1707,7 +1797,7 @@ void JsonFamily::Register(CommandRegistry* registry) {
*registry << CI{"JSON.ARRINDEX", CO::READONLY | CO::FAST, -4, 1, 1, 1}.HFUNC(ArrIndex);
*registry << CI{"JSON.DEBUG", CO::READONLY | CO::FAST, -2, 1, 1, 1}.HFUNC(Debug);
*registry << CI{"JSON.RESP", CO::READONLY | CO::FAST, -2, 1, 1, 1}.HFUNC(Resp);
*registry << CI{"JSON.SET", CO::WRITE | CO::DENYOOM | CO::FAST, 4, 1, 1, 1}.HFUNC(Set);
*registry << CI{"JSON.SET", CO::WRITE | CO::DENYOOM | CO::FAST, -4, 1, 1, 1}.HFUNC(Set);
}
} // namespace dfly