mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
Implement ZUNIONSTORE.
Fix some bugs in transactional framework to support irregular commands. Lay out groundwork for supporting XXX-STORE other commands in zsets.
This commit is contained in:
parent
3a4c36c1f2
commit
c34e7c6d44
9 changed files with 408 additions and 131 deletions
|
@ -162,8 +162,14 @@ char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len)
|
|||
RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
|
||||
}
|
||||
|
||||
void RedisReplyBuilder::SendError(string_view str, std::string_view type) {
|
||||
err_count_[type.empty() ? str : type]++;
|
||||
void RedisReplyBuilder::SendError(string_view str, string_view err_type) {
|
||||
if (err_type.empty()) {
|
||||
err_type = str;
|
||||
if (err_type == kSyntaxErr)
|
||||
err_type = kSyntaxErrType;
|
||||
}
|
||||
|
||||
err_count_[err_type]++;
|
||||
|
||||
if (str[0] == '-') {
|
||||
iovec v[] = {IoVec(str), IoVec(kCRLF)};
|
||||
|
|
|
@ -486,12 +486,11 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
|
|||
|
||||
if (IsTransactional(cid)) {
|
||||
dist_trans.reset(new Transaction{cid, &shard_set_});
|
||||
dfly_cntx->transaction = dist_trans.get();
|
||||
|
||||
OpStatus st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args);
|
||||
if (st != OpStatus::OK)
|
||||
return (*cntx)->SendError(st);
|
||||
|
||||
dfly_cntx->transaction = dist_trans.get();
|
||||
dfly_cntx->last_command_debug.shards_count = dfly_cntx->transaction->unique_shard_cnt();
|
||||
} else {
|
||||
dfly_cntx->transaction = nullptr;
|
||||
|
|
|
@ -305,8 +305,9 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, ArgSlice v
|
|||
auto* es = op_args.shard;
|
||||
auto& db_slice = es->db_slice();
|
||||
|
||||
// overwrite - meaning we run in the context of 2-hop operation and we had already
|
||||
// ensured that the key exists.
|
||||
// overwrite - meaning we run in the context of 2-hop operation and we want
|
||||
// to overwrite the key. However, if the set is empty it means we should delete the
|
||||
// key if it exists.
|
||||
if (overwrite && vals.empty()) {
|
||||
auto it = db_slice.FindExt(op_args.db_ind, key).first;
|
||||
db_slice.Del(op_args.db_ind, it);
|
||||
|
|
|
@ -90,6 +90,8 @@ TEST_F(SetFamilyTest, SInter) {
|
|||
resp = Run({"sinter", "x", "y"});
|
||||
ASSERT_EQ(1, GetDebugInfo("IO0").shards_count);
|
||||
EXPECT_THAT(resp, ErrArg("WRONGTYPE Operation against a key"));
|
||||
resp = Run({"sinterstore", "none1", "none2"});
|
||||
EXPECT_THAT(resp, IntArg(0));
|
||||
}
|
||||
|
||||
TEST_F(SetFamilyTest, SMove) {
|
||||
|
|
|
@ -103,6 +103,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
|
||||
bool incremental_locking = multi_ && multi_->incremental;
|
||||
bool single_key = !multi_ && key_index.HasSingleKey();
|
||||
bool needs_reverse_mapping = cid_->opt_mask() & CO::REVERSE_MAPPING;
|
||||
|
||||
if (single_key) {
|
||||
DCHECK_GT(key_index.step, 0u);
|
||||
|
@ -118,6 +119,12 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
unique_shard_cnt_ = 1;
|
||||
unique_shard_id_ = Shard(key, ess_->size());
|
||||
|
||||
if (needs_reverse_mapping) {
|
||||
reverse_index_.resize(args_.size());
|
||||
for (unsigned j = 0; j < reverse_index_.size(); ++j) {
|
||||
reverse_index_[j] = j + key_index.start - 1;
|
||||
}
|
||||
}
|
||||
return OpStatus::OK;
|
||||
}
|
||||
|
||||
|
@ -137,7 +144,6 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
// and regular commands.
|
||||
IntentLock::Mode mode = IntentLock::EXCLUSIVE;
|
||||
bool should_record_locks = false;
|
||||
bool needs_reverse_mapping = cid_->opt_mask() & CO::REVERSE_MAPPING;
|
||||
|
||||
if (multi_) {
|
||||
mode = Mode();
|
||||
|
@ -148,11 +154,12 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
|
||||
if (key_index.bonus) { // additional one-of key.
|
||||
DCHECK(key_index.step == 1);
|
||||
DCHECK(!needs_reverse_mapping);
|
||||
|
||||
string_view key = ArgS(args, key_index.bonus);
|
||||
uint32_t sid = Shard(key, shard_data_.size());
|
||||
shard_index[sid].args.push_back(key);
|
||||
if (needs_reverse_mapping)
|
||||
shard_index[sid].original_index.push_back(key_index.bonus - 1);
|
||||
}
|
||||
|
||||
for (unsigned i = key_index.start; i < key_index.end; ++i) {
|
||||
|
@ -183,7 +190,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
|
||||
args_.resize(key_index.num_args());
|
||||
|
||||
// we need reverse index only for blocking commands or commands like MSET.
|
||||
// we need reverse index only for some commands (MSET etc).
|
||||
if (needs_reverse_mapping)
|
||||
reverse_index_.resize(args_.size());
|
||||
|
||||
|
@ -213,20 +220,25 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
|
||||
++unique_shard_cnt_;
|
||||
unique_shard_id_ = i;
|
||||
uint32_t orig_indx = 0;
|
||||
for (size_t j = 0; j < si.args.size(); ++j) {
|
||||
*next_arg = si.args[j];
|
||||
if (needs_reverse_mapping) {
|
||||
*rev_indx_it++ = si.original_index[orig_indx];
|
||||
*rev_indx_it++ = si.original_index[j];
|
||||
}
|
||||
++next_arg;
|
||||
++orig_indx;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(next_arg == args_.end());
|
||||
DVLOG(1) << "InitByArgs " << DebugId() << " " << args_.front();
|
||||
|
||||
// validation
|
||||
if (needs_reverse_mapping) {
|
||||
for (size_t i = 0; i < args_.size(); ++i) {
|
||||
DCHECK_EQ(args_[i], ArgS(args, 1 + reverse_index_[i])); // 1 for the commandname.
|
||||
}
|
||||
}
|
||||
|
||||
if (unique_shard_cnt_ == 1) {
|
||||
PerShardData* sd;
|
||||
if (multi_) {
|
||||
|
@ -892,11 +904,14 @@ ArgSlice Transaction::ShardArgsInShard(ShardId sid) const {
|
|||
return ArgSlice{args_.data() + sd.arg_start, sd.arg_count};
|
||||
}
|
||||
|
||||
// from local index back to original arg index skipping the command.
|
||||
// i.e. returns (first_key_pos -1) or bigger.
|
||||
size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const {
|
||||
if (unique_shard_cnt_ == 1)
|
||||
return arg_index;
|
||||
if (unique_shard_cnt_ == 1) // mget: 0->0, 1->1. zunionstore has 0->2
|
||||
return reverse_index_[arg_index];
|
||||
|
||||
return reverse_index_[shard_data_[shard_id].arg_start + arg_index];
|
||||
const auto& sd = shard_data_[shard_id];
|
||||
return reverse_index_[sd.arg_start + arg_index];
|
||||
}
|
||||
|
||||
bool Transaction::WaitOnWatch(const time_point& tp) {
|
||||
|
@ -1164,10 +1179,13 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
|
|||
if (args.size() < 3) {
|
||||
return OpStatus::SYNTAX_ERR;
|
||||
}
|
||||
|
||||
string_view num(ArgS(args, 2));
|
||||
if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0 ||
|
||||
size_t(num_custom_keys) + 3 > args.size())
|
||||
if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0)
|
||||
return OpStatus::INVALID_INT;
|
||||
|
||||
if (size_t(num_custom_keys) + 3 > args.size())
|
||||
return OpStatus::SYNTAX_ERR;
|
||||
}
|
||||
|
||||
if (cid->first_key_pos() > 0) {
|
||||
|
|
|
@ -70,9 +70,8 @@ class Transaction {
|
|||
// Runs in engine thread
|
||||
ArgSlice ShardArgsInShard(ShardId sid) const;
|
||||
|
||||
// Maps the index in ShardKeys(shard_id) slice back to the index in the original array passed to
|
||||
// InitByArgs.
|
||||
// Runs in the coordinator thread.
|
||||
// Maps the index in ShardArgsInShard(shard_id) slice back to the index
|
||||
// in the original array passed to InitByArgs.
|
||||
size_t ReverseArgIndex(ShardId shard_id, size_t arg_index) const;
|
||||
|
||||
//! Returns true if the transaction spans this shard_id.
|
||||
|
|
|
@ -14,6 +14,7 @@ extern "C" {
|
|||
#include <double-conversion/double-to-string.h>
|
||||
|
||||
#include "base/logging.h"
|
||||
#include "base/stl_util.h"
|
||||
#include "facade/error.h"
|
||||
#include "server/command_registry.h"
|
||||
#include "server/conn_context.h"
|
||||
|
@ -78,15 +79,21 @@ zlexrangespec GetLexRange(bool reverse, const ZSetFamily::LexInterval& li) {
|
|||
return range;
|
||||
}
|
||||
|
||||
OpResult<PrimeIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_view key,
|
||||
struct ZParams {
|
||||
unsigned flags = 0; // mask of ZADD_IN_ macros.
|
||||
bool ch = false; // Corresponds to CH option.
|
||||
bool override = false;
|
||||
};
|
||||
|
||||
OpResult<PrimeIterator> FindZEntry(const ZParams& zparams, const OpArgs& op_args, string_view key,
|
||||
size_t member_len) {
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
if (flags & ZADD_IN_XX) {
|
||||
if (zparams.flags & ZADD_IN_XX) {
|
||||
return db_slice.Find(op_args.db_ind, key, OBJ_ZSET);
|
||||
}
|
||||
|
||||
auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key);
|
||||
if (inserted) {
|
||||
if (inserted || zparams.override) {
|
||||
robj* zobj = nullptr;
|
||||
|
||||
if (member_len > kMaxListPackValue) {
|
||||
|
@ -96,12 +103,16 @@ OpResult<PrimeIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string
|
|||
}
|
||||
|
||||
DVLOG(2) << "Created zset " << zobj->ptr;
|
||||
if (!inserted) {
|
||||
db_slice.PreUpdate(op_args.db_ind, it);
|
||||
}
|
||||
it->second.ImportRObj(zobj);
|
||||
} else {
|
||||
if (it->second.ObjType() != OBJ_ZSET)
|
||||
return OpStatus::WRONG_TYPE;
|
||||
db_slice.PreUpdate(op_args.db_ind, it);
|
||||
}
|
||||
|
||||
return it;
|
||||
}
|
||||
|
||||
|
@ -562,6 +573,193 @@ bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void SendAtLeastOneKeyError(ConnectionContext* cntx) {
|
||||
string name = cntx->cid->name();
|
||||
absl::AsciiStrToLower(&name);
|
||||
(*cntx)->SendError(absl::StrCat("at least 1 input key is needed for ", name));
|
||||
}
|
||||
|
||||
enum class AggType : uint8_t { SUM, MIN, MAX };
|
||||
using ScoredMap = absl::flat_hash_map<std::string, double>;
|
||||
|
||||
ScoredMap FromObject(const CompactObj& co, double weight) {
|
||||
robj* obj = co.AsRObj();
|
||||
ZSetFamily::RangeParams params;
|
||||
params.with_scores = true;
|
||||
IntervalVisitor vis(Action::RANGE, params, obj);
|
||||
vis(ZSetFamily::IndexInterval(0, -1));
|
||||
|
||||
ZSetFamily::ScoredArray arr = vis.PopResult();
|
||||
ScoredMap res;
|
||||
res.reserve(arr.size());
|
||||
|
||||
for (auto& elem : arr) {
|
||||
elem.second *= weight;
|
||||
res.emplace(move(elem));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
double Aggregate(double v1, double v2, AggType atype) {
|
||||
switch (atype) {
|
||||
case AggType::SUM:
|
||||
return v1 + v2;
|
||||
case AggType::MAX:
|
||||
return max(v1, v2);
|
||||
case AggType::MIN:
|
||||
return min(v1, v2);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// the result is in the destination.
|
||||
void UnionScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
|
||||
ScoredMap* target = dest;
|
||||
ScoredMap* iter = src;
|
||||
|
||||
if (iter->size() > target->size())
|
||||
swap(target, iter);
|
||||
|
||||
for (const auto& elem : *iter) {
|
||||
auto [it, inserted] = target->emplace(elem);
|
||||
if (!inserted) {
|
||||
it->second = Aggregate(it->second, elem.second, agg_type);
|
||||
}
|
||||
}
|
||||
|
||||
if (target != dest)
|
||||
dest->swap(*src);
|
||||
}
|
||||
|
||||
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
|
||||
const vector<double>& weights, bool store) {
|
||||
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
|
||||
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
|
||||
DCHECK(!keys.empty());
|
||||
|
||||
unsigned start = 0;
|
||||
|
||||
if (keys.front() == dest) {
|
||||
++start;
|
||||
}
|
||||
|
||||
auto& db_slice = shard->db_slice();
|
||||
vector<pair<PrimeIterator, double>> it_arr(keys.size() - start);
|
||||
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
|
||||
return OpStatus::OK; // return empty map
|
||||
|
||||
for (unsigned j = start; j < keys.size(); ++j) {
|
||||
auto it_res = db_slice.Find(t->db_index(), keys[j], OBJ_ZSET);
|
||||
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
|
||||
return it_res.status();
|
||||
if (!it_res)
|
||||
continue;
|
||||
|
||||
// first global index is 2 after {destkey, numkeys}
|
||||
unsigned src_indx = j - start;
|
||||
unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2;
|
||||
DCHECK_LT(windex, weights.size());
|
||||
it_arr[src_indx] = {*it_res, weights[windex]};
|
||||
}
|
||||
|
||||
ScoredMap result;
|
||||
for (auto it = it_arr.begin(); it != it_arr.end(); ++it) {
|
||||
if (it->first.is_done())
|
||||
continue;
|
||||
|
||||
ScoredMap sm = FromObject(it->first->second, it->second);
|
||||
if (result.empty())
|
||||
result.swap(sm);
|
||||
else
|
||||
UnionScoredMap(&result, &sm, agg_type);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
using ScoredMemberView = std::pair<double, std::string_view>;
|
||||
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
|
||||
|
||||
struct AddResult {
|
||||
double new_score;
|
||||
unsigned num_updated = 0;
|
||||
|
||||
bool is_nan = false;
|
||||
};
|
||||
|
||||
OpResult<AddResult> OpAdd(const OpArgs& op_args, const ZParams& zparams, string_view key,
|
||||
ScoredMemberSpan members) {
|
||||
DCHECK(!members.empty() || zparams.override);
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
|
||||
if (zparams.override && members.empty()) {
|
||||
auto it = db_slice.FindExt(op_args.db_ind, key).first;
|
||||
db_slice.Del(op_args.db_ind, it);
|
||||
return OpStatus::OK;
|
||||
}
|
||||
|
||||
OpResult<PrimeIterator> res_it = FindZEntry(zparams, op_args, key, members.front().second.size());
|
||||
|
||||
if (!res_it)
|
||||
return res_it.status();
|
||||
|
||||
robj* zobj = res_it.value()->second.AsRObj();
|
||||
|
||||
unsigned added = 0;
|
||||
unsigned updated = 0;
|
||||
unsigned processed = 0;
|
||||
|
||||
sds& tmp_str = op_args.shard->tmp_str1;
|
||||
double new_score = 0;
|
||||
int retflags = 0;
|
||||
|
||||
OpStatus op_status = OpStatus::OK;
|
||||
AddResult aresult;
|
||||
|
||||
for (size_t j = 0; j < members.size(); j++) {
|
||||
const auto& m = members[j];
|
||||
tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size());
|
||||
|
||||
int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score);
|
||||
|
||||
if (zparams.flags & ZADD_IN_INCR) {
|
||||
if (retval == 0) {
|
||||
CHECK_EQ(1u, members.size());
|
||||
|
||||
aresult.is_nan = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (retflags & ZADD_OUT_NOP) {
|
||||
op_status = OpStatus::SKIPPED;
|
||||
}
|
||||
}
|
||||
|
||||
if (retflags & ZADD_OUT_ADDED)
|
||||
added++;
|
||||
if (retflags & ZADD_OUT_UPDATED)
|
||||
updated++;
|
||||
if (!(retflags & ZADD_OUT_NOP))
|
||||
processed++;
|
||||
}
|
||||
|
||||
DVLOG(2) << "ZAdd " << zobj->ptr;
|
||||
|
||||
res_it.value()->second.SyncRObj();
|
||||
op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it);
|
||||
|
||||
if (zparams.flags & ZADD_IN_INCR) {
|
||||
aresult.new_score = new_score;
|
||||
} else {
|
||||
aresult.num_updated = zparams.ch ? added + updated : added;
|
||||
}
|
||||
|
||||
if (op_status != OpStatus::OK)
|
||||
return op_status;
|
||||
return aresult;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
@ -631,33 +829,31 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
|
|||
DCHECK(cntx->transaction);
|
||||
|
||||
absl::Span memb_sp{members.data(), members.size()};
|
||||
AddResult add_result;
|
||||
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus {
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) {
|
||||
OpArgs op_args{shard, t->db_index()};
|
||||
return OpAdd(zparams, op_args, key, memb_sp, &add_result);
|
||||
return OpAdd(op_args, zparams, key, memb_sp);
|
||||
};
|
||||
|
||||
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
|
||||
if (status == OpStatus::WRONG_TYPE) {
|
||||
OpResult<AddResult> add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
|
||||
if (add_result.status() == OpStatus::WRONG_TYPE) {
|
||||
return (*cntx)->SendError(kWrongTypeErr);
|
||||
}
|
||||
|
||||
// KEY_NOTFOUND may happen in case of XX flag.
|
||||
if (status == OpStatus::KEY_NOTFOUND) {
|
||||
if (add_result.status() == OpStatus::KEY_NOTFOUND) {
|
||||
if (zparams.flags & ZADD_IN_INCR)
|
||||
(*cntx)->SendNull();
|
||||
else
|
||||
(*cntx)->SendLong(0);
|
||||
} else if (status == OpStatus::SKIPPED) {
|
||||
} else if (add_result.status() == OpStatus::SKIPPED) {
|
||||
(*cntx)->SendNull();
|
||||
} else if (add_result.is_nan) {
|
||||
} else if (add_result->is_nan) {
|
||||
(*cntx)->SendError(kScoreNaN);
|
||||
} else {
|
||||
if (zparams.flags & ZADD_IN_INCR) {
|
||||
(*cntx)->SendDouble(add_result.new_score);
|
||||
(*cntx)->SendDouble(add_result->new_score);
|
||||
} else {
|
||||
(*cntx)->SendLong(add_result.num_updated);
|
||||
(*cntx)->SendLong(add_result->num_updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -725,31 +921,28 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
|
|||
ZParams zparams;
|
||||
zparams.flags = ZADD_IN_INCR;
|
||||
|
||||
AddResult add_result;
|
||||
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus {
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) {
|
||||
OpArgs op_args{shard, t->db_index()};
|
||||
return OpAdd(zparams, op_args, key, ScoredMemberSpan{&scored_member, 1}, &add_result);
|
||||
return OpAdd(op_args, zparams, key, ScoredMemberSpan{&scored_member, 1});
|
||||
};
|
||||
|
||||
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
|
||||
if (status == OpStatus::WRONG_TYPE) {
|
||||
OpResult<AddResult> add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
|
||||
if (add_result.status() == OpStatus::WRONG_TYPE) {
|
||||
return (*cntx)->SendError(kWrongTypeErr);
|
||||
}
|
||||
|
||||
if (status == OpStatus::SKIPPED) {
|
||||
if (add_result.status() == OpStatus::SKIPPED) {
|
||||
return (*cntx)->SendNull();
|
||||
}
|
||||
|
||||
if (add_result.is_nan) {
|
||||
if (add_result->is_nan) {
|
||||
return (*cntx)->SendError(kScoreNaN);
|
||||
}
|
||||
|
||||
(*cntx)->SendDouble(add_result.new_score);
|
||||
(*cntx)->SendDouble(add_result->new_score);
|
||||
}
|
||||
|
||||
void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
||||
}
|
||||
|
||||
void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
@ -984,16 +1177,93 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
|
||||
auto cb = [&](Transaction* t, EngineShard* es) {
|
||||
auto args = t->ShardArgsInShard(es->shard_id());
|
||||
for (auto x : args) {
|
||||
LOG(INFO) << "arg " << x;
|
||||
string_view dest_key = ArgS(args, 1);
|
||||
string_view num_str = ArgS(args, 2);
|
||||
uint32_t num_keys;
|
||||
AggType agg_type = AggType::SUM;
|
||||
|
||||
// we parsed the structure before, when transaction has been initialized.
|
||||
CHECK(absl::SimpleAtoi(num_str, &num_keys));
|
||||
if (num_keys == 0) {
|
||||
return SendAtLeastOneKeyError(cntx);
|
||||
}
|
||||
|
||||
DCHECK_GE(args.size(), 3 + num_keys);
|
||||
|
||||
vector<double> weights(num_keys, 1);
|
||||
for (size_t i = 3 + num_keys; i < args.size(); ++i) {
|
||||
ToUpper(&args[i]);
|
||||
string_view arg = ArgS(args, i);
|
||||
if (arg == "WEIGHTS") {
|
||||
if (args.size() <= i + num_keys) {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
for (unsigned j = 0; j < num_keys; ++j) {
|
||||
string_view weight = ArgS(args, i + j + 1);
|
||||
if (!absl::SimpleAtod(weight, &weights[j])) {
|
||||
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
|
||||
}
|
||||
}
|
||||
i += num_keys;
|
||||
} else if (arg == "AGGREGATE") {
|
||||
if (i + 2 != args.size()) {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
ToUpper(&args[i + 1]);
|
||||
|
||||
string_view agg = ArgS(args, i + 1);
|
||||
if (agg == "SUM") {
|
||||
agg_type = AggType::SUM;
|
||||
} else if (agg == "MIN") {
|
||||
agg_type = AggType::MIN;
|
||||
} else if (agg == "MAX") {
|
||||
agg_type = AggType::MAX;
|
||||
} else {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
}
|
||||
|
||||
vector<OpResult<ScoredMap>> maps(cntx->shard_set->size());
|
||||
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) {
|
||||
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, agg_type, weights, false);
|
||||
return OpStatus::OK;
|
||||
};
|
||||
|
||||
cntx->transaction->Schedule();
|
||||
|
||||
cntx->transaction->Execute(std::move(cb), false);
|
||||
ScoredMap result;
|
||||
|
||||
for (auto& op_res : maps) {
|
||||
if (!op_res)
|
||||
return (*cntx)->SendError(op_res.status());
|
||||
UnionScoredMap(&result, &op_res.value(), agg_type);
|
||||
}
|
||||
ShardId dest_shard = Shard(dest_key, maps.size());
|
||||
AddResult add_result;
|
||||
vector<ScoredMemberView> smvec;
|
||||
for (const auto& elem : result) {
|
||||
smvec.emplace_back(elem.second, elem.first);
|
||||
}
|
||||
|
||||
auto store_cb = [&](Transaction* t, EngineShard* shard) {
|
||||
if (shard->shard_id() == dest_shard) {
|
||||
ZParams zparams;
|
||||
zparams.override = true;
|
||||
add_result =
|
||||
OpAdd(OpArgs{shard, t->db_index()}, zparams, dest_key, ScoredMemberSpan{smvec}).value();
|
||||
}
|
||||
return OpStatus::OK;
|
||||
};
|
||||
|
||||
OpStatus result = cntx->transaction->ScheduleSingleHop(std::move(cb));
|
||||
(*cntx)->SendOk();
|
||||
cntx->transaction->Execute(std::move(store_cb), true);
|
||||
|
||||
(*cntx)->SendLong(smvec.size());
|
||||
}
|
||||
|
||||
void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, string_view max_s,
|
||||
|
@ -1202,68 +1472,6 @@ OpResult<StringVec> ZSetFamily::OpScan(const OpArgs& op_args, std::string_view k
|
|||
return res;
|
||||
}
|
||||
|
||||
OpStatus ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
|
||||
ScoredMemberSpan members, AddResult* add_result) {
|
||||
DCHECK(!members.empty());
|
||||
OpResult<PrimeIterator> res_it =
|
||||
FindZEntry(zparams.flags, op_args, key, members.front().second.size());
|
||||
|
||||
if (!res_it)
|
||||
return res_it.status();
|
||||
|
||||
robj* zobj = res_it.value()->second.AsRObj();
|
||||
|
||||
unsigned added = 0;
|
||||
unsigned updated = 0;
|
||||
unsigned processed = 0;
|
||||
|
||||
sds& tmp_str = op_args.shard->tmp_str1;
|
||||
double new_score = 0;
|
||||
int retflags = 0;
|
||||
|
||||
OpStatus res = OpStatus::OK;
|
||||
|
||||
for (size_t j = 0; j < members.size(); j++) {
|
||||
const auto& m = members[j];
|
||||
tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size());
|
||||
|
||||
int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score);
|
||||
|
||||
if (zparams.flags & ZADD_IN_INCR) {
|
||||
if (retval == 0) {
|
||||
CHECK_EQ(1u, members.size());
|
||||
|
||||
add_result->is_nan = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (retflags & ZADD_OUT_NOP) {
|
||||
res = OpStatus::SKIPPED;
|
||||
}
|
||||
}
|
||||
|
||||
if (retflags & ZADD_OUT_ADDED)
|
||||
added++;
|
||||
if (retflags & ZADD_OUT_UPDATED)
|
||||
updated++;
|
||||
if (!(retflags & ZADD_OUT_NOP))
|
||||
processed++;
|
||||
}
|
||||
|
||||
DVLOG(2) << "ZAdd " << zobj->ptr;
|
||||
|
||||
res_it.value()->second.SyncRObj();
|
||||
op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it);
|
||||
|
||||
if (zparams.flags & ZADD_IN_INCR) {
|
||||
add_result->new_score = new_score;
|
||||
} else {
|
||||
add_result->num_updated = zparams.ch ? added + updated : added;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
OpResult<unsigned> ZSetFamily::OpRem(const OpArgs& op_args, string_view key, ArgSlice members) {
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_ind, key, OBJ_ZSET);
|
||||
|
@ -1496,11 +1704,13 @@ OpResult<unsigned> ZSetFamily::OpLexCount(const OpArgs& op_args, string_view key
|
|||
#define HFUNC(x) SetHandler(&ZSetFamily::x)
|
||||
|
||||
void ZSetFamily::Register(CommandRegistry* registry) {
|
||||
constexpr uint32_t kUnionMask = CO::WRITE | CO::DESTINATION_KEY | CO::REVERSE_MAPPING;
|
||||
|
||||
*registry << CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd)
|
||||
<< CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard)
|
||||
<< CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount)
|
||||
<< CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy)
|
||||
<< CI{"ZINTERSTORE", CO::WRITE | CO::DESTINATION_KEY, -4, 1, 1, 1}.HFUNC(ZInterStore)
|
||||
<< CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore)
|
||||
<< CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount)
|
||||
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem)
|
||||
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange)
|
||||
|
@ -1515,7 +1725,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
|
|||
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRevRangeByScore)
|
||||
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank)
|
||||
<< CI{"ZSCAN", CO::READONLY | CO::RANDOM, -3, 1, 1, 1}.HFUNC(ZScan)
|
||||
<< CI{"ZUNIONSTORE", CO::WRITE | CO::DESTINATION_KEY, -4, 3, 3, 1}.HFUNC(ZUnionStore);
|
||||
<< CI{"ZUNIONSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZUnionStore);
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -86,23 +86,6 @@ class ZSetFamily {
|
|||
|
||||
static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor);
|
||||
|
||||
struct ZParams {
|
||||
unsigned flags = 0; // mask of ZADD_IN_ macros.
|
||||
bool ch = false; // Corresponds to CH option.
|
||||
};
|
||||
|
||||
using ScoredMemberView = std::pair<double, std::string_view>;
|
||||
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
|
||||
|
||||
struct AddResult {
|
||||
double new_score;
|
||||
unsigned num_updated = 0;
|
||||
|
||||
bool is_nan = false;
|
||||
};
|
||||
|
||||
static facade::OpStatus OpAdd(const ZParams& zparams, const OpArgs& op_args, std::string_view key,
|
||||
ScoredMemberSpan members, AddResult* add_result);
|
||||
static OpResult<unsigned> OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members);
|
||||
static OpResult<double> OpScore(const OpArgs& op_args, std::string_view key,
|
||||
std::string_view member);
|
||||
|
|
|
@ -167,4 +167,63 @@ TEST_F(ZSetFamilyTest, ZScan) {
|
|||
EXPECT_EQ(100 * 2, scan_len);
|
||||
}
|
||||
|
||||
TEST_F(ZSetFamilyTest, ZUnionStore) {
|
||||
RespExpr resp;
|
||||
|
||||
resp = Run({"zunionstore", "key", "0"});
|
||||
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
|
||||
|
||||
resp = Run({"zunionstore", "key", "0", "aggregate", "sum"});
|
||||
EXPECT_THAT(resp, ErrArg("at least 1 input key is needed"));
|
||||
resp = Run({"zunionstore", "key", "-1", "aggregate", "sum"});
|
||||
EXPECT_THAT(resp, ErrArg("out of range"));
|
||||
resp = Run({"zunionstore", "key", "2", "foo", "bar", "weights", "1"});
|
||||
EXPECT_THAT(resp, ErrArg("syntax error"));
|
||||
|
||||
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"}));
|
||||
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
|
||||
|
||||
resp = Run({"zunionstore", "key", "2", "z1", "z2"});
|
||||
EXPECT_THAT(resp, IntArg(3));
|
||||
resp = Run({"zrange", "key", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "c", "3", "b", "4"));
|
||||
|
||||
resp = Run({"zunionstore", "z1", "1", "z1"});
|
||||
EXPECT_THAT(resp, IntArg(2));
|
||||
|
||||
resp = Run({"zunionstore", "z1", "2", "z1", "z2"});
|
||||
EXPECT_THAT(resp, IntArg(3));
|
||||
resp = Run({"zrange", "z1", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "c", "3", "b", "4"));
|
||||
|
||||
Run({"set", "foo", "bar"});
|
||||
resp = Run({"zunionstore", "foo", "1", "z2"});
|
||||
EXPECT_THAT(resp, IntArg(2));
|
||||
resp = Run({"zrange", "foo", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("b", "2", "c", "3"));
|
||||
}
|
||||
|
||||
TEST_F(ZSetFamilyTest, ZUnionStoreOpts) {
|
||||
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"}));
|
||||
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
|
||||
RespExpr resp;
|
||||
|
||||
EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"}));
|
||||
resp = Run({"zrange", "a", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "8", "c", "9"));
|
||||
|
||||
resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"});
|
||||
EXPECT_THAT(resp, ErrArg("syntax error"));
|
||||
|
||||
resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"});
|
||||
EXPECT_THAT(resp, IntArg(2));
|
||||
resp = Run({"zrange", "z1", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "2", "b", "4"));
|
||||
|
||||
resp = Run({"zunionstore", "max", "2", "z1", "z2", "weights", "1", "0", "aggregate", "max"});
|
||||
ASSERT_THAT(resp, IntArg(3));
|
||||
resp = Run({"zrange", "max", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("c", "0", "a", "2", "b", "4"));
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue