Implement ZUNIONSTORE.

Fix some bugs in transactional framework to support irregular commands.
Lay out groundwork for supporting XXX-STORE other commands in zsets.
This commit is contained in:
Roman Gershman 2022-05-07 20:38:18 +03:00
parent 3a4c36c1f2
commit c34e7c6d44
9 changed files with 408 additions and 131 deletions

View file

@ -162,8 +162,14 @@ char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len)
RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) { RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
} }
void RedisReplyBuilder::SendError(string_view str, std::string_view type) { void RedisReplyBuilder::SendError(string_view str, string_view err_type) {
err_count_[type.empty() ? str : type]++; if (err_type.empty()) {
err_type = str;
if (err_type == kSyntaxErr)
err_type = kSyntaxErrType;
}
err_count_[err_type]++;
if (str[0] == '-') { if (str[0] == '-') {
iovec v[] = {IoVec(str), IoVec(kCRLF)}; iovec v[] = {IoVec(str), IoVec(kCRLF)};

View file

@ -486,12 +486,11 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
if (IsTransactional(cid)) { if (IsTransactional(cid)) {
dist_trans.reset(new Transaction{cid, &shard_set_}); dist_trans.reset(new Transaction{cid, &shard_set_});
dfly_cntx->transaction = dist_trans.get();
OpStatus st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args); OpStatus st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args);
if (st != OpStatus::OK) if (st != OpStatus::OK)
return (*cntx)->SendError(st); return (*cntx)->SendError(st);
dfly_cntx->transaction = dist_trans.get();
dfly_cntx->last_command_debug.shards_count = dfly_cntx->transaction->unique_shard_cnt(); dfly_cntx->last_command_debug.shards_count = dfly_cntx->transaction->unique_shard_cnt();
} else { } else {
dfly_cntx->transaction = nullptr; dfly_cntx->transaction = nullptr;

View file

@ -305,8 +305,9 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, ArgSlice v
auto* es = op_args.shard; auto* es = op_args.shard;
auto& db_slice = es->db_slice(); auto& db_slice = es->db_slice();
// overwrite - meaning we run in the context of 2-hop operation and we had already // overwrite - meaning we run in the context of 2-hop operation and we want
// ensured that the key exists. // to overwrite the key. However, if the set is empty it means we should delete the
// key if it exists.
if (overwrite && vals.empty()) { if (overwrite && vals.empty()) {
auto it = db_slice.FindExt(op_args.db_ind, key).first; auto it = db_slice.FindExt(op_args.db_ind, key).first;
db_slice.Del(op_args.db_ind, it); db_slice.Del(op_args.db_ind, it);

View file

@ -90,6 +90,8 @@ TEST_F(SetFamilyTest, SInter) {
resp = Run({"sinter", "x", "y"}); resp = Run({"sinter", "x", "y"});
ASSERT_EQ(1, GetDebugInfo("IO0").shards_count); ASSERT_EQ(1, GetDebugInfo("IO0").shards_count);
EXPECT_THAT(resp, ErrArg("WRONGTYPE Operation against a key")); EXPECT_THAT(resp, ErrArg("WRONGTYPE Operation against a key"));
resp = Run({"sinterstore", "none1", "none2"});
EXPECT_THAT(resp, IntArg(0));
} }
TEST_F(SetFamilyTest, SMove) { TEST_F(SetFamilyTest, SMove) {

View file

@ -103,6 +103,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
bool incremental_locking = multi_ && multi_->incremental; bool incremental_locking = multi_ && multi_->incremental;
bool single_key = !multi_ && key_index.HasSingleKey(); bool single_key = !multi_ && key_index.HasSingleKey();
bool needs_reverse_mapping = cid_->opt_mask() & CO::REVERSE_MAPPING;
if (single_key) { if (single_key) {
DCHECK_GT(key_index.step, 0u); DCHECK_GT(key_index.step, 0u);
@ -118,6 +119,12 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
unique_shard_cnt_ = 1; unique_shard_cnt_ = 1;
unique_shard_id_ = Shard(key, ess_->size()); unique_shard_id_ = Shard(key, ess_->size());
if (needs_reverse_mapping) {
reverse_index_.resize(args_.size());
for (unsigned j = 0; j < reverse_index_.size(); ++j) {
reverse_index_[j] = j + key_index.start - 1;
}
}
return OpStatus::OK; return OpStatus::OK;
} }
@ -137,7 +144,6 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
// and regular commands. // and regular commands.
IntentLock::Mode mode = IntentLock::EXCLUSIVE; IntentLock::Mode mode = IntentLock::EXCLUSIVE;
bool should_record_locks = false; bool should_record_locks = false;
bool needs_reverse_mapping = cid_->opt_mask() & CO::REVERSE_MAPPING;
if (multi_) { if (multi_) {
mode = Mode(); mode = Mode();
@ -148,11 +154,12 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
if (key_index.bonus) { // additional one-of key. if (key_index.bonus) { // additional one-of key.
DCHECK(key_index.step == 1); DCHECK(key_index.step == 1);
DCHECK(!needs_reverse_mapping);
string_view key = ArgS(args, key_index.bonus); string_view key = ArgS(args, key_index.bonus);
uint32_t sid = Shard(key, shard_data_.size()); uint32_t sid = Shard(key, shard_data_.size());
shard_index[sid].args.push_back(key); shard_index[sid].args.push_back(key);
if (needs_reverse_mapping)
shard_index[sid].original_index.push_back(key_index.bonus - 1);
} }
for (unsigned i = key_index.start; i < key_index.end; ++i) { for (unsigned i = key_index.start; i < key_index.end; ++i) {
@ -183,7 +190,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
args_.resize(key_index.num_args()); args_.resize(key_index.num_args());
// we need reverse index only for blocking commands or commands like MSET. // we need reverse index only for some commands (MSET etc).
if (needs_reverse_mapping) if (needs_reverse_mapping)
reverse_index_.resize(args_.size()); reverse_index_.resize(args_.size());
@ -213,20 +220,25 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
++unique_shard_cnt_; ++unique_shard_cnt_;
unique_shard_id_ = i; unique_shard_id_ = i;
uint32_t orig_indx = 0;
for (size_t j = 0; j < si.args.size(); ++j) { for (size_t j = 0; j < si.args.size(); ++j) {
*next_arg = si.args[j]; *next_arg = si.args[j];
if (needs_reverse_mapping) { if (needs_reverse_mapping) {
*rev_indx_it++ = si.original_index[orig_indx]; *rev_indx_it++ = si.original_index[j];
} }
++next_arg; ++next_arg;
++orig_indx;
} }
} }
CHECK(next_arg == args_.end()); CHECK(next_arg == args_.end());
DVLOG(1) << "InitByArgs " << DebugId() << " " << args_.front(); DVLOG(1) << "InitByArgs " << DebugId() << " " << args_.front();
// validation
if (needs_reverse_mapping) {
for (size_t i = 0; i < args_.size(); ++i) {
DCHECK_EQ(args_[i], ArgS(args, 1 + reverse_index_[i])); // 1 for the commandname.
}
}
if (unique_shard_cnt_ == 1) { if (unique_shard_cnt_ == 1) {
PerShardData* sd; PerShardData* sd;
if (multi_) { if (multi_) {
@ -892,11 +904,14 @@ ArgSlice Transaction::ShardArgsInShard(ShardId sid) const {
return ArgSlice{args_.data() + sd.arg_start, sd.arg_count}; return ArgSlice{args_.data() + sd.arg_start, sd.arg_count};
} }
// from local index back to original arg index skipping the command.
// i.e. returns (first_key_pos -1) or bigger.
size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const { size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const {
if (unique_shard_cnt_ == 1) if (unique_shard_cnt_ == 1) // mget: 0->0, 1->1. zunionstore has 0->2
return arg_index; return reverse_index_[arg_index];
return reverse_index_[shard_data_[shard_id].arg_start + arg_index]; const auto& sd = shard_data_[shard_id];
return reverse_index_[sd.arg_start + arg_index];
} }
bool Transaction::WaitOnWatch(const time_point& tp) { bool Transaction::WaitOnWatch(const time_point& tp) {
@ -1164,10 +1179,13 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
if (args.size() < 3) { if (args.size() < 3) {
return OpStatus::SYNTAX_ERR; return OpStatus::SYNTAX_ERR;
} }
string_view num(ArgS(args, 2)); string_view num(ArgS(args, 2));
if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0 || if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0)
size_t(num_custom_keys) + 3 > args.size())
return OpStatus::INVALID_INT; return OpStatus::INVALID_INT;
if (size_t(num_custom_keys) + 3 > args.size())
return OpStatus::SYNTAX_ERR;
} }
if (cid->first_key_pos() > 0) { if (cid->first_key_pos() > 0) {

View file

@ -70,9 +70,8 @@ class Transaction {
// Runs in engine thread // Runs in engine thread
ArgSlice ShardArgsInShard(ShardId sid) const; ArgSlice ShardArgsInShard(ShardId sid) const;
// Maps the index in ShardKeys(shard_id) slice back to the index in the original array passed to // Maps the index in ShardArgsInShard(shard_id) slice back to the index
// InitByArgs. // in the original array passed to InitByArgs.
// Runs in the coordinator thread.
size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const; size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const;
//! Returns true if the transaction spans this shard_id. //! Returns true if the transaction spans this shard_id.

View file

@ -14,6 +14,7 @@ extern "C" {
#include <double-conversion/double-to-string.h> #include <double-conversion/double-to-string.h>
#include "base/logging.h" #include "base/logging.h"
#include "base/stl_util.h"
#include "facade/error.h" #include "facade/error.h"
#include "server/command_registry.h" #include "server/command_registry.h"
#include "server/conn_context.h" #include "server/conn_context.h"
@ -78,15 +79,21 @@ zlexrangespec GetLexRange(bool reverse, const ZSetFamily::LexInterval& li) {
return range; return range;
} }
OpResult<PrimeIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_view key, struct ZParams {
unsigned flags = 0; // mask of ZADD_IN_ macros.
bool ch = false; // Corresponds to CH option.
bool override = false;
};
OpResult<PrimeIterator> FindZEntry(const ZParams& zparams, const OpArgs& op_args, string_view key,
size_t member_len) { size_t member_len) {
auto& db_slice = op_args.shard->db_slice(); auto& db_slice = op_args.shard->db_slice();
if (flags & ZADD_IN_XX) { if (zparams.flags & ZADD_IN_XX) {
return db_slice.Find(op_args.db_ind, key, OBJ_ZSET); return db_slice.Find(op_args.db_ind, key, OBJ_ZSET);
} }
auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key); auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key);
if (inserted) { if (inserted || zparams.override) {
robj* zobj = nullptr; robj* zobj = nullptr;
if (member_len > kMaxListPackValue) { if (member_len > kMaxListPackValue) {
@ -96,12 +103,16 @@ OpResult<PrimeIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string
} }
DVLOG(2) << "Created zset " << zobj->ptr; DVLOG(2) << "Created zset " << zobj->ptr;
if (!inserted) {
db_slice.PreUpdate(op_args.db_ind, it);
}
it->second.ImportRObj(zobj); it->second.ImportRObj(zobj);
} else { } else {
if (it->second.ObjType() != OBJ_ZSET) if (it->second.ObjType() != OBJ_ZSET)
return OpStatus::WRONG_TYPE; return OpStatus::WRONG_TYPE;
db_slice.PreUpdate(op_args.db_ind, it); db_slice.PreUpdate(op_args.db_ind, it);
} }
return it; return it;
} }
@ -562,6 +573,193 @@ bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) {
return true; return true;
} }
void SendAtLeastOneKeyError(ConnectionContext* cntx) {
string name = cntx->cid->name();
absl::AsciiStrToLower(&name);
(*cntx)->SendError(absl::StrCat("at least 1 input key is needed for ", name));
}
enum class AggType : uint8_t { SUM, MIN, MAX };
using ScoredMap = absl::flat_hash_map<std::string, double>;
ScoredMap FromObject(const CompactObj& co, double weight) {
robj* obj = co.AsRObj();
ZSetFamily::RangeParams params;
params.with_scores = true;
IntervalVisitor vis(Action::RANGE, params, obj);
vis(ZSetFamily::IndexInterval(0, -1));
ZSetFamily::ScoredArray arr = vis.PopResult();
ScoredMap res;
res.reserve(arr.size());
for (auto& elem : arr) {
elem.second *= weight;
res.emplace(move(elem));
}
return res;
}
double Aggregate(double v1, double v2, AggType atype) {
switch (atype) {
case AggType::SUM:
return v1 + v2;
case AggType::MAX:
return max(v1, v2);
case AggType::MIN:
return min(v1, v2);
}
return 0;
}
// the result is in the destination.
void UnionScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
ScoredMap* target = dest;
ScoredMap* iter = src;
if (iter->size() > target->size())
swap(target, iter);
for (const auto& elem : *iter) {
auto [it, inserted] = target->emplace(elem);
if (!inserted) {
it->second = Aggregate(it->second, elem.second, agg_type);
}
}
if (target != dest)
dest->swap(*src);
}
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
unsigned start = 0;
if (keys.front() == dest) {
++start;
}
auto& db_slice = shard->db_slice();
vector<pair<PrimeIterator, double>> it_arr(keys.size() - start);
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
return OpStatus::OK; // return empty map
for (unsigned j = start; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->db_index(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
if (!it_res)
continue;
// first global index is 2 after {destkey, numkeys}
unsigned src_indx = j - start;
unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2;
DCHECK_LT(windex, weights.size());
it_arr[src_indx] = {*it_res, weights[windex]};
}
ScoredMap result;
for (auto it = it_arr.begin(); it != it_arr.end(); ++it) {
if (it->first.is_done())
continue;
ScoredMap sm = FromObject(it->first->second, it->second);
if (result.empty())
result.swap(sm);
else
UnionScoredMap(&result, &sm, agg_type);
}
return result;
}
using ScoredMemberView = std::pair<double, std::string_view>;
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
struct AddResult {
double new_score;
unsigned num_updated = 0;
bool is_nan = false;
};
OpResult<AddResult> OpAdd(const OpArgs& op_args, const ZParams& zparams, string_view key,
ScoredMemberSpan members) {
DCHECK(!members.empty() || zparams.override);
auto& db_slice = op_args.shard->db_slice();
if (zparams.override && members.empty()) {
auto it = db_slice.FindExt(op_args.db_ind, key).first;
db_slice.Del(op_args.db_ind, it);
return OpStatus::OK;
}
OpResult<PrimeIterator> res_it = FindZEntry(zparams, op_args, key, members.front().second.size());
if (!res_it)
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
unsigned added = 0;
unsigned updated = 0;
unsigned processed = 0;
sds& tmp_str = op_args.shard->tmp_str1;
double new_score = 0;
int retflags = 0;
OpStatus op_status = OpStatus::OK;
AddResult aresult;
for (size_t j = 0; j < members.size(); j++) {
const auto& m = members[j];
tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size());
int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score);
if (zparams.flags & ZADD_IN_INCR) {
if (retval == 0) {
CHECK_EQ(1u, members.size());
aresult.is_nan = true;
break;
}
if (retflags & ZADD_OUT_NOP) {
op_status = OpStatus::SKIPPED;
}
}
if (retflags & ZADD_OUT_ADDED)
added++;
if (retflags & ZADD_OUT_UPDATED)
updated++;
if (!(retflags & ZADD_OUT_NOP))
processed++;
}
DVLOG(2) << "ZAdd " << zobj->ptr;
res_it.value()->second.SyncRObj();
op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it);
if (zparams.flags & ZADD_IN_INCR) {
aresult.new_score = new_score;
} else {
aresult.num_updated = zparams.ch ? added + updated : added;
}
if (op_status != OpStatus::OK)
return op_status;
return aresult;
}
} // namespace } // namespace
void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) { void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
@ -631,33 +829,31 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
DCHECK(cntx->transaction); DCHECK(cntx->transaction);
absl::Span memb_sp{members.data(), members.size()}; absl::Span memb_sp{members.data(), members.size()};
AddResult add_result; auto cb = [&](Transaction* t, EngineShard* shard) {
auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus {
OpArgs op_args{shard, t->db_index()}; OpArgs op_args{shard, t->db_index()};
return OpAdd(zparams, op_args, key, memb_sp, &add_result); return OpAdd(op_args, zparams, key, memb_sp);
}; };
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); OpResult<AddResult> add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (status == OpStatus::WRONG_TYPE) { if (add_result.status() == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr); return (*cntx)->SendError(kWrongTypeErr);
} }
// KEY_NOTFOUND may happen in case of XX flag. // KEY_NOTFOUND may happen in case of XX flag.
if (status == OpStatus::KEY_NOTFOUND) { if (add_result.status() == OpStatus::KEY_NOTFOUND) {
if (zparams.flags & ZADD_IN_INCR) if (zparams.flags & ZADD_IN_INCR)
(*cntx)->SendNull(); (*cntx)->SendNull();
else else
(*cntx)->SendLong(0); (*cntx)->SendLong(0);
} else if (status == OpStatus::SKIPPED) { } else if (add_result.status() == OpStatus::SKIPPED) {
(*cntx)->SendNull(); (*cntx)->SendNull();
} else if (add_result.is_nan) { } else if (add_result->is_nan) {
(*cntx)->SendError(kScoreNaN); (*cntx)->SendError(kScoreNaN);
} else { } else {
if (zparams.flags & ZADD_IN_INCR) { if (zparams.flags & ZADD_IN_INCR) {
(*cntx)->SendDouble(add_result.new_score); (*cntx)->SendDouble(add_result->new_score);
} else { } else {
(*cntx)->SendLong(add_result.num_updated); (*cntx)->SendLong(add_result->num_updated);
} }
} }
} }
@ -725,31 +921,28 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
ZParams zparams; ZParams zparams;
zparams.flags = ZADD_IN_INCR; zparams.flags = ZADD_IN_INCR;
AddResult add_result; auto cb = [&](Transaction* t, EngineShard* shard) {
auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus {
OpArgs op_args{shard, t->db_index()}; OpArgs op_args{shard, t->db_index()};
return OpAdd(zparams, op_args, key, ScoredMemberSpan{&scored_member, 1}, &add_result); return OpAdd(op_args, zparams, key, ScoredMemberSpan{&scored_member, 1});
}; };
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); OpResult<AddResult> add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (status == OpStatus::WRONG_TYPE) { if (add_result.status() == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr); return (*cntx)->SendError(kWrongTypeErr);
} }
if (status == OpStatus::SKIPPED) { if (add_result.status() == OpStatus::SKIPPED) {
return (*cntx)->SendNull(); return (*cntx)->SendNull();
} }
if (add_result.is_nan) { if (add_result->is_nan) {
return (*cntx)->SendError(kScoreNaN); return (*cntx)->SendError(kScoreNaN);
} }
(*cntx)->SendDouble(add_result.new_score); (*cntx)->SendDouble(add_result->new_score);
} }
void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) { void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
} }
void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) { void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) {
@ -984,16 +1177,93 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
} }
void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) { void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* es) { string_view dest_key = ArgS(args, 1);
auto args = t->ShardArgsInShard(es->shard_id()); string_view num_str = ArgS(args, 2);
for (auto x : args) { uint32_t num_keys;
LOG(INFO) << "arg " << x; AggType agg_type = AggType::SUM;
// we parsed the structure before, when transaction has been initialized.
CHECK(absl::SimpleAtoi(num_str, &num_keys));
if (num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
DCHECK_GE(args.size(), 3 + num_keys);
vector<double> weights(num_keys, 1);
for (size_t i = 3 + num_keys; i < args.size(); ++i) {
ToUpper(&args[i]);
string_view arg = ArgS(args, i);
if (arg == "WEIGHTS") {
if (args.size() <= i + num_keys) {
return (*cntx)->SendError(kSyntaxErr);
}
for (unsigned j = 0; j < num_keys; ++j) {
string_view weight = ArgS(args, i + j + 1);
if (!absl::SimpleAtod(weight, &weights[j])) {
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
}
}
i += num_keys;
} else if (arg == "AGGREGATE") {
if (i + 2 != args.size()) {
return (*cntx)->SendError(kSyntaxErr);
}
ToUpper(&args[i + 1]);
string_view agg = ArgS(args, i + 1);
if (agg == "SUM") {
agg_type = AggType::SUM;
} else if (agg == "MIN") {
agg_type = AggType::MIN;
} else if (agg == "MAX") {
agg_type = AggType::MAX;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
break;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
}
vector<OpResult<ScoredMap>> maps(cntx->shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, agg_type, weights, false);
return OpStatus::OK;
};
cntx->transaction->Schedule();
cntx->transaction->Execute(std::move(cb), false);
ScoredMap result;
for (auto& op_res : maps) {
if (!op_res)
return (*cntx)->SendError(op_res.status());
UnionScoredMap(&result, &op_res.value(), agg_type);
}
ShardId dest_shard = Shard(dest_key, maps.size());
AddResult add_result;
vector<ScoredMemberView> smvec;
for (const auto& elem : result) {
smvec.emplace_back(elem.second, elem.first);
}
auto store_cb = [&](Transaction* t, EngineShard* shard) {
if (shard->shard_id() == dest_shard) {
ZParams zparams;
zparams.override = true;
add_result =
OpAdd(OpArgs{shard, t->db_index()}, zparams, dest_key, ScoredMemberSpan{smvec}).value();
} }
return OpStatus::OK; return OpStatus::OK;
}; };
OpStatus result = cntx->transaction->ScheduleSingleHop(std::move(cb)); cntx->transaction->Execute(std::move(store_cb), true);
(*cntx)->SendOk();
(*cntx)->SendLong(smvec.size());
} }
void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, string_view max_s, void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, string_view max_s,
@ -1202,68 +1472,6 @@ OpResult<StringVec> ZSetFamily::OpScan(const OpArgs& op_args, std::string_view k
return res; return res;
} }
OpStatus ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members, AddResult* add_result) {
DCHECK(!members.empty());
OpResult<PrimeIterator> res_it =
FindZEntry(zparams.flags, op_args, key, members.front().second.size());
if (!res_it)
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
unsigned added = 0;
unsigned updated = 0;
unsigned processed = 0;
sds& tmp_str = op_args.shard->tmp_str1;
double new_score = 0;
int retflags = 0;
OpStatus res = OpStatus::OK;
for (size_t j = 0; j < members.size(); j++) {
const auto& m = members[j];
tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size());
int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score);
if (zparams.flags & ZADD_IN_INCR) {
if (retval == 0) {
CHECK_EQ(1u, members.size());
add_result->is_nan = true;
break;
}
if (retflags & ZADD_OUT_NOP) {
res = OpStatus::SKIPPED;
}
}
if (retflags & ZADD_OUT_ADDED)
added++;
if (retflags & ZADD_OUT_UPDATED)
updated++;
if (!(retflags & ZADD_OUT_NOP))
processed++;
}
DVLOG(2) << "ZAdd " << zobj->ptr;
res_it.value()->second.SyncRObj();
op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it);
if (zparams.flags & ZADD_IN_INCR) {
add_result->new_score = new_score;
} else {
add_result->num_updated = zparams.ch ? added + updated : added;
}
return res;
}
OpResult<unsigned> ZSetFamily::OpRem(const OpArgs& op_args, string_view key, ArgSlice members) { OpResult<unsigned> ZSetFamily::OpRem(const OpArgs& op_args, string_view key, ArgSlice members) {
auto& db_slice = op_args.shard->db_slice(); auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_ind, key, OBJ_ZSET); OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_ind, key, OBJ_ZSET);
@ -1496,11 +1704,13 @@ OpResult<unsigned> ZSetFamily::OpLexCount(const OpArgs& op_args, string_view key
#define HFUNC(x) SetHandler(&ZSetFamily::x) #define HFUNC(x) SetHandler(&ZSetFamily::x)
void ZSetFamily::Register(CommandRegistry* registry) { void ZSetFamily::Register(CommandRegistry* registry) {
constexpr uint32_t kUnionMask = CO::WRITE | CO::DESTINATION_KEY | CO::REVERSE_MAPPING;
*registry << CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd) *registry << CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd)
<< CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard) << CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard)
<< CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount) << CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount)
<< CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy) << CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy)
<< CI{"ZINTERSTORE", CO::WRITE | CO::DESTINATION_KEY, -4, 1, 1, 1}.HFUNC(ZInterStore) << CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore)
<< CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount) << CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount)
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem) << CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange) << CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange)
@ -1515,7 +1725,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRevRangeByScore) << CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRevRangeByScore)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank) << CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank)
<< CI{"ZSCAN", CO::READONLY | CO::RANDOM, -3, 1, 1, 1}.HFUNC(ZScan) << CI{"ZSCAN", CO::READONLY | CO::RANDOM, -3, 1, 1, 1}.HFUNC(ZScan)
<< CI{"ZUNIONSTORE", CO::WRITE | CO::DESTINATION_KEY, -4, 3, 3, 1}.HFUNC(ZUnionStore); << CI{"ZUNIONSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZUnionStore);
} }
} // namespace dfly } // namespace dfly

View file

@ -86,23 +86,6 @@ class ZSetFamily {
static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor); static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor);
struct ZParams {
unsigned flags = 0; // mask of ZADD_IN_ macros.
bool ch = false; // Corresponds to CH option.
};
using ScoredMemberView = std::pair<double, std::string_view>;
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
struct AddResult {
double new_score;
unsigned num_updated = 0;
bool is_nan = false;
};
static facade::OpStatus OpAdd(const ZParams& zparams, const OpArgs& op_args, std::string_view key,
ScoredMemberSpan members, AddResult* add_result);
static OpResult<unsigned> OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members); static OpResult<unsigned> OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members);
static OpResult<double> OpScore(const OpArgs& op_args, std::string_view key, static OpResult<double> OpScore(const OpArgs& op_args, std::string_view key,
std::string_view member); std::string_view member);

View file

@ -167,4 +167,63 @@ TEST_F(ZSetFamilyTest, ZScan) {
EXPECT_EQ(100 * 2, scan_len); EXPECT_EQ(100 * 2, scan_len);
} }
TEST_F(ZSetFamilyTest, ZUnionStore) {
RespExpr resp;
resp = Run({"zunionstore", "key", "0"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"zunionstore", "key", "0", "aggregate", "sum"});
EXPECT_THAT(resp, ErrArg("at least 1 input key is needed"));
resp = Run({"zunionstore", "key", "-1", "aggregate", "sum"});
EXPECT_THAT(resp, ErrArg("out of range"));
resp = Run({"zunionstore", "key", "2", "foo", "bar", "weights", "1"});
EXPECT_THAT(resp, ErrArg("syntax error"));
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
resp = Run({"zunionstore", "key", "2", "z1", "z2"});
EXPECT_THAT(resp, IntArg(3));
resp = Run({"zrange", "key", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "c", "3", "b", "4"));
resp = Run({"zunionstore", "z1", "1", "z1"});
EXPECT_THAT(resp, IntArg(2));
resp = Run({"zunionstore", "z1", "2", "z1", "z2"});
EXPECT_THAT(resp, IntArg(3));
resp = Run({"zrange", "z1", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "c", "3", "b", "4"));
Run({"set", "foo", "bar"});
resp = Run({"zunionstore", "foo", "1", "z2"});
EXPECT_THAT(resp, IntArg(2));
resp = Run({"zrange", "foo", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("b", "2", "c", "3"));
}
TEST_F(ZSetFamilyTest, ZUnionStoreOpts) {
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
RespExpr resp;
EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"}));
resp = Run({"zrange", "a", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "8", "c", "9"));
resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"});
EXPECT_THAT(resp, IntArg(2));
resp = Run({"zrange", "z1", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "2", "b", "4"));
resp = Run({"zunionstore", "max", "2", "z1", "z2", "weights", "1", "0", "aggregate", "max"});
ASSERT_THAT(resp, IntArg(3));
resp = Run({"zrange", "max", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("c", "0", "a", "2", "b", "4"));
}
} // namespace dfly } // namespace dfly