diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 822aa1c5c..7e10978d0 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -162,8 +162,14 @@ char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len) RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) { } -void RedisReplyBuilder::SendError(string_view str, std::string_view type) { - err_count_[type.empty() ? str : type]++; +void RedisReplyBuilder::SendError(string_view str, string_view err_type) { + if (err_type.empty()) { + err_type = str; + if (err_type == kSyntaxErr) + err_type = kSyntaxErrType; + } + + err_count_[err_type]++; if (str[0] == '-') { iovec v[] = {IoVec(str), IoVec(kCRLF)}; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 6a1c126f0..fd3ceb720 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -486,12 +486,11 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) if (IsTransactional(cid)) { dist_trans.reset(new Transaction{cid, &shard_set_}); - dfly_cntx->transaction = dist_trans.get(); - OpStatus st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args); if (st != OpStatus::OK) return (*cntx)->SendError(st); + dfly_cntx->transaction = dist_trans.get(); dfly_cntx->last_command_debug.shards_count = dfly_cntx->transaction->unique_shard_cnt(); } else { dfly_cntx->transaction = nullptr; diff --git a/src/server/set_family.cc b/src/server/set_family.cc index af74f4a44..6d1345b9f 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -305,8 +305,9 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, ArgSlice v auto* es = op_args.shard; auto& db_slice = es->db_slice(); - // overwrite - meaning we run in the context of 2-hop operation and we had already - // ensured that the key exists. + // overwrite - meaning we run in the context of 2-hop operation and we want + // to overwrite the key. However, if the set is empty it means we should delete the + // key if it exists. if (overwrite && vals.empty()) { auto it = db_slice.FindExt(op_args.db_ind, key).first; db_slice.Del(op_args.db_ind, it); diff --git a/src/server/set_family_test.cc b/src/server/set_family_test.cc index d2c8fcb74..79b3a95bb 100644 --- a/src/server/set_family_test.cc +++ b/src/server/set_family_test.cc @@ -90,6 +90,8 @@ TEST_F(SetFamilyTest, SInter) { resp = Run({"sinter", "x", "y"}); ASSERT_EQ(1, GetDebugInfo("IO0").shards_count); EXPECT_THAT(resp, ErrArg("WRONGTYPE Operation against a key")); + resp = Run({"sinterstore", "none1", "none2"}); + EXPECT_THAT(resp, IntArg(0)); } TEST_F(SetFamilyTest, SMove) { diff --git a/src/server/transaction.cc b/src/server/transaction.cc index fa53c89ab..7a812c584 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -103,6 +103,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { bool incremental_locking = multi_ && multi_->incremental; bool single_key = !multi_ && key_index.HasSingleKey(); + bool needs_reverse_mapping = cid_->opt_mask() & CO::REVERSE_MAPPING; if (single_key) { DCHECK_GT(key_index.step, 0u); @@ -118,6 +119,12 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { unique_shard_cnt_ = 1; unique_shard_id_ = Shard(key, ess_->size()); + if (needs_reverse_mapping) { + reverse_index_.resize(args_.size()); + for (unsigned j = 0; j < reverse_index_.size(); ++j) { + reverse_index_[j] = j + key_index.start - 1; + } + } return OpStatus::OK; } @@ -137,7 +144,6 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { // and regular commands. IntentLock::Mode mode = IntentLock::EXCLUSIVE; bool should_record_locks = false; - bool needs_reverse_mapping = cid_->opt_mask() & CO::REVERSE_MAPPING; if (multi_) { mode = Mode(); @@ -148,11 +154,12 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { if (key_index.bonus) { // additional one-of key. DCHECK(key_index.step == 1); - DCHECK(!needs_reverse_mapping); string_view key = ArgS(args, key_index.bonus); uint32_t sid = Shard(key, shard_data_.size()); shard_index[sid].args.push_back(key); + if (needs_reverse_mapping) + shard_index[sid].original_index.push_back(key_index.bonus - 1); } for (unsigned i = key_index.start; i < key_index.end; ++i) { @@ -183,7 +190,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { args_.resize(key_index.num_args()); - // we need reverse index only for blocking commands or commands like MSET. + // we need reverse index only for some commands (MSET etc). if (needs_reverse_mapping) reverse_index_.resize(args_.size()); @@ -213,20 +220,25 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { ++unique_shard_cnt_; unique_shard_id_ = i; - uint32_t orig_indx = 0; for (size_t j = 0; j < si.args.size(); ++j) { *next_arg = si.args[j]; if (needs_reverse_mapping) { - *rev_indx_it++ = si.original_index[orig_indx]; + *rev_indx_it++ = si.original_index[j]; } ++next_arg; - ++orig_indx; } } CHECK(next_arg == args_.end()); DVLOG(1) << "InitByArgs " << DebugId() << " " << args_.front(); + // validation + if (needs_reverse_mapping) { + for (size_t i = 0; i < args_.size(); ++i) { + DCHECK_EQ(args_[i], ArgS(args, 1 + reverse_index_[i])); // 1 for the commandname. + } + } + if (unique_shard_cnt_ == 1) { PerShardData* sd; if (multi_) { @@ -892,11 +904,14 @@ ArgSlice Transaction::ShardArgsInShard(ShardId sid) const { return ArgSlice{args_.data() + sd.arg_start, sd.arg_count}; } +// from local index back to original arg index skipping the command. +// i.e. returns (first_key_pos -1) or bigger. size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const { - if (unique_shard_cnt_ == 1) - return arg_index; + if (unique_shard_cnt_ == 1) // mget: 0->0, 1->1. zunionstore has 0->2 + return reverse_index_[arg_index]; - return reverse_index_[shard_data_[shard_id].arg_start + arg_index]; + const auto& sd = shard_data_[shard_id]; + return reverse_index_[sd.arg_start + arg_index]; } bool Transaction::WaitOnWatch(const time_point& tp) { @@ -1164,10 +1179,13 @@ OpResult DetermineKeys(const CommandId* cid, CmdArgList args) { if (args.size() < 3) { return OpStatus::SYNTAX_ERR; } + string_view num(ArgS(args, 2)); - if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0 || - size_t(num_custom_keys) + 3 > args.size()) + if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0) return OpStatus::INVALID_INT; + + if (size_t(num_custom_keys) + 3 > args.size()) + return OpStatus::SYNTAX_ERR; } if (cid->first_key_pos() > 0) { diff --git a/src/server/transaction.h b/src/server/transaction.h index bc0331c6a..7a72eb12a 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -70,9 +70,8 @@ class Transaction { // Runs in engine thread ArgSlice ShardArgsInShard(ShardId sid) const; - // Maps the index in ShardKeys(shard_id) slice back to the index in the original array passed to - // InitByArgs. - // Runs in the coordinator thread. + // Maps the index in ShardArgsInShard(shard_id) slice back to the index + // in the original array passed to InitByArgs. size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const; //! Returns true if the transaction spans this shard_id. diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index aa2ac87e0..e05596d92 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -14,6 +14,7 @@ extern "C" { #include #include "base/logging.h" +#include "base/stl_util.h" #include "facade/error.h" #include "server/command_registry.h" #include "server/conn_context.h" @@ -78,15 +79,21 @@ zlexrangespec GetLexRange(bool reverse, const ZSetFamily::LexInterval& li) { return range; } -OpResult FindZEntry(unsigned flags, const OpArgs& op_args, string_view key, +struct ZParams { + unsigned flags = 0; // mask of ZADD_IN_ macros. + bool ch = false; // Corresponds to CH option. + bool override = false; +}; + +OpResult FindZEntry(const ZParams& zparams, const OpArgs& op_args, string_view key, size_t member_len) { auto& db_slice = op_args.shard->db_slice(); - if (flags & ZADD_IN_XX) { + if (zparams.flags & ZADD_IN_XX) { return db_slice.Find(op_args.db_ind, key, OBJ_ZSET); } auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key); - if (inserted) { + if (inserted || zparams.override) { robj* zobj = nullptr; if (member_len > kMaxListPackValue) { @@ -96,12 +103,16 @@ OpResult FindZEntry(unsigned flags, const OpArgs& op_args, string } DVLOG(2) << "Created zset " << zobj->ptr; + if (!inserted) { + db_slice.PreUpdate(op_args.db_ind, it); + } it->second.ImportRObj(zobj); } else { if (it->second.ObjType() != OBJ_ZSET) return OpStatus::WRONG_TYPE; db_slice.PreUpdate(op_args.db_ind, it); } + return it; } @@ -562,6 +573,193 @@ bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) { return true; } +void SendAtLeastOneKeyError(ConnectionContext* cntx) { + string name = cntx->cid->name(); + absl::AsciiStrToLower(&name); + (*cntx)->SendError(absl::StrCat("at least 1 input key is needed for ", name)); +} + +enum class AggType : uint8_t { SUM, MIN, MAX }; +using ScoredMap = absl::flat_hash_map; + +ScoredMap FromObject(const CompactObj& co, double weight) { + robj* obj = co.AsRObj(); + ZSetFamily::RangeParams params; + params.with_scores = true; + IntervalVisitor vis(Action::RANGE, params, obj); + vis(ZSetFamily::IndexInterval(0, -1)); + + ZSetFamily::ScoredArray arr = vis.PopResult(); + ScoredMap res; + res.reserve(arr.size()); + + for (auto& elem : arr) { + elem.second *= weight; + res.emplace(move(elem)); + } + + return res; +} + +double Aggregate(double v1, double v2, AggType atype) { + switch (atype) { + case AggType::SUM: + return v1 + v2; + case AggType::MAX: + return max(v1, v2); + case AggType::MIN: + return min(v1, v2); + } + return 0; +} + +// the result is in the destination. +void UnionScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) { + ScoredMap* target = dest; + ScoredMap* iter = src; + + if (iter->size() > target->size()) + swap(target, iter); + + for (const auto& elem : *iter) { + auto [it, inserted] = target->emplace(elem); + if (!inserted) { + it->second = Aggregate(it->second, elem.second, agg_type); + } + } + + if (target != dest) + dest->swap(*src); +} + +OpResult OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type, + const vector& weights, bool store) { + ArgSlice keys = t->ShardArgsInShard(shard->shard_id()); + DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end()); + DCHECK(!keys.empty()); + + unsigned start = 0; + + if (keys.front() == dest) { + ++start; + } + + auto& db_slice = shard->db_slice(); + vector> it_arr(keys.size() - start); + if (it_arr.empty()) // could be when only the dest key is hosted in this shard + return OpStatus::OK; // return empty map + + for (unsigned j = start; j < keys.size(); ++j) { + auto it_res = db_slice.Find(t->db_index(), keys[j], OBJ_ZSET); + if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1. + return it_res.status(); + if (!it_res) + continue; + + // first global index is 2 after {destkey, numkeys} + unsigned src_indx = j - start; + unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2; + DCHECK_LT(windex, weights.size()); + it_arr[src_indx] = {*it_res, weights[windex]}; + } + + ScoredMap result; + for (auto it = it_arr.begin(); it != it_arr.end(); ++it) { + if (it->first.is_done()) + continue; + + ScoredMap sm = FromObject(it->first->second, it->second); + if (result.empty()) + result.swap(sm); + else + UnionScoredMap(&result, &sm, agg_type); + } + + return result; +} + +using ScoredMemberView = std::pair; +using ScoredMemberSpan = absl::Span; + +struct AddResult { + double new_score; + unsigned num_updated = 0; + + bool is_nan = false; +}; + +OpResult OpAdd(const OpArgs& op_args, const ZParams& zparams, string_view key, + ScoredMemberSpan members) { + DCHECK(!members.empty() || zparams.override); + auto& db_slice = op_args.shard->db_slice(); + + if (zparams.override && members.empty()) { + auto it = db_slice.FindExt(op_args.db_ind, key).first; + db_slice.Del(op_args.db_ind, it); + return OpStatus::OK; + } + + OpResult res_it = FindZEntry(zparams, op_args, key, members.front().second.size()); + + if (!res_it) + return res_it.status(); + + robj* zobj = res_it.value()->second.AsRObj(); + + unsigned added = 0; + unsigned updated = 0; + unsigned processed = 0; + + sds& tmp_str = op_args.shard->tmp_str1; + double new_score = 0; + int retflags = 0; + + OpStatus op_status = OpStatus::OK; + AddResult aresult; + + for (size_t j = 0; j < members.size(); j++) { + const auto& m = members[j]; + tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size()); + + int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score); + + if (zparams.flags & ZADD_IN_INCR) { + if (retval == 0) { + CHECK_EQ(1u, members.size()); + + aresult.is_nan = true; + break; + } + + if (retflags & ZADD_OUT_NOP) { + op_status = OpStatus::SKIPPED; + } + } + + if (retflags & ZADD_OUT_ADDED) + added++; + if (retflags & ZADD_OUT_UPDATED) + updated++; + if (!(retflags & ZADD_OUT_NOP)) + processed++; + } + + DVLOG(2) << "ZAdd " << zobj->ptr; + + res_it.value()->second.SyncRObj(); + op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it); + + if (zparams.flags & ZADD_IN_INCR) { + aresult.new_score = new_score; + } else { + aresult.num_updated = zparams.ch ? added + updated : added; + } + + if (op_status != OpStatus::OK) + return op_status; + return aresult; +} + } // namespace void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) { @@ -631,33 +829,31 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) { DCHECK(cntx->transaction); absl::Span memb_sp{members.data(), members.size()}; - AddResult add_result; - - auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus { + auto cb = [&](Transaction* t, EngineShard* shard) { OpArgs op_args{shard, t->db_index()}; - return OpAdd(zparams, op_args, key, memb_sp, &add_result); + return OpAdd(op_args, zparams, key, memb_sp); }; - OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); - if (status == OpStatus::WRONG_TYPE) { + OpResult add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (add_result.status() == OpStatus::WRONG_TYPE) { return (*cntx)->SendError(kWrongTypeErr); } // KEY_NOTFOUND may happen in case of XX flag. - if (status == OpStatus::KEY_NOTFOUND) { + if (add_result.status() == OpStatus::KEY_NOTFOUND) { if (zparams.flags & ZADD_IN_INCR) (*cntx)->SendNull(); else (*cntx)->SendLong(0); - } else if (status == OpStatus::SKIPPED) { + } else if (add_result.status() == OpStatus::SKIPPED) { (*cntx)->SendNull(); - } else if (add_result.is_nan) { + } else if (add_result->is_nan) { (*cntx)->SendError(kScoreNaN); } else { if (zparams.flags & ZADD_IN_INCR) { - (*cntx)->SendDouble(add_result.new_score); + (*cntx)->SendDouble(add_result->new_score); } else { - (*cntx)->SendLong(add_result.num_updated); + (*cntx)->SendLong(add_result->num_updated); } } } @@ -725,31 +921,28 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) { ZParams zparams; zparams.flags = ZADD_IN_INCR; - AddResult add_result; - - auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus { + auto cb = [&](Transaction* t, EngineShard* shard) { OpArgs op_args{shard, t->db_index()}; - return OpAdd(zparams, op_args, key, ScoredMemberSpan{&scored_member, 1}, &add_result); + return OpAdd(op_args, zparams, key, ScoredMemberSpan{&scored_member, 1}); }; - OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); - if (status == OpStatus::WRONG_TYPE) { + OpResult add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (add_result.status() == OpStatus::WRONG_TYPE) { return (*cntx)->SendError(kWrongTypeErr); } - if (status == OpStatus::SKIPPED) { + if (add_result.status() == OpStatus::SKIPPED) { return (*cntx)->SendNull(); } - if (add_result.is_nan) { + if (add_result->is_nan) { return (*cntx)->SendError(kScoreNaN); } - (*cntx)->SendDouble(add_result.new_score); + (*cntx)->SendDouble(add_result->new_score); } void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) { - } void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) { @@ -984,16 +1177,93 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) { } void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) { - auto cb = [&](Transaction* t, EngineShard* es) { - auto args = t->ShardArgsInShard(es->shard_id()); - for (auto x : args) { - LOG(INFO) << "arg " << x; + string_view dest_key = ArgS(args, 1); + string_view num_str = ArgS(args, 2); + uint32_t num_keys; + AggType agg_type = AggType::SUM; + + // we parsed the structure before, when transaction has been initialized. + CHECK(absl::SimpleAtoi(num_str, &num_keys)); + if (num_keys == 0) { + return SendAtLeastOneKeyError(cntx); + } + + DCHECK_GE(args.size(), 3 + num_keys); + + vector weights(num_keys, 1); + for (size_t i = 3 + num_keys; i < args.size(); ++i) { + ToUpper(&args[i]); + string_view arg = ArgS(args, i); + if (arg == "WEIGHTS") { + if (args.size() <= i + num_keys) { + return (*cntx)->SendError(kSyntaxErr); + } + for (unsigned j = 0; j < num_keys; ++j) { + string_view weight = ArgS(args, i + j + 1); + if (!absl::SimpleAtod(weight, &weights[j])) { + return (*cntx)->SendError("weight value is not a float", kSyntaxErrType); + } + } + i += num_keys; + } else if (arg == "AGGREGATE") { + if (i + 2 != args.size()) { + return (*cntx)->SendError(kSyntaxErr); + } + ToUpper(&args[i + 1]); + + string_view agg = ArgS(args, i + 1); + if (agg == "SUM") { + agg_type = AggType::SUM; + } else if (agg == "MIN") { + agg_type = AggType::MIN; + } else if (agg == "MAX") { + agg_type = AggType::MAX; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + break; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + } + + vector> maps(cntx->shard_set->size()); + + auto cb = [&](Transaction* t, EngineShard* shard) { + maps[shard->shard_id()] = OpUnion(shard, t, dest_key, agg_type, weights, false); + return OpStatus::OK; + }; + + cntx->transaction->Schedule(); + + cntx->transaction->Execute(std::move(cb), false); + ScoredMap result; + + for (auto& op_res : maps) { + if (!op_res) + return (*cntx)->SendError(op_res.status()); + UnionScoredMap(&result, &op_res.value(), agg_type); + } + ShardId dest_shard = Shard(dest_key, maps.size()); + AddResult add_result; + vector smvec; + for (const auto& elem : result) { + smvec.emplace_back(elem.second, elem.first); + } + + auto store_cb = [&](Transaction* t, EngineShard* shard) { + if (shard->shard_id() == dest_shard) { + ZParams zparams; + zparams.override = true; + add_result = + OpAdd(OpArgs{shard, t->db_index()}, zparams, dest_key, ScoredMemberSpan{smvec}).value(); } return OpStatus::OK; }; - OpStatus result = cntx->transaction->ScheduleSingleHop(std::move(cb)); - (*cntx)->SendOk(); + cntx->transaction->Execute(std::move(store_cb), true); + + (*cntx)->SendLong(smvec.size()); } void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, string_view max_s, @@ -1202,68 +1472,6 @@ OpResult ZSetFamily::OpScan(const OpArgs& op_args, std::string_view k return res; } -OpStatus ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key, - ScoredMemberSpan members, AddResult* add_result) { - DCHECK(!members.empty()); - OpResult res_it = - FindZEntry(zparams.flags, op_args, key, members.front().second.size()); - - if (!res_it) - return res_it.status(); - - robj* zobj = res_it.value()->second.AsRObj(); - - unsigned added = 0; - unsigned updated = 0; - unsigned processed = 0; - - sds& tmp_str = op_args.shard->tmp_str1; - double new_score = 0; - int retflags = 0; - - OpStatus res = OpStatus::OK; - - for (size_t j = 0; j < members.size(); j++) { - const auto& m = members[j]; - tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size()); - - int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score); - - if (zparams.flags & ZADD_IN_INCR) { - if (retval == 0) { - CHECK_EQ(1u, members.size()); - - add_result->is_nan = true; - break; - } - - if (retflags & ZADD_OUT_NOP) { - res = OpStatus::SKIPPED; - } - } - - if (retflags & ZADD_OUT_ADDED) - added++; - if (retflags & ZADD_OUT_UPDATED) - updated++; - if (!(retflags & ZADD_OUT_NOP)) - processed++; - } - - DVLOG(2) << "ZAdd " << zobj->ptr; - - res_it.value()->second.SyncRObj(); - op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it); - - if (zparams.flags & ZADD_IN_INCR) { - add_result->new_score = new_score; - } else { - add_result->num_updated = zparams.ch ? added + updated : added; - } - - return res; -} - OpResult ZSetFamily::OpRem(const OpArgs& op_args, string_view key, ArgSlice members) { auto& db_slice = op_args.shard->db_slice(); OpResult res_it = db_slice.Find(op_args.db_ind, key, OBJ_ZSET); @@ -1496,11 +1704,13 @@ OpResult ZSetFamily::OpLexCount(const OpArgs& op_args, string_view key #define HFUNC(x) SetHandler(&ZSetFamily::x) void ZSetFamily::Register(CommandRegistry* registry) { + constexpr uint32_t kUnionMask = CO::WRITE | CO::DESTINATION_KEY | CO::REVERSE_MAPPING; + *registry << CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd) << CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard) << CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount) << CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy) - << CI{"ZINTERSTORE", CO::WRITE | CO::DESTINATION_KEY, -4, 1, 1, 1}.HFUNC(ZInterStore) + << CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore) << CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount) << CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem) << CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange) @@ -1515,7 +1725,7 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRevRangeByScore) << CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank) << CI{"ZSCAN", CO::READONLY | CO::RANDOM, -3, 1, 1, 1}.HFUNC(ZScan) - << CI{"ZUNIONSTORE", CO::WRITE | CO::DESTINATION_KEY, -4, 3, 3, 1}.HFUNC(ZUnionStore); + << CI{"ZUNIONSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZUnionStore); } } // namespace dfly diff --git a/src/server/zset_family.h b/src/server/zset_family.h index dda433e7c..801e3561e 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -86,23 +86,6 @@ class ZSetFamily { static OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor); - struct ZParams { - unsigned flags = 0; // mask of ZADD_IN_ macros. - bool ch = false; // Corresponds to CH option. - }; - - using ScoredMemberView = std::pair; - using ScoredMemberSpan = absl::Span; - - struct AddResult { - double new_score; - unsigned num_updated = 0; - - bool is_nan = false; - }; - - static facade::OpStatus OpAdd(const ZParams& zparams, const OpArgs& op_args, std::string_view key, - ScoredMemberSpan members, AddResult* add_result); static OpResult OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members); static OpResult OpScore(const OpArgs& op_args, std::string_view key, std::string_view member); diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 233c4a698..1acdf426e 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -167,4 +167,63 @@ TEST_F(ZSetFamilyTest, ZScan) { EXPECT_EQ(100 * 2, scan_len); } +TEST_F(ZSetFamilyTest, ZUnionStore) { + RespExpr resp; + + resp = Run({"zunionstore", "key", "0"}); + EXPECT_THAT(resp, ErrArg("wrong number of arguments")); + + resp = Run({"zunionstore", "key", "0", "aggregate", "sum"}); + EXPECT_THAT(resp, ErrArg("at least 1 input key is needed")); + resp = Run({"zunionstore", "key", "-1", "aggregate", "sum"}); + EXPECT_THAT(resp, ErrArg("out of range")); + resp = Run({"zunionstore", "key", "2", "foo", "bar", "weights", "1"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"})); + EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"})); + + resp = Run({"zunionstore", "key", "2", "z1", "z2"}); + EXPECT_THAT(resp, IntArg(3)); + resp = Run({"zrange", "key", "0", "-1", "withscores"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "c", "3", "b", "4")); + + resp = Run({"zunionstore", "z1", "1", "z1"}); + EXPECT_THAT(resp, IntArg(2)); + + resp = Run({"zunionstore", "z1", "2", "z1", "z2"}); + EXPECT_THAT(resp, IntArg(3)); + resp = Run({"zrange", "z1", "0", "-1", "withscores"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "c", "3", "b", "4")); + + Run({"set", "foo", "bar"}); + resp = Run({"zunionstore", "foo", "1", "z2"}); + EXPECT_THAT(resp, IntArg(2)); + resp = Run({"zrange", "foo", "0", "-1", "withscores"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("b", "2", "c", "3")); +} + +TEST_F(ZSetFamilyTest, ZUnionStoreOpts) { + EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"})); + EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"})); + RespExpr resp; + + EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"})); + resp = Run({"zrange", "a", "0", "-1", "withscores"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "8", "c", "9")); + + resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"}); + EXPECT_THAT(resp, IntArg(2)); + resp = Run({"zrange", "z1", "0", "-1", "withscores"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("a", "2", "b", "4")); + + resp = Run({"zunionstore", "max", "2", "z1", "z2", "weights", "1", "0", "aggregate", "max"}); + ASSERT_THAT(resp, IntArg(3)); + resp = Run({"zrange", "max", "0", "-1", "withscores"}); + EXPECT_THAT(resp.GetVec(), ElementsAre("c", "0", "a", "2", "b", "4")); +} + } // namespace dfly