feat(server): implement json.numincrby (#240) (#252)

* feat(server): implement json.numincrby (#240)

Signed-off-by: iko1 <me@remotecpp.dev>
This commit is contained in:
iko1 2022-08-26 16:13:47 +03:00 committed by Roman Gershman
parent 3e3496c4bf
commit 3acb1bb704
10 changed files with 356 additions and 76 deletions

View file

@ -30,5 +30,6 @@ extern const char kSyntaxErrType[];
extern const char kScriptErrType[];
extern const char kIndexOutOfRange[];
extern const char kOutOfMemory[];
extern const char kInvalidNumericResult[];
} // namespace dfly

View file

@ -80,6 +80,7 @@ const char kSyntaxErrType[] = "syntax_error";
const char kScriptErrType[] = "script_error";
const char kIndexOutOfRange[] = "index out of range";
const char kOutOfMemory[] = "Out of memory";
const char kInvalidNumericResult[] = "result is not a number";
const char* RespExpr::TypeName(Type t) {
switch (t) {

View file

@ -24,6 +24,7 @@ enum class OpStatus : uint16_t {
BUSY_GROUP,
STREAM_ID_SMALL,
ENTRIES_ADDED_SMALL,
INVALID_NUMERIC_RESULT,
};
class OpResultBase {

View file

@ -247,6 +247,9 @@ void RedisReplyBuilder::SendError(OpStatus status) {
case OpStatus::BUSY_GROUP:
SendError("-BUSYGROUP Consumer Group name already exists");
break;
case OpStatus::INVALID_NUMERIC_RESULT:
SendError(kInvalidNumericResult);
break;
default:
LOG(ERROR) << "Unsupported status " << status;
SendError("Internal error");
@ -340,7 +343,6 @@ void RedisReplyBuilder::SendStringArr(StrPtr str_ptr, uint32_t len) {
unsigned vec_indx = 1;
string_view src;
for (unsigned i = 0; i < len; ++i) {
if (holds_alternative<const string_view*>(str_ptr)) {
src = get<const string_view*>(str_ptr)[i];
} else {

View file

@ -4,6 +4,7 @@
#include "server/common.h"
#include <absl/strings/charconv.h>
#include <absl/strings/str_cat.h>
#include <mimalloc.h>
@ -163,6 +164,22 @@ bool ParseHumanReadableBytes(std::string_view str, int64_t* num_bytes) {
return true;
}
bool ParseDouble(string_view src, double* value) {
if (src.empty())
return false;
if (src == "-inf") {
*value = -HUGE_VAL;
} else if (src == "+inf") {
*value = HUGE_VAL;
} else {
absl::from_chars_result result = absl::from_chars(src.data(), src.end(), *value);
if (int(result.ec) != 0 || result.ptr != src.end() || isnan(*value))
return false;
}
return true;
}
#define ADD(x) (x) += o.x
TieredStats& TieredStats::operator+=(const TieredStats& o) {

View file

@ -112,6 +112,7 @@ inline void ToLower(const MutableSlice* val) {
}
bool ParseHumanReadableBytes(std::string_view str, int64_t* num_bytes);
bool ParseDouble(std::string_view src, double* value);
const char* ObjTypeName(int type);
const char* RdbTypeName(unsigned type);

View file

@ -6,7 +6,6 @@
extern "C" {
#include "redis/object.h"
#include "redis/util.h"
}
#include <absl/strings/str_join.h>
@ -29,7 +28,7 @@ using namespace jsoncons;
using JsonExpression = jsonpath::jsonpath_expression<json>;
using OptBool = optional<bool>;
using OptSizeT = optional<size_t>;
using JsonReplaceCb = std::function<void(const string&, json&)>;
using JsonReplaceCb = function<void(const string&, json&)>;
using CI = CommandId;
namespace {
@ -84,6 +83,27 @@ string JsonType(const json& val) {
return "";
}
template <typename T>
void PrintOptVec(ConnectionContext* cntx, const OpResult<vector<optional<T>>>& result) {
if (result->empty()) {
(*cntx)->SendNullArray();
} else {
(*cntx)->StartArray(result->size());
for (auto& it : *result) {
if (it.has_value()) {
if constexpr (is_floating_point_v<T>) {
(*cntx)->SendDouble(*it);
} else {
static_assert(is_integral_v<T>, "Integral required.");
(*cntx)->SendLong(*it);
}
} else {
(*cntx)->SendNull();
}
}
}
}
error_code JsonReplace(json& instance, string_view& path, JsonReplaceCb callback) {
using evaluator_t = jsoncons::jsonpath::detail::jsonpath_evaluator<json, json&>;
using value_type = evaluator_t::value_type;
@ -259,8 +279,109 @@ OpResult<vector<OptBool>> OpToggle(const OpArgs& op_args, string_view key, strin
return vec;
}
template <typename Op>
OpResult<string> OpDoubleArithmetic(const OpArgs& op_args, string_view key, string_view path,
double num, Op arithmetic_op) {
OpResult<json> result = GetJson(op_args, key);
if (!result) {
return result.status();
}
bool is_result_overflow = false;
double int_part;
bool has_fractional_part = (modf(num, &int_part) != 0);
json output(json_array_arg);
auto cb = [&](const string& path, json& val) {
if (val.is_number()) {
double result = arithmetic_op(val.as<double>(), num);
if (isinf(result)) {
is_result_overflow = true;
return;
}
if (val.is_double() || has_fractional_part) {
val = result;
} else {
val = (uint64_t)result;
}
output.push_back(val);
} else {
output.push_back(json::null());
}
};
json j = result.value();
error_code ec = JsonReplace(j, path, cb);
if (ec) {
VLOG(1) << "Failed to evaulate expression on json with error: " << ec.message();
return OpStatus::SYNTAX_ERR;
}
if (is_result_overflow) {
return OpStatus::INVALID_NUMERIC_RESULT;
}
SetString(op_args, key, j.as_string());
return output.as_string();
}
} // namespace
void JsonFamily::NumIncrBy(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view path = ArgS(args, 2);
string_view num = ArgS(args, 3);
double dnum;
if (!ParseDouble(num, &dnum)) {
(*cntx)->SendError(kWrongTypeErr);
return;
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpDoubleArithmetic(t->GetOpArgs(shard), key, path, dnum, plus<double>{});
};
DVLOG(1) << "Before Get::ScheduleSingleHopT " << key;
Transaction* trans = cntx->transaction;
OpResult<string> result = trans->ScheduleSingleHopT(move(cb));
if (result) {
DVLOG(1) << "JSON.NUMINCRBY " << trans->DebugId() << ": " << key;
(*cntx)->SendSimpleString(*result);
} else {
(*cntx)->SendError(result.status());
}
}
void JsonFamily::NumMultBy(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view path = ArgS(args, 2);
string_view num = ArgS(args, 3);
double dnum;
if (!ParseDouble(num, &dnum)) {
(*cntx)->SendError(kWrongTypeErr);
return;
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpDoubleArithmetic(t->GetOpArgs(shard), key, path, dnum, multiplies<double>{});
};
DVLOG(1) << "Before Get::ScheduleSingleHopT " << key;
Transaction* trans = cntx->transaction;
OpResult<string> result = trans->ScheduleSingleHopT(move(cb));
if (result) {
DVLOG(1) << "JSON.NUMMULTBY " << trans->DebugId() << ": " << key;
(*cntx)->SendSimpleString(*result);
} else {
(*cntx)->SendError(result.status());
}
}
void JsonFamily::Toggle(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view path = ArgS(args, 2);
@ -275,18 +396,7 @@ void JsonFamily::Toggle(CmdArgList args, ConnectionContext* cntx) {
if (result) {
DVLOG(1) << "JSON.TOGGLE " << trans->DebugId() << ": " << key;
if (result->empty()) {
(*cntx)->SendNullArray();
} else {
(*cntx)->StartArray(result->size());
for (auto& it : *result) {
if (it.has_value()) {
(*cntx)->SendLong(*it);
} else {
(*cntx)->SendNull();
}
}
}
PrintOptVec(cntx, result);
} else {
(*cntx)->SendError(result.status());
}
@ -353,18 +463,7 @@ void JsonFamily::ArrLen(CmdArgList args, ConnectionContext* cntx) {
if (result) {
DVLOG(1) << "JSON.ARRLEN " << trans->DebugId() << ": " << key;
if (result->empty()) {
(*cntx)->SendNullArray();
} else {
(*cntx)->StartArray(result->size());
for (auto& it : *result) {
if (it.has_value()) {
(*cntx)->SendLong(*it);
} else {
(*cntx)->SendNull();
}
}
}
PrintOptVec(cntx, result);
} else {
(*cntx)->SendError(result.status());
}
@ -393,18 +492,7 @@ void JsonFamily::ObjLen(CmdArgList args, ConnectionContext* cntx) {
if (result) {
DVLOG(1) << "JSON.OBJLEN " << trans->DebugId() << ": " << key;
if (result->empty()) {
(*cntx)->SendNullArray();
} else {
(*cntx)->StartArray(result->size());
for (auto& it : *result) {
if (it.has_value()) {
(*cntx)->SendLong(*it);
} else {
(*cntx)->SendNull();
}
}
}
PrintOptVec(cntx, result);
} else {
(*cntx)->SendError(result.status());
}
@ -433,18 +521,7 @@ void JsonFamily::StrLen(CmdArgList args, ConnectionContext* cntx) {
if (result) {
DVLOG(1) << "JSON.STRLEN " << trans->DebugId() << ": " << key;
if (result->empty()) {
(*cntx)->SendNullArray();
} else {
(*cntx)->StartArray(result->size());
for (auto& it : *result) {
if (it.has_value()) {
(*cntx)->SendLong(*it);
} else {
(*cntx)->SendNull();
}
}
}
PrintOptVec(cntx, result);
} else {
(*cntx)->SendError(result.status());
}
@ -495,6 +572,10 @@ void JsonFamily::Register(CommandRegistry* registry) {
*registry << CI{"JSON.OBJLEN", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ObjLen);
*registry << CI{"JSON.ARRLEN", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ArrLen);
*registry << CI{"JSON.TOGGLE", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(Toggle);
*registry << CI{"JSON.NUMINCRBY", CO::WRITE | CO::DENYOOM | CO::FAST, 4, 1, 1, 1}.HFUNC(
NumIncrBy);
*registry << CI{"JSON.NUMMULTBY", CO::WRITE | CO::DENYOOM | CO::FAST, 4, 1, 1, 1}.HFUNC(
NumMultBy);
}
} // namespace dfly

View file

@ -13,6 +13,7 @@ class ConnectionContext;
class CommandRegistry;
using facade::OpResult;
using facade::OpStatus;
using facade::RedisReplyBuilder;
class JsonFamily {
public:
@ -25,6 +26,8 @@ class JsonFamily {
static void ObjLen(CmdArgList args, ConnectionContext* cntx);
static void ArrLen(CmdArgList args, ConnectionContext* cntx);
static void Toggle(CmdArgList args, ConnectionContext* cntx);
static void NumIncrBy(CmdArgList args, ConnectionContext* cntx);
static void NumMultBy(CmdArgList args, ConnectionContext* cntx);
};
} // namespace dfly

View file

@ -262,4 +262,198 @@ TEST_F(JsonFamilyTest, Toggle) {
EXPECT_EQ(resp, R"([true,false,1,null,"foo",[],{}])");
}
TEST_F(JsonFamilyTest, NumIncrBy) {
string json = R"(
{"e":1.5,"a":1}
)";
auto resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMINCRBY", "json", "$.a", "1.1"});
EXPECT_EQ(resp, "[2.1]");
resp = Run({"JSON.NUMINCRBY", "json", "$.e", "1"});
EXPECT_EQ(resp, "[2.5]");
resp = Run({"JSON.NUMINCRBY", "json", "$.e", "inf"});
EXPECT_THAT(resp, ErrArg("ERR result is not a number"));
json = R"(
{"e":1.5,"a":1}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMINCRBY", "json", "$.e", "1.7e308"});
EXPECT_EQ(resp, "[1.7e+308]");
resp = Run({"JSON.NUMINCRBY", "json", "$.e", "1.7e308"});
EXPECT_THAT(resp, ErrArg("ERR result is not a number"));
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([1,1.7e+308])");
json = R"(
{"a":[], "b":[1], "c":[1,2], "d":[1,2,3]}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMINCRBY", "json", "$.d[*]", "10"});
EXPECT_EQ(resp, "[11,12,13]");
resp = Run({"JSON.GET", "json", "$.d[*]"});
EXPECT_EQ(resp, "[11,12,13]");
json = R"(
{"a":[], "b":[1], "c":[1,2], "d":[1,2,3]}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMINCRBY", "json", "$.a[*]", "1"});
EXPECT_EQ(resp, "[]");
resp = Run({"JSON.NUMINCRBY", "json", "$.b[*]", "1"});
EXPECT_EQ(resp, "[2]");
resp = Run({"JSON.NUMINCRBY", "json", "$.c[*]", "1"});
EXPECT_EQ(resp, "[2,3]");
resp = Run({"JSON.NUMINCRBY", "json", "$.d[*]", "1"});
EXPECT_EQ(resp, "[2,3,4]");
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([[],[2],[2,3],[2,3,4]])");
json = R"(
{"a":{}, "b":{"a":1}, "c":{"a":1, "b":2}, "d":{"a":1, "b":2, "c":3}}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMINCRBY", "json", "$.a.*", "1"});
EXPECT_EQ(resp, "[]");
resp = Run({"JSON.NUMINCRBY", "json", "$.b.*", "1"});
EXPECT_EQ(resp, "[2]");
resp = Run({"JSON.NUMINCRBY", "json", "$.c.*", "1"});
EXPECT_EQ(resp, "[2,3]");
resp = Run({"JSON.NUMINCRBY", "json", "$.d.*", "1"});
EXPECT_EQ(resp, "[2,3,4]");
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([{},{"a":2},{"a":2,"b":3},{"a":2,"b":3,"c":4}])");
json = R"(
{"a":{"a":"a"}, "b":{"a":"a", "b":1}, "c":{"a":"a", "b":"b"}, "d":{"a":1, "b":"b", "c":3}}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMINCRBY", "json", "$.a.*", "1"});
EXPECT_EQ(resp, "[null]");
resp = Run({"JSON.NUMINCRBY", "json", "$.b.*", "1"});
EXPECT_EQ(resp, "[null,2]");
resp = Run({"JSON.NUMINCRBY", "json", "$.c.*", "1"});
EXPECT_EQ(resp, "[null,null]");
resp = Run({"JSON.NUMINCRBY", "json", "$.d.*", "1"});
EXPECT_EQ(resp, "[2,null,4]");
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([{"a":"a"},{"a":"a","b":2},{"a":"a","b":"b"},{"a":2,"b":"b","c":4}])");
}
TEST_F(JsonFamilyTest, NumMultBy) {
string json = R"(
{"a":[], "b":[1], "c":[1,2], "d":[1,2,3]}
)";
auto resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMMULTBY", "json", "$.d[*]", "2"});
EXPECT_EQ(resp, "[2,4,6]");
resp = Run({"JSON.GET", "json", "$.d[*]"});
EXPECT_EQ(resp, R"([2,4,6])");
json = R"(
{"a":[], "b":[1], "c":[1,2], "d":[1,2,3]}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMMULTBY", "json", "$.a[*]", "2"});
EXPECT_EQ(resp, "[]");
resp = Run({"JSON.NUMMULTBY", "json", "$.b[*]", "2"});
EXPECT_EQ(resp, "[2]");
resp = Run({"JSON.NUMMULTBY", "json", "$.c[*]", "2"});
EXPECT_EQ(resp, "[2,4]");
resp = Run({"JSON.NUMMULTBY", "json", "$.d[*]", "2"});
EXPECT_EQ(resp, "[2,4,6]");
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([[],[2],[2,4],[2,4,6]])");
json = R"(
{"a":{}, "b":{"a":1}, "c":{"a":1, "b":2}, "d":{"a":1, "b":2, "c":3}}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMMULTBY", "json", "$.a.*", "2"});
EXPECT_EQ(resp, "[]");
resp = Run({"JSON.NUMMULTBY", "json", "$.b.*", "2"});
EXPECT_EQ(resp, "[2]");
resp = Run({"JSON.NUMMULTBY", "json", "$.c.*", "2"});
EXPECT_EQ(resp, "[2,4]");
resp = Run({"JSON.NUMMULTBY", "json", "$.d.*", "2"});
EXPECT_EQ(resp, "[2,4,6]");
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([{},{"a":2},{"a":2,"b":4},{"a":2,"b":4,"c":6}])");
json = R"(
{"a":{"a":"a"}, "b":{"a":"a", "b":1}, "c":{"a":"a", "b":"b"}, "d":{"a":1, "b":"b", "c":3}}
)";
resp = Run({"set", "json", json});
ASSERT_THAT(resp, "OK");
resp = Run({"JSON.NUMMULTBY", "json", "$.a.*", "2"});
EXPECT_EQ(resp, "[null]");
resp = Run({"JSON.NUMMULTBY", "json", "$.b.*", "2"});
EXPECT_EQ(resp, "[null,2]");
resp = Run({"JSON.NUMMULTBY", "json", "$.c.*", "2"});
EXPECT_EQ(resp, "[null,null]");
resp = Run({"JSON.NUMMULTBY", "json", "$.d.*", "2"});
EXPECT_EQ(resp, "[2,null,6]");
resp = Run({"JSON.GET", "json", "$.*"});
EXPECT_EQ(resp, R"([{"a":"a"},{"a":"a","b":2},{"a":"a","b":"b"},{"a":2,"b":"b","c":6}])");
}
} // namespace dfly

View file

@ -11,8 +11,6 @@ extern "C" {
#include "redis/zset.h"
}
#include <absl/strings/charconv.h>
#include "base/logging.h"
#include "base/stl_util.h"
#include "facade/error.h"
@ -531,22 +529,6 @@ void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vl
}
}
bool ParseScore(string_view src, double* score) {
if (src.empty())
return false;
if (src == "-inf") {
*score = -HUGE_VAL;
} else if (src == "+inf") {
*score = HUGE_VAL;
} else {
absl::from_chars_result result = absl::from_chars(src.data(), src.end(), *score);
if (int(result.ec) != 0 || result.ptr != src.end() || isnan(*score))
return false;
}
return true;
};
bool ParseBound(string_view src, ZSetFamily::Bound* bound) {
if (src.empty())
return false;
@ -556,7 +538,7 @@ bool ParseBound(string_view src, ZSetFamily::Bound* bound) {
src.remove_prefix(1);
}
return ParseScore(src, &bound->val);
return ParseDouble(src, &bound->val);
}
bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) {
@ -956,7 +938,7 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
string_view cur_arg = ArgS(args, i);
double val = 0;
if (!ParseScore(cur_arg, &val)) {
if (!ParseDouble(cur_arg, &val)) {
VLOG(1) << "Bad score:" << cur_arg << "|";
return (*cntx)->SendError(kInvalidFloatErr);
}
@ -1135,8 +1117,7 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
if (shard->shard_id() == dest_shard) {
ZParams zparams;
zparams.override = true;
add_result =
OpAdd(t->GetOpArgs(shard), zparams, dest_key, ScoredMemberSpan{smvec}).value();
add_result = OpAdd(t->GetOpArgs(shard), zparams, dest_key, ScoredMemberSpan{smvec}).value();
}
return OpStatus::OK;
};
@ -1419,8 +1400,7 @@ void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
if (shard->shard_id() == dest_shard) {
ZParams zparams;
zparams.override = true;
add_result =
OpAdd(t->GetOpArgs(shard), zparams, dest_key, ScoredMemberSpan{smvec}).value();
add_result = OpAdd(t->GetOpArgs(shard), zparams, dest_key, ScoredMemberSpan{smvec}).value();
}
return OpStatus::OK;
};
@ -1534,7 +1514,6 @@ void ZSetFamily::ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext*
string_view member = ArgS(args, 2);
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRank(t->GetOpArgs(shard), key, member, reverse);
};