chore: improve compatibility of set and ping commands (#3569)

* chore: improve compatibility of set and ping commands

smismember should return an array of longs and not array of strings.
ping in subscribe mode returns an array for resp2.
Also, fix double rounding for legacy float mode.
---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-08-26 13:33:03 +03:00 committed by GitHub
parent 10816b500f
commit 908290a268
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 63 additions and 42 deletions

View file

@ -755,6 +755,8 @@ TEST_F(DflyEngineTest, EvalBug2664) {
auto resp = Run({"eval", "return 42.9", "0"});
EXPECT_THAT(resp, IntArg(42));
resp = Run({"eval", "return -3.8", "0"});
EXPECT_THAT(resp, IntArg(-3));
resp = Run({"hello", "3"});
ASSERT_THAT(resp, ArrLen(14));

View file

@ -767,16 +767,27 @@ void GenericFamily::Ping(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(facade::WrongNumArgsError("ping"), kSyntaxErrType);
}
// We synchronously block here until the engine sends us the payload and notifies that
// the I/O operation has been processed.
string_view msg;
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
// If a client in the subscribe state and in resp2 mode, it returns an array for some reason.
if (cntx->conn_state.subscribe_info && !rb->IsResp3()) {
if (args.size() == 1) {
msg = ArgS(args, 0);
}
string_view resp[2] = {"pong", msg};
return rb->SendStringArr(resp);
}
if (args.size() == 0) {
return cntx->SendSimpleString("PONG");
} else {
string_view arg = ArgS(args, 0);
DVLOG(2) << "Ping " << arg;
msg = ArgS(args, 0);
DVLOG(2) << "Ping " << msg;
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
return rb->SendBulkString(arg);
return rb->SendBulkString(msg);
}
}

View file

@ -330,7 +330,7 @@ class EvalSerializer : public ObjectExplorer {
if (rb_->IsResp3() || !absl::GetFlag(FLAGS_lua_resp2_legacy_float)) {
rb_->SendDouble(d);
} else {
long val = static_cast<long>(floor(d));
long val = d >= 0 ? static_cast<long>(floor(d)) : static_cast<long>(ceil(d));
rb_->SendLong(val);
}
}

View file

@ -227,14 +227,6 @@ int32_t GetExpiry(const DbContext& db_context, const SetType& st, string_view me
}
}
void FindInSet(StringVec& memberships, const DbContext& db_context, const SetType& st,
facade::ArgRange members) {
for (string_view member : members) {
bool status = IsInSet(db_context, st, member);
memberships.emplace_back(to_string(status));
}
}
// Removes arg from result.
void DiffStrSet(const DbContext& db_context, const SetType& st,
absl::flat_hash_set<string>* result) {
@ -975,6 +967,8 @@ void SIsMember(CmdArgList args, ConnectionContext* cntx) {
switch (result.status()) {
case OpStatus::OK:
return cntx->SendLong(1);
case OpStatus::WRONG_TYPE:
return cntx->SendError(OpStatus::WRONG_TYPE);
default:
return cntx->SendLong(0);
}
@ -982,16 +976,20 @@ void SIsMember(CmdArgList args, ConnectionContext* cntx) {
void SMIsMember(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
auto vals = args.subspan(1);
auto members = args.subspan(1);
StringVec memberships;
memberships.reserve(vals.size());
vector<bool> memberships(members.size());
auto cb = [&](Transaction* t, EngineShard* shard) {
auto find_res = t->GetDbSlice(shard->shard_id()).FindReadOnly(t->GetDbContext(), key, OBJ_SET);
DbContext db_cntx = t->GetDbContext();
auto find_res = t->GetDbSlice(shard->shard_id()).FindReadOnly(db_cntx, key, OBJ_SET);
if (find_res) {
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FindInSet(memberships, t->GetDbContext(), st, vals);
SetType st{(*find_res)->second.RObjPtr(), find_res.value()->second.Encoding()};
for (size_t i = 0; i < members.size(); ++i) {
auto member = members[i];
bool status = IsInSet(db_cntx, st, ToSV(member));
memberships[i] = status;
}
return OpStatus::OK;
}
return find_res.status();
@ -999,11 +997,11 @@ void SMIsMember(CmdArgList args, ConnectionContext* cntx) {
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));
if (result == OpStatus::KEY_NOTFOUND) {
memberships.assign(vals.size(), "0");
return rb->SendStringArr(memberships);
} else if (result == OpStatus::OK) {
return rb->SendStringArr(memberships);
if (result || result == OpStatus::KEY_NOTFOUND) {
rb->StartArray(memberships.size());
for (bool b : memberships)
rb->SendLong(int(b));
return;
}
cntx->SendError(result.status());
}
@ -1225,17 +1223,17 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) {
OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
rb->SendStringArr(*result, RedisReplyBuilder::SET);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (result || result == OpStatus::KEY_NOTFOUND) {
if (is_count) {
rb->SendStringArr(StringVec(), RedisReplyBuilder::SET);
rb->SendStringArr(*result, RedisReplyBuilder::SET);
} else if (result->size()) {
rb->SendBulkString(result->front());
} else {
rb->SendNull();
}
} else {
cntx->SendError(result.status());
return;
}
cntx->SendError(result.status());
}
void SInter(CmdArgList args, ConnectionContext* cntx) {
@ -1322,7 +1320,10 @@ void SInterCard(CmdArgList args, ConnectionContext* cntx) {
OpResult<SvArray> result =
InterResultVec(result_set, cntx->transaction->GetUniqueShardCnt(), limit);
return cntx->SendLong(result->size());
if (result) {
return cntx->SendLong(result->size());
}
cntx->SendError(result.status());
}
void SUnion(CmdArgList args, ConnectionContext* cntx) {

View file

@ -286,26 +286,22 @@ TEST_F(SetFamilyTest, SMIsMember) {
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"smismember", "foo1", "a", "b"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("0", "0"));
EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(0), IntArg(0))));
resp = Run({"smismember", "foo", "a", "c"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "0"));
EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(1), IntArg(0))));
resp = Run({"smismember", "foo", "a", "b"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "1"));
EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(1), IntArg(1))));
resp = Run({"smismember", "foo", "d", "e"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("0", "0"));
EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(0), IntArg(0))));
resp = Run({"smismember", "foo", "b"});
EXPECT_THAT(resp, "1");
EXPECT_THAT(resp, IntArg(1));
resp = Run({"smismember", "foo", "x"});
EXPECT_THAT(resp, "0");
EXPECT_THAT(resp, IntArg(0));
}
TEST_F(SetFamilyTest, Empty) {

View file

@ -480,6 +480,7 @@ def test_pubsub_numsub(r: redis.Redis):
@pytest.mark.min_server("7")
@testtools.run_test_if_redispy_ver("gte", "5.0.0rc2")
@pytest.mark.unsupported_server_types("dragonfly")
def test_published_message_to_shard_channel(r: redis.Redis):
p = r.pubsub()
p.ssubscribe("foo")
@ -493,6 +494,7 @@ def test_published_message_to_shard_channel(r: redis.Redis):
@pytest.mark.min_server("7")
@testtools.run_test_if_redispy_ver("gte", "5.0.0rc2")
@pytest.mark.unsupported_server_types("dragonfly")
def test_subscribe_property_with_shard_channels_cluster(r: redis.Redis):
p = r.pubsub()
keys = ["foo", "bar", "uni" + chr(4456) + "code"]
@ -539,6 +541,7 @@ def test_subscribe_property_with_shard_channels_cluster(r: redis.Redis):
@pytest.mark.min_server("7")
@testtools.run_test_if_redispy_ver("gte", "5.0.0")
@pytest.mark.unsupported_server_types("dragonfly")
def test_pubsub_shardnumsub(r: redis.Redis):
channels = {b"foo", b"bar", b"baz"}
p1 = r.pubsub()
@ -559,6 +562,7 @@ def test_pubsub_shardnumsub(r: redis.Redis):
@pytest.mark.min_server("7")
@testtools.run_test_if_redispy_ver("gte", "5.0.0rc2")
@pytest.mark.unsupported_server_types("dragonfly")
def test_pubsub_shardchannels(r: redis.Redis):
p = r.pubsub()
p.ssubscribe("foo", "bar", "baz", "quux")

View file

@ -320,6 +320,8 @@ def test_eval_global_and_return_ok(r: redis.Redis):
)
# Dragonfly uses lua5.4, so it natively supports doubles.
# To use legacy rounding of doubles to integers run dragonfly with --lua_resp2_legacy_float
def test_eval_convert_number(r: redis.Redis):
# Redis forces all Lua numbers to integer
val = r.eval("return 3.2", 0)
@ -349,6 +351,7 @@ def test_eval_call_bool6(r: redis.Redis):
@pytest.mark.min_server("7")
@pytest.mark.unsupported_server_types("dragonfly") # dragonfly allows this
def test_eval_call_bool7(r: redis.Redis):
# Redis doesn't allow Lua bools to be passed to [p]call
with pytest.raises(
@ -429,6 +432,7 @@ def test_eval_exists(r: redis.Redis):
assert val == 1
@pytest.mark.unsupported_server_types("dragonfly")
def test_eval_flushdb(r: redis.Redis):
r.set("foo", "bar")
val = r.eval(
@ -441,6 +445,7 @@ def test_eval_flushdb(r: redis.Redis):
assert val == 1
@pytest.mark.unsupported_server_types("dragonfly")
def test_eval_flushall(r, create_redis):
r1 = create_redis(db=2)
r2 = create_redis(db=3)
@ -461,6 +466,8 @@ def test_eval_flushall(r, create_redis):
assert "r2" not in r2
# Dragonfly lua supports doubles
@pytest.mark.unsupported_server_types("dragonfly")
def test_eval_incrbyfloat(r: redis.Redis):
r.set("foo", 0.5)
val = r.eval(