feat(zset family): Implement ZDiff command issue #1311 (#1333)

Signed-off-by: Kostas <kostaskyrim@gmail.com>
This commit is contained in:
Kostas Kyrimis 2023-06-05 18:26:01 +03:00 committed by GitHub
parent bf44b56667
commit 42116fa012
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 0 deletions

View file

@ -1432,6 +1432,12 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0)
return OpStatus::INVALID_INT;
// TODO Fix this for Z family functions.
// Examples that crash: ZUNION 0 myset
if (name == "ZDIFF" && num_custom_keys == 0) {
return OpStatus::INVALID_INT;
}
if (args.size() < size_t(num_custom_keys) + num_keys_index + 1)
return OpStatus::SYNTAX_ERR;
}

View file

@ -1351,6 +1351,86 @@ void ZSetFamily::ZCount(CmdArgList args, ConnectionContext* cntx) {
}
}
vector<ScoredMap> OpFetch(EngineShard* shard, Transaction* t) {
ArgSlice keys = t->GetShardArgs(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
vector<ScoredMap> results;
results.reserve(keys.size());
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < keys.size(); ++i) {
auto it = db_slice.Find(t->GetDbContext(), keys[i], OBJ_ZSET);
if (!it) {
results.push_back({});
continue;
}
ScoredMap sm = FromObject((*it)->second, 1);
results.push_back(std::move(sm));
}
return results;
}
void ZSetFamily::ZDiff(CmdArgList args, ConnectionContext* cntx) {
vector<vector<ScoredMap>> maps(shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] = OpFetch(shard, t);
return OpStatus::OK;
};
cntx->transaction->ScheduleSingleHop(std::move(cb));
const string_view key = ArgS(args, 1);
const ShardId sid = Shard(key, maps.size());
// Extract the ScoredMap of the first key
auto& sm = maps[sid];
if (sm.empty()) {
(*cntx)->SendEmptyArray();
return;
}
auto result = std::move(sm[0]);
sm.erase(sm.begin());
auto filter = [&result](const auto& key) mutable {
auto it = result.find(key);
if (it != result.end()) {
result.erase(it);
}
};
// Total O(L)
// Iterate over the results of each shard
for (auto& vsm : maps) {
// Iterate over each fetched set
for (auto& sm : vsm) {
// Iterate over each key in the fetched set and filter
for (auto& [key, value] : sm) {
filter(key);
}
}
}
vector<ScoredMemberView> smvec;
for (const auto& elem : result) {
smvec.emplace_back(elem.second, elem.first);
}
// Total O(KlogK)
std::sort(std::begin(smvec), std::end(smvec));
const bool with_scores = ArgS(args, args.size() - 1) == "WITHSCORES";
(*cntx)->StartArray(result.size() * (with_scores ? 2 : 1));
for (const auto& [score, key] : smvec) {
(*cntx)->SendBulkString(key);
if (with_scores) {
(*cntx)->SendDouble(score);
}
}
}
void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
string_view score_arg = ArgS(args, 1);
@ -2301,6 +2381,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
.HFUNC(BZPopMax)
<< CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard)
<< CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount)
<< CI{"ZDIFF", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, 1}.HFUNC(ZDiff)
<< CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy)
<< CI{"ZINTERSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZInterStore)
<< CI{"ZINTERCARD", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 2, 2, 1}

View file

@ -61,6 +61,7 @@ class ZSetFamily {
static void ZAdd(CmdArgList args, ConnectionContext* cntx);
static void ZCard(CmdArgList args, ConnectionContext* cntx);
static void ZCount(CmdArgList args, ConnectionContext* cntx);
static void ZDiff(CmdArgList args, ConnectionContext* cntx);
static void ZIncrBy(CmdArgList args, ConnectionContext* cntx);
static void ZInterStore(CmdArgList args, ConnectionContext* cntx);
static void ZInterCard(CmdArgList args, ConnectionContext* cntx);

View file

@ -578,4 +578,57 @@ TEST_F(ZSetFamilyTest, BlockingTimeout) {
EXPECT_THAT(resp0, ArgType(RespExpr::NIL_ARRAY));
}
TEST_F(ZSetFamilyTest, ZDiffError) {
RespExpr resp;
resp = Run({"zdiff", "-1", "z1"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
resp = Run({"zdiff", "0"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"zdiff", "0", "z1"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
resp = Run({"zdiff", "0", "z1", "z2"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
}
TEST_F(ZSetFamilyTest, ZDiff) {
RespExpr resp;
EXPECT_EQ(4, CheckedInt({"zadd", "z1", "1", "one", "2", "two", "3", "three", "4", "four"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "1", "one", "5", "five"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z3", "2", "two", "3", "three"}));
EXPECT_EQ(1, CheckedInt({"zadd", "z4", "4", "four"}));
resp = Run({"zdiff", "1", "z1"});
EXPECT_THAT(resp.GetVec(), ElementsAre("one", "two", "three", "four"));
resp = Run({"zdiff", "2", "z1", "z1"});
EXPECT_THAT(resp.GetVec().empty(), true);
resp = Run({"zdiff", "2", "z1", "doesnt_exist"});
EXPECT_THAT(resp.GetVec(), ElementsAre("one", "two", "three", "four"));
resp = Run({"zdiff", "2", "z1", "z2"});
EXPECT_THAT(resp.GetVec(), ElementsAre("two", "three", "four"));
resp = Run({"zdiff", "2", "z1", "z3"});
EXPECT_THAT(resp.GetVec(), ElementsAre("one", "four"));
resp = Run({"zdiff", "4", "z1", "z2", "z3", "z4"});
EXPECT_THAT(resp.GetVec().empty(), true);
resp = Run({"zdiff", "2", "doesnt_exist", "key1"});
EXPECT_THAT(resp.GetVec().empty(), true);
// WITHSCORES
resp = Run({"zdiff", "1", "z1", "WITHSCORES"});
EXPECT_THAT(resp.GetVec(), ElementsAre("one", "1", "two", "2", "three", "3", "four", "4"));
resp = Run({"zdiff", "2", "z1", "z2", "WITHSCORES"});
EXPECT_THAT(resp.GetVec(), ElementsAre("two", "2", "three", "3", "four", "4"));
}
} // namespace dfly