diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index cfa4cc78a..e8f1c4455 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -3,6 +3,7 @@ * **[Amir Alperin](https://github.com/iko1)** * **[Philipp Born](https://github.com/tamcore)** * Helm Chart +* **[Redha Lhimeur](https://github.com/redhal)** * **[Braydn Moore](https://github.com/braydnm)** * **[Logan Raarup](https://github.com/logandk)** * **[Ryan Russell](https://github.com/ryanrussell)** diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 2d66bd7c4..b55d234de 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -125,6 +125,7 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& op_args enum class Action { RANGE = 0, REMOVE = 1, + POP = 2 }; class IntervalVisitor { @@ -139,6 +140,8 @@ class IntervalVisitor { void operator()(const ZSetFamily::LexInterval& li); + void operator()(ZSetFamily::TopNScored sc); + ZSetFamily::ScoredArray PopResult() { return std::move(result_); } @@ -154,6 +157,9 @@ class IntervalVisitor { void ExtractListPack(const zlexrangespec& range); void ExtractSkipList(const zlexrangespec& range); + void PopListPack(ZSetFamily::TopNScored sc); + void PopSkipList(ZSetFamily::TopNScored sc); + void ActionRange(unsigned start, unsigned end); // rank void ActionRange(const zrangespec& range); // score void ActionRange(const zlexrangespec& range); // lex @@ -162,6 +168,8 @@ class IntervalVisitor { void ActionRem(const zrangespec& range); // score void ActionRem(const zlexrangespec& range); // lex + void ActionPop(ZSetFamily::TopNScored sc); + void Next(uint8_t* zl, uint8_t** eptr, uint8_t** sptr) const { if (params_.reverse) { zzlPrev(zl, eptr, sptr); @@ -214,6 +222,8 @@ void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) { case Action::REMOVE: ActionRem(start, end); break; + default: + break; } } @@ -227,6 +237,8 @@ void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) { case Action::REMOVE: ActionRem(range); break; + default: + break; } } @@ -240,10 +252,22 @@ void IntervalVisitor::operator()(const ZSetFamily::LexInterval& li) { case Action::REMOVE: ActionRem(range); break; + default: + break; } zslFreeLexRange(&range); } +void IntervalVisitor::operator()(ZSetFamily::TopNScored sc) { + switch (action_) { + case Action::POP: + ActionPop(sc); + break; + default: + break; + } +} + void IntervalVisitor::ActionRange(unsigned start, unsigned end) { container_utils::IterateSortedSet(zobj_, [this](container_utils::ContainerEntry ce, double score){ result_.emplace_back(ce.ToString(), score); @@ -311,6 +335,15 @@ void IntervalVisitor::ActionRem(const zlexrangespec& range) { } } +void IntervalVisitor::ActionPop(ZSetFamily::TopNScored sc) { + if (zobj_->encoding == OBJ_ENCODING_LISTPACK) { + PopListPack(sc); + } else { + CHECK_EQ(zobj_->encoding, OBJ_ENCODING_SKIPLIST); + PopSkipList(sc); + } +} + void IntervalVisitor::ExtractListPack(const zrangespec& range) { uint8_t* zl = (uint8_t*)zobj_->ptr; uint8_t *eptr, *sptr; @@ -472,6 +505,67 @@ void IntervalVisitor::ExtractSkipList(const zlexrangespec& range) { } } +void IntervalVisitor::PopListPack(ZSetFamily::TopNScored sc) { + uint8_t* zl = (uint8_t*)zobj_->ptr; + uint8_t *eptr, *sptr; + uint8_t* vstr; + unsigned int vlen = 0; + long long vlong = 0; + + if (params_.reverse) { + eptr = lpSeek(zl,-2); + } else { + eptr = lpSeek(zl,0); + } + + /* Get score pointer for the first element. */ + if (eptr) + sptr = lpNext(zl, eptr); + + /* First we get the entries */ + unsigned int num = sc; + while (eptr && num--) { + double score = zzlGetScore(sptr); + vstr = lpGetValue(eptr, &vlen, &vlong); + AddResult(vstr, vlen, vlong, score); + + /* Move to next node */ + Next(zl, &eptr, &sptr); + } + + int start = 0; + if (params_.reverse) { + /* If the number of elements to delete is greater than the listpack length, + * we set the start to 0 because lpseek fails to search beyond length in reverse */ + start = (2*sc > lpLength(zl)) ? 0 : -2*sc; + } + + /* We can finally delete the elements */ + zobj_->ptr = lpDeleteRange(zl, start, 2*sc); +} + +void IntervalVisitor::PopSkipList(ZSetFamily::TopNScored sc) { + zset* zs = (zset*)zobj_->ptr; + zskiplist* zsl = zs->zsl; + zskiplistNode* ln; + + /* We start from the header, or the tail if reversed. */ + if (params_.reverse) { + ln = zsl->tail; + } else { + ln = zsl->header; + } + + while (ln && sc--) { + result_.emplace_back(string{ln->ele, sdslen(ln->ele)}, ln->score); + + /* we can delete the element now */ + zsetDel(zobj_, ln->ele); + + ln = Next(ln); + } +} + void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vlong, double score) { if (vstr == NULL) { result_.emplace_back(absl::StrCat(vlong), score); @@ -1078,6 +1172,14 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) { (*cntx)->SendLong(smvec.size()); } +void ZSetFamily::ZPopMax(CmdArgList args, ConnectionContext* cntx) { + ZPopMinMax(std::move(args), true, cntx); +} + +void ZSetFamily::ZPopMin(CmdArgList args, ConnectionContext* cntx) { + ZPopMinMax(std::move(args), false, cntx); +} + void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 1); @@ -1532,6 +1634,30 @@ bool ZSetFamily::ParseRangeByScoreParams(CmdArgList args, RangeParams* params) { return true; } +void ZSetFamily::ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cntx) { + string_view key = ArgS(args, 1); + string_view count = ArgS(args, 2); + + RangeParams range_params; + range_params.reverse = reverse; + ZRangeSpec range_spec; + range_spec.params = range_params; + TopNScored sc; + + if (!SimpleAtoi(count, &sc)) { + return (*cntx)->SendError(kUintErr); + } + + range_spec.interval = sc; + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpPopCount(range_spec, t->GetOpArgs(shard), key); + }; + + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + OutputScoredArrayResult(result, range_params, cntx); +} + OpResult ZSetFamily::OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor) { OpResult find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); @@ -1658,6 +1784,30 @@ OpResult ZSetFamily::OpMScore(const OpArgs& op_args, return scores; } +auto ZSetFamily::OpPopCount(const ZRangeSpec& range_spec, const OpArgs& op_args, string_view key) -> OpResult { + auto& db_slice = op_args.shard->db_slice(); + OpResult res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET); + if (!res_it) + return res_it.status(); + + db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it); + + robj* zobj = res_it.value()->second.AsRObj(); + + IntervalVisitor iv{Action::POP, range_spec.params, zobj}; + std::visit(iv, range_spec.interval); + + res_it.value()->second.SyncRObj(); + db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key); + + auto zlen = zsetLength(zobj); + if (zlen == 0) { + CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value())); + } + + return iv.PopResult(); +} + auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, string_view key) -> OpResult { OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); @@ -1857,6 +2007,8 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy) << CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore) << CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount) + << CI{"ZPOPMAX", CO::READONLY, 3, 1, 1, 1}.HFUNC(ZPopMax) + << CI{"ZPOPMIN", CO::READONLY, 3, 1, 1, 1}.HFUNC(ZPopMin) << CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem) << CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange) << CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRank) diff --git a/src/server/zset_family.h b/src/server/zset_family.h index 9ee85777b..b39ad6bfb 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -34,6 +34,8 @@ class ZSetFamily { using LexInterval = std::pair; + using TopNScored = uint32_t; + struct RangeParams { uint32_t offset = 0; uint32_t limit = UINT32_MAX; @@ -42,7 +44,7 @@ class ZSetFamily { }; struct ZRangeSpec { - std::variant interval; + std::variant interval; RangeParams params; }; @@ -58,6 +60,8 @@ class ZSetFamily { static void ZIncrBy(CmdArgList args, ConnectionContext* cntx); static void ZInterStore(CmdArgList args, ConnectionContext* cntx); static void ZLexCount(CmdArgList args, ConnectionContext* cntx); + static void ZPopMax(CmdArgList args, ConnectionContext* cntx); + static void ZPopMin(CmdArgList args, ConnectionContext* cntx); static void ZRange(CmdArgList args, ConnectionContext* cntx); static void ZRank(CmdArgList args, ConnectionContext* cntx); static void ZRem(CmdArgList args, ConnectionContext* cntx); @@ -84,7 +88,7 @@ class ZSetFamily { static void ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx); static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx); static bool ParseRangeByScoreParams(CmdArgList args, RangeParams* params); - + static void ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cntx); static OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor); static OpResult OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members); @@ -93,6 +97,8 @@ class ZSetFamily { using MScoreResponse = std::vector>; static OpResult OpMScore(const OpArgs& op_args, std::string_view key, ArgSlice members); + static OpResult OpPopCount(const ZRangeSpec& range_spec, const OpArgs& op_args, + std::string_view key); static OpResult OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, std::string_view key); static OpResult OpRemRange(const OpArgs& op_args, std::string_view key, diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 141394cbe..33b1ec125 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -267,4 +267,47 @@ TEST_F(ZSetFamilyTest, ZAddBug148) { EXPECT_THAT(resp, IntArg(1)); } +TEST_F(ZSetFamilyTest, ZPopMin) { + auto resp = Run({"zadd", "key", "1", "a", "2", "b", "3", "c", "4", "d", "5", "e"}); + EXPECT_THAT(resp, IntArg(5)); + + resp = Run({"zpopmin", "key", "2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b")); + + resp = Run({"zpopmin", "key", "-1"}); + ASSERT_THAT(resp, ErrArg("value is out of range, must be positive")); + + resp = Run({"zpopmin", "key", "1"}); + ASSERT_THAT(resp, "c"); + + resp = Run({"zpopmin", "key", "3"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), ElementsAre("d", "e")); + + resp = Run({"zpopmin", "key", "1"}); + ASSERT_THAT(resp, ArrLen(0)); +} + +TEST_F(ZSetFamilyTest, ZPopMax) { + auto resp = Run({"zadd", "key", "1", "a", "2", "b", "3", "c", "4", "d", "5", "e"}); + EXPECT_THAT(resp, IntArg(5)); + + resp = Run({"zpopmax", "key", "2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), ElementsAre("e", "d")); + + resp = Run({"zpopmax", "key", "-1"}); + ASSERT_THAT(resp, ErrArg("value is out of range, must be positive")); + + resp = Run({"zpopmax", "key", "1"}); + ASSERT_THAT(resp, "c"); + + resp = Run({"zpopmax", "key", "3"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), ElementsAre("b", "a")); + + resp = Run({"zpopmax", "key", "1"}); + ASSERT_THAT(resp, ArrLen(0)); +} } // namespace dfly