feat(zset family): implement a variant of ZUNION command issue #356 (#686)

Signed-off-by: ATM SALEH <saleh.cse08@gmail.com>
Signed-off-by: adi_holden <adi@dragonflydb.io>
Co-authored-by: ATM SALEH <saleh.cse08@gmail.com>
This commit is contained in:
adiholden 2023-01-17 12:20:22 +02:00 committed by GitHub
parent 1f5811fb78
commit b2edf9c848
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 279 additions and 126 deletions

View file

@ -690,24 +690,39 @@ void InterScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
dest->swap(*src);
}
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
using KeyIterWeightVec = vector<pair<PrimeIterator, double>>;
unsigned start = 0;
ScoredMap UnionShardKeysWithScore(const KeyIterWeightVec& key_iter_weight_vec, AggType agg_type) {
ScoredMap result;
for (const auto& key_iter_wieght : key_iter_weight_vec) {
if (key_iter_wieght.first.is_done()) {
continue;
}
if (keys.front() == dest) {
++start;
ScoredMap sm = FromObject(key_iter_wieght.first->second, key_iter_wieght.second);
if (result.empty()) {
result.swap(sm);
} else {
UnionScoredMap(&result, &sm, agg_type);
}
}
return result;
}
double GetKeyWeight(EngineShard* shard, Transaction* t, const vector<double>& weights,
unsigned key_index, unsigned cmdargs_keys_offset) {
unsigned windex = t->ReverseArgIndex(shard->shard_id(), key_index) - cmdargs_keys_offset;
DCHECK_LT(windex, weights.size());
return weights[windex];
}
OpResult<KeyIterWeightVec> FindShardKeysAndWeights(EngineShard* shard, Transaction* t,
ArgSlice keys, const vector<double>& weights,
unsigned src_keys_offset,
unsigned cmdargs_keys_offset) {
auto& db_slice = shard->db_slice();
vector<pair<PrimeIterator, double>> it_arr(keys.size() - start);
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
return OpStatus::OK; // return empty map
for (unsigned j = start; j < keys.size(); ++j) {
KeyIterWeightVec key_weight_vec(keys.size() - src_keys_offset);
for (unsigned j = src_keys_offset; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->db_context(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
@ -715,25 +730,41 @@ OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest
continue;
// first global index is 2 after {destkey, numkeys}
unsigned src_indx = j - start;
unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2;
DCHECK_LT(windex, weights.size());
it_arr[src_indx] = {*it_res, weights[windex]};
key_weight_vec[j - src_keys_offset] = {*it_res,
GetKeyWeight(shard, t, weights, j, cmdargs_keys_offset)};
}
return key_weight_vec;
}
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
unsigned src_keys_offset = 0;
unsigned cmdargs_keys_offset = 0;
if (!dest.empty()) {
// first global index is 2 after {destkey, numkeys}
cmdargs_keys_offset = 2;
if (keys.front() == dest) {
++src_keys_offset;
}
// In case ONLY the destination key is hosted in this shard no work on this shard should be done
// in this step
if (src_keys_offset >= keys.size()) {
return OpStatus::OK;
}
}
auto keys_and_weights =
FindShardKeysAndWeights(shard, t, keys, weights, src_keys_offset, cmdargs_keys_offset);
if (!keys_and_weights) {
return keys_and_weights.status();
}
ScoredMap result;
for (auto it = it_arr.begin(); it != it_arr.end(); ++it) {
if (it->first.is_done())
continue;
ScoredMap sm = FromObject(it->first->second, it->second);
if (result.empty())
result.swap(sm);
else
UnionScoredMap(&result, &sm, agg_type);
}
return result;
return UnionShardKeysWithScore(*keys_and_weights, agg_type);
}
OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
@ -869,61 +900,180 @@ OpResult<AddResult> OpAdd(const OpArgs& op_args, const ZParams& zparams, string_
return aresult;
}
struct StoreArgs {
struct SetOpArgs {
AggType agg_type = AggType::SUM;
unsigned num_keys;
vector<double> weights;
bool with_scores = false;
};
OpResult<StoreArgs> ParseStoreArgs(CmdArgList args) {
string_view num_str = ArgS(args, 2);
StoreArgs store_args;
OpResult<void> FillAggType(string_view agg, SetOpArgs* op_args) {
if (agg == "SUM") {
op_args->agg_type = AggType::SUM;
} else if (agg == "MIN") {
op_args->agg_type = AggType::MIN;
} else if (agg == "MAX") {
op_args->agg_type = AggType::MAX;
} else {
return OpStatus::SYNTAX_ERR;
}
return OpStatus::OK;
}
// Parse functions return the number of arguments read from CmdArgList
OpResult<unsigned> ParseAggregate(CmdArgList args, bool store, SetOpArgs* op_args) {
if (args.size() < 2) {
return OpStatus::SYNTAX_ERR;
}
ToUpper(&args[1]);
auto filled = FillAggType(ArgS(args, 1), op_args);
if (!filled) {
return filled.status();
}
return 1;
}
OpResult<unsigned> ParseWeights(CmdArgList args, SetOpArgs* op_args) {
if (args.size() <= op_args->num_keys) {
return OpStatus::SYNTAX_ERR;
}
for (unsigned i = 0; i < op_args->num_keys; ++i) {
string_view weight = ArgS(args, i + 1);
if (!absl::SimpleAtod(weight, &(op_args->weights[i]))) {
return OpStatus::INVALID_FLOAT;
}
}
return op_args->num_keys;
}
OpResult<void> ParseKeyCount(string_view arg_num_keys, SetOpArgs* op_args) {
// we parsed the structure before, when transaction has been initialized.
CHECK(absl::SimpleAtoi(num_str, &store_args.num_keys));
DCHECK_GE(args.size(), 3 + store_args.num_keys);
if (!absl::SimpleAtoi(arg_num_keys, &op_args->num_keys)) {
return OpStatus::SYNTAX_ERR;
}
return OpStatus::OK;
}
store_args.weights.resize(store_args.num_keys, 1);
for (size_t i = 3 + store_args.num_keys; i < args.size(); ++i) {
OpResult<unsigned> ParseWithScores(CmdArgList args, SetOpArgs* op_args) {
op_args->with_scores = true;
return 0;
}
OpResult<SetOpArgs> ParseSetOpArgs(CmdArgList args, bool store) {
// TODO: support variadic key for ZUNION command (now fixed to 3)
string_view num_keys_str = store ? ArgS(args, 2) : "3";
SetOpArgs op_args;
auto parsed = ParseKeyCount(num_keys_str, &op_args);
if (!parsed) {
return parsed.status();
}
unsigned opt_args_start = store ? 3 + op_args.num_keys : 1 + op_args.num_keys;
// TODO: modify this check when there is variadic key support for ZUNION command
DCHECK_GE(args.size(), opt_args_start);
op_args.weights.resize(op_args.num_keys, 1);
for (size_t i = opt_args_start; i < args.size(); ++i) {
ToUpper(&args[i]);
string_view arg = ArgS(args, i);
if (arg == "WEIGHTS") {
if (args.size() <= i + store_args.num_keys) {
return OpStatus::SYNTAX_ERR;
auto parsed_cnt = ParseWeights(args.subspan(i), &op_args);
if (!parsed_cnt) {
return parsed_cnt.status();
}
for (unsigned j = 0; j < store_args.num_keys; ++j) {
string_view weight = ArgS(args, i + j + 1);
if (!absl::SimpleAtod(weight, &store_args.weights[j])) {
return OpStatus::INVALID_FLOAT;
}
}
i += store_args.num_keys;
i += *parsed_cnt;
} else if (arg == "AGGREGATE") {
if (i + 2 != args.size()) {
auto parsed_cnt = ParseAggregate(args.subspan(i), store, &op_args);
if (!parsed_cnt) {
return parsed_cnt.status();
}
i += *parsed_cnt;
} else if (arg == "WITHSCORES") {
// Commands with store capability does not offer WITHSCORES option
if (store) {
return OpStatus::SYNTAX_ERR;
}
ToUpper(&args[i + 1]);
string_view agg = ArgS(args, i + 1);
if (agg == "SUM") {
store_args.agg_type = AggType::SUM;
} else if (agg == "MIN") {
store_args.agg_type = AggType::MIN;
} else if (agg == "MAX") {
store_args.agg_type = AggType::MAX;
} else {
return OpStatus::SYNTAX_ERR;
auto parsed_cnt = ParseWithScores(args.subspan(i), &op_args);
if (!parsed_cnt) {
return parsed_cnt.status();
}
break;
i += *parsed_cnt;
} else {
return OpStatus::SYNTAX_ERR;
}
}
return op_args;
}
return store_args;
};
void ZUnionFamilyInternal(CmdArgList args, bool store, ConnectionContext* cntx) {
OpResult<SetOpArgs> op_args_res = ParseSetOpArgs(args, store);
if (!op_args_res) {
switch (op_args_res.status()) {
case OpStatus::INVALID_FLOAT:
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
default:
return (*cntx)->SendError(op_args_res.status());
}
}
const auto& op_args = *op_args_res;
if (op_args.num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
vector<OpResult<ScoredMap>> maps(shard_set->size());
string_view dest_key = store ? ArgS(args, 1) : "";
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, op_args.agg_type, op_args.weights, false);
return OpStatus::OK;
};
cntx->transaction->Schedule();
// For commands not storing computed result, this should be
// the last transaction hop (e.g. ZUNION)
cntx->transaction->Execute(std::move(cb), !store);
ScoredMap result;
for (auto& op_res : maps) {
if (!op_res)
return (*cntx)->SendError(op_res.status());
UnionScoredMap(&result, &op_res.value(), op_args.agg_type);
}
vector<ScoredMemberView> smvec;
for (const auto& elem : result) {
smvec.emplace_back(elem.second, elem.first);
}
if (store) {
ShardId dest_shard = Shard(dest_key, maps.size());
AddResult add_result;
auto store_cb = [&](Transaction* t, EngineShard* shard) {
if (shard->shard_id() == dest_shard) {
ZParams zparams;
zparams.override = true;
add_result = OpAdd(t->GetOpArgs(shard), zparams, dest_key, ScoredMemberSpan{smvec}).value();
}
return OpStatus::OK;
};
cntx->transaction->Execute(std::move(store_cb), true);
(*cntx)->SendLong(smvec.size());
} else {
std::sort(std::begin(smvec), std::end(smvec));
(*cntx)->StartArray(smvec.size() * (op_args.with_scores ? 2 : 1));
for (const auto& elem : smvec) {
(*cntx)->SendBulkString(elem.second);
if (op_args.with_scores) {
(*cntx)->SendDouble(elem.first);
}
}
}
}
} // namespace
@ -1108,26 +1258,25 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
string_view dest_key = ArgS(args, 1);
OpResult<StoreArgs> store_args_res = ParseStoreArgs(args);
OpResult<SetOpArgs> op_args_res = ParseSetOpArgs(args, true);
if (!store_args_res) {
switch (store_args_res.status()) {
if (!op_args_res) {
switch (op_args_res.status()) {
case OpStatus::INVALID_FLOAT:
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
default:
return (*cntx)->SendError(store_args_res.status());
return (*cntx)->SendError(op_args_res.status());
}
}
const auto& store_args = *store_args_res;
if (store_args.num_keys == 0) {
const auto& op_args = *op_args_res;
if (op_args.num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
vector<OpResult<ScoredMap>> maps(shard_set->size(), OpStatus::SKIPPED);
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] =
OpInter(shard, t, dest_key, store_args.agg_type, store_args.weights, false);
maps[shard->shard_id()] = OpInter(shard, t, dest_key, op_args.agg_type, op_args.weights, false);
return OpStatus::OK;
};
@ -1145,7 +1294,7 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
if (result.empty())
result.swap(op_res.value());
else
InterScoredMap(&result, &op_res.value(), store_args.agg_type);
InterScoredMap(&result, &op_res.value(), op_args.agg_type);
if (result.empty())
break;
}
@ -1455,60 +1604,12 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
}
}
void ZSetFamily::ZUnion(CmdArgList args, ConnectionContext* cntx) {
ZUnionFamilyInternal(args, false, cntx);
}
void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
string_view dest_key = ArgS(args, 1);
OpResult<StoreArgs> store_args_res = ParseStoreArgs(args);
if (!store_args_res) {
switch (store_args_res.status()) {
case OpStatus::INVALID_FLOAT:
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
default:
return (*cntx)->SendError(store_args_res.status());
}
}
const auto& store_args = *store_args_res;
if (store_args.num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
vector<OpResult<ScoredMap>> maps(shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] =
OpUnion(shard, t, dest_key, store_args.agg_type, store_args.weights, false);
return OpStatus::OK;
};
cntx->transaction->Schedule();
cntx->transaction->Execute(std::move(cb), false);
ScoredMap result;
for (auto& op_res : maps) {
if (!op_res)
return (*cntx)->SendError(op_res.status());
UnionScoredMap(&result, &op_res.value(), store_args.agg_type);
}
ShardId dest_shard = Shard(dest_key, maps.size());
AddResult add_result;
vector<ScoredMemberView> smvec;
for (const auto& elem : result) {
smvec.emplace_back(elem.second, elem.first);
}
auto store_cb = [&](Transaction* t, EngineShard* shard) {
if (shard->shard_id() == dest_shard) {
ZParams zparams;
zparams.override = true;
add_result = OpAdd(t->GetOpArgs(shard), zparams, dest_key, ScoredMemberSpan{smvec}).value();
}
return OpStatus::OK;
};
cntx->transaction->Execute(std::move(store_cb), true);
(*cntx)->SendLong(smvec.size());
ZUnionFamilyInternal(args, true, cntx);
}
void ZSetFamily::ZRangeByScoreInternal(CmdArgList args, bool reverse, ConnectionContext* cntx) {
@ -2034,6 +2135,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRevRangeByScore)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank)
<< CI{"ZSCAN", CO::READONLY, -3, 1, 1, 1}.HFUNC(ZScan)
<< CI{"ZUNION", CO::READONLY | CO::REVERSE_MAPPING, -4, 1, 3, 1}.HFUNC(ZUnion)
<< CI{"ZUNIONSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZUnionStore);
}

View file

@ -79,6 +79,7 @@ class ZSetFamily {
static void ZRevRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRevRank(CmdArgList args, ConnectionContext* cntx);
static void ZScan(CmdArgList args, ConnectionContext* cntx);
static void ZUnion(CmdArgList args, ConnectionContext* cntx);
static void ZUnionStore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScoreInternal(CmdArgList args, bool reverse, ConnectionContext* cntx);

View file

@ -264,6 +264,56 @@ TEST_F(ZSetFamilyTest, ZScan) {
EXPECT_EQ(100 * 2, scan_len);
}
TEST_F(ZSetFamilyTest, ZUnion) {
RespExpr resp;
resp = Run({"zunion", "0"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "3", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z3", "1", "c", "1", "d"}));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "k"});
EXPECT_THAT(resp, ErrArg("weight value is not a float"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "something"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "2", "aggregate", "something"});
EXPECT_THAT(resp, ErrArg("weight value is not a float"));
resp = Run({"zunion", "z1", "z2", "z3", "aggregate", "sum", "somescore"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "z1", "z2", "z3", "withscores", "someargs"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "z1", "z2", "z3"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "d", "c", "b"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "d", "b", "c"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "d", "2", "b", "5", "c", "5"));
resp =
Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "min", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "2", "d", "2"));
resp =
Run({"zunion", "z1", "z2", "z3", "withscores", "weights", "1", "1", "2", "aggregate", "min"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "2", "d", "2"));
resp = Run({"zunion", "none1", "none2", "z3", "withscores", "weights", "1", "1", "2"});
EXPECT_THAT(resp.GetVec(), ElementsAre("c", "2", "d", "2"));
resp =
Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "max", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "d", "2", "b", "3", "c", "3"));
}
TEST_F(ZSetFamilyTest, ZUnionStore) {
RespExpr resp;