From 42116fa01216a1d285cd8026b330a3a7310f9b15 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Mon, 5 Jun 2023 18:26:01 +0300 Subject: [PATCH] feat(zset family): Implement ZDiff command issue #1311 (#1333) Signed-off-by: Kostas --- src/server/transaction.cc | 6 +++ src/server/zset_family.cc | 81 ++++++++++++++++++++++++++++++++++ src/server/zset_family.h | 1 + src/server/zset_family_test.cc | 53 ++++++++++++++++++++++ 4 files changed, 141 insertions(+) diff --git a/src/server/transaction.cc b/src/server/transaction.cc index e439ec62f..35f9d4c76 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -1432,6 +1432,12 @@ OpResult DetermineKeys(const CommandId* cid, CmdArgList args) { if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0) return OpStatus::INVALID_INT; + // TODO Fix this for Z family functions. + // Examples that crash: ZUNION 0 myset + if (name == "ZDIFF" && num_custom_keys == 0) { + return OpStatus::INVALID_INT; + } + if (args.size() < size_t(num_custom_keys) + num_keys_index + 1) return OpStatus::SYNTAX_ERR; } diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 82bbae7c5..8ac369452 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -1351,6 +1351,86 @@ void ZSetFamily::ZCount(CmdArgList args, ConnectionContext* cntx) { } } +vector OpFetch(EngineShard* shard, Transaction* t) { + ArgSlice keys = t->GetShardArgs(shard->shard_id()); + DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end()); + DCHECK(!keys.empty()); + + vector results; + results.reserve(keys.size()); + + auto& db_slice = shard->db_slice(); + for (size_t i = 0; i < keys.size(); ++i) { + auto it = db_slice.Find(t->GetDbContext(), keys[i], OBJ_ZSET); + if (!it) { + results.push_back({}); + continue; + } + + ScoredMap sm = FromObject((*it)->second, 1); + results.push_back(std::move(sm)); + } + + return results; +} + +void ZSetFamily::ZDiff(CmdArgList args, ConnectionContext* cntx) { + vector> maps(shard_set->size()); + auto cb = [&](Transaction* t, EngineShard* shard) { + maps[shard->shard_id()] = OpFetch(shard, t); + return OpStatus::OK; + }; + + cntx->transaction->ScheduleSingleHop(std::move(cb)); + + const string_view key = ArgS(args, 1); + const ShardId sid = Shard(key, maps.size()); + // Extract the ScoredMap of the first key + auto& sm = maps[sid]; + if (sm.empty()) { + (*cntx)->SendEmptyArray(); + return; + } + auto result = std::move(sm[0]); + sm.erase(sm.begin()); + + auto filter = [&result](const auto& key) mutable { + auto it = result.find(key); + if (it != result.end()) { + result.erase(it); + } + }; + + // Total O(L) + // Iterate over the results of each shard + for (auto& vsm : maps) { + // Iterate over each fetched set + for (auto& sm : vsm) { + // Iterate over each key in the fetched set and filter + for (auto& [key, value] : sm) { + filter(key); + } + } + } + + vector smvec; + for (const auto& elem : result) { + smvec.emplace_back(elem.second, elem.first); + } + + // Total O(KlogK) + std::sort(std::begin(smvec), std::end(smvec)); + + const bool with_scores = ArgS(args, args.size() - 1) == "WITHSCORES"; + (*cntx)->StartArray(result.size() * (with_scores ? 2 : 1)); + for (const auto& [score, key] : smvec) { + (*cntx)->SendBulkString(key); + if (with_scores) { + (*cntx)->SendDouble(score); + } + } +} + void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); string_view score_arg = ArgS(args, 1); @@ -2301,6 +2381,7 @@ void ZSetFamily::Register(CommandRegistry* registry) { .HFUNC(BZPopMax) << CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard) << CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount) + << CI{"ZDIFF", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, 1}.HFUNC(ZDiff) << CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy) << CI{"ZINTERSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZInterStore) << CI{"ZINTERCARD", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 2, 2, 1} diff --git a/src/server/zset_family.h b/src/server/zset_family.h index 36b8deb52..f4be6558b 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -61,6 +61,7 @@ class ZSetFamily { static void ZAdd(CmdArgList args, ConnectionContext* cntx); static void ZCard(CmdArgList args, ConnectionContext* cntx); static void ZCount(CmdArgList args, ConnectionContext* cntx); + static void ZDiff(CmdArgList args, ConnectionContext* cntx); static void ZIncrBy(CmdArgList args, ConnectionContext* cntx); static void ZInterStore(CmdArgList args, ConnectionContext* cntx); static void ZInterCard(CmdArgList args, ConnectionContext* cntx); diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 7b77967a1..86e7f50b7 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -578,4 +578,57 @@ TEST_F(ZSetFamilyTest, BlockingTimeout) { EXPECT_THAT(resp0, ArgType(RespExpr::NIL_ARRAY)); } +TEST_F(ZSetFamilyTest, ZDiffError) { + RespExpr resp; + + resp = Run({"zdiff", "-1", "z1"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); + + resp = Run({"zdiff", "0"}); + EXPECT_THAT(resp, ErrArg("wrong number of arguments")); + + resp = Run({"zdiff", "0", "z1"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); + + resp = Run({"zdiff", "0", "z1", "z2"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); +} + +TEST_F(ZSetFamilyTest, ZDiff) { + RespExpr resp; + + EXPECT_EQ(4, CheckedInt({"zadd", "z1", "1", "one", "2", "two", "3", "three", "4", "four"})); + EXPECT_EQ(2, CheckedInt({"zadd", "z2", "1", "one", "5", "five"})); + EXPECT_EQ(2, CheckedInt({"zadd", "z3", "2", "two", "3", "three"})); + EXPECT_EQ(1, CheckedInt({"zadd", "z4", "4", "four"})); + + resp = Run({"zdiff", "1", "z1"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("one", "two", "three", "four")); + + resp = Run({"zdiff", "2", "z1", "z1"}); + EXPECT_THAT(resp.GetVec().empty(), true); + + resp = Run({"zdiff", "2", "z1", "doesnt_exist"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("one", "two", "three", "four")); + + resp = Run({"zdiff", "2", "z1", "z2"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("two", "three", "four")); + + resp = Run({"zdiff", "2", "z1", "z3"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("one", "four")); + + resp = Run({"zdiff", "4", "z1", "z2", "z3", "z4"}); + EXPECT_THAT(resp.GetVec().empty(), true); + + resp = Run({"zdiff", "2", "doesnt_exist", "key1"}); + EXPECT_THAT(resp.GetVec().empty(), true); + + // WITHSCORES + resp = Run({"zdiff", "1", "z1", "WITHSCORES"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("one", "1", "two", "2", "three", "3", "four", "4")); + + resp = Run({"zdiff", "2", "z1", "z2", "WITHSCORES"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("two", "2", "three", "3", "four", "4")); +} + } // namespace dfly