fix(server): zunion now supports variadic arguments (#717)

1. Before that we did no support a real syntax with <numkey> argument,
now we do.

2. Fix warnings.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2023-01-23 14:06:17 +02:00 committed by GitHub
parent 7662e03d1f
commit ac44a1f7e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 131 additions and 96 deletions

View file

@ -156,13 +156,19 @@ std::optional<std::string> GetRemoteVersion(ProactorBase* proactor, SSL_CTX* ssl
static bool is_logged{false};
if (!is_logged) {
is_logged = true;
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
const char* func_err = "ssl_internal_error";
#else
const char* func_err = ERR_func_error_string(ec.value());
#endif
// Unfortunately AsioStreamAdapter looses the original error category
// because std::error_code can not be converted into boost::system::error_code.
// It's fixed in later versions of Boost, but for now we assume it's from TLS.
LOG(WARNING) << "Remote version - HTTP GET error [" << host << ":" << service << resource
<< "], error: " << ec.value();
LOG(WARNING) << "ssl error: " << ERR_func_error_string(ec.value()) << "/"
<< ERR_reason_error_string(ec.value());
LOG(WARNING) << "ssl error: " << func_err << "/" << ERR_reason_error_string(ec.value());
}
}

View file

@ -1363,7 +1363,7 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) {
append("role", "master");
append("connected_slaves", m.conn_stats.num_replicas);
auto replicas = dfly_cmd_->GetReplicasRoleInfo();
for (auto i = 0; i < replicas.size(); i++) {
for (size_t i = 0; i < replicas.size(); i++) {
auto& r = replicas[i];
// e.g. slave0:ip=172.19.0.3,port=6379
append(StrCat("slave", i), StrCat("ip=", r.address, ",port=", r.listening_port));

View file

@ -988,7 +988,7 @@ ArgSlice Transaction::GetShardArgs(ShardId sid) const {
// from local index back to original arg index skipping the command.
// i.e. returns (first_key_pos -1) or bigger.
size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const {
if (unique_shard_cnt_ == 1) // mget: 0->0, 1->1. zunionstore has 0->2
if (unique_shard_cnt_ == 1)
return reverse_index_[arg_index];
const auto& sd = shard_data_[shard_id];
@ -1264,20 +1264,25 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
int num_custom_keys = -1;
if (cid->opt_mask() & CO::VARIADIC_KEYS) {
// ZUNION/INTER <num_keys> <key1> [<key2> ...]
// EVAL <script> <num_keys>
if (args.size() < 3) {
return OpStatus::SYNTAX_ERR;
}
string_view name{cid->name()};
if (!absl::StartsWith(name, "EVAL")) {
if (absl::EndsWith(name, "STORE")) {
key_index.bonus = 1; // Z<xxx>STORE commands
}
string_view num(ArgS(args, 2));
unsigned num_keys_index = absl::StartsWith(name, "EVAL") ? 2 : key_index.bonus + 1;
string_view num = ArgS(args, num_keys_index);
if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0)
return OpStatus::INVALID_INT;
if (size_t(num_custom_keys) + 3 > args.size())
if (args.size() < size_t(num_custom_keys) + num_keys_index + 1)
return OpStatus::SYNTAX_ERR;
}

View file

@ -715,62 +715,55 @@ ScoredMap UnionShardKeysWithScore(const KeyIterWeightVec& key_iter_weight_vec, A
return result;
}
double GetKeyWeight(EngineShard* shard, Transaction* t, const vector<double>& weights,
double GetKeyWeight(Transaction* t, ShardId shard_id, const vector<double>& weights,
unsigned key_index, unsigned cmdargs_keys_offset) {
unsigned windex = t->ReverseArgIndex(shard->shard_id(), key_index) - cmdargs_keys_offset;
if (weights.empty()) {
return 1;
}
unsigned windex = t->ReverseArgIndex(shard_id, key_index) - cmdargs_keys_offset;
DCHECK_LT(windex, weights.size());
return weights[windex];
}
OpResult<KeyIterWeightVec> FindShardKeysAndWeights(EngineShard* shard, Transaction* t,
ArgSlice keys, const vector<double>& weights,
unsigned src_keys_offset,
unsigned cmdargs_keys_offset) {
auto& db_slice = shard->db_slice();
KeyIterWeightVec key_weight_vec(keys.size() - src_keys_offset);
for (unsigned j = src_keys_offset; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->GetDbContext(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
if (!it_res)
continue;
// first global index is 2 after {destkey, numkeys}
key_weight_vec[j - src_keys_offset] = {*it_res,
GetKeyWeight(shard, t, weights, j, cmdargs_keys_offset)};
}
return key_weight_vec;
}
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->GetShardArgs(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
unsigned src_keys_offset = 0;
unsigned cmdargs_keys_offset = 0;
unsigned cmdargs_keys_offset = 1; // after {numkeys} for ZUNION
unsigned removed_keys = 0;
if (!dest.empty()) {
// first global index is 2 after {destkey, numkeys}
cmdargs_keys_offset = 2;
if (store) {
// first global index is 2 after {destkey, numkeys}.
++cmdargs_keys_offset;
if (keys.front() == dest) {
++src_keys_offset;
keys.remove_prefix(1);
++removed_keys;
}
// In case ONLY the destination key is hosted in this shard no work on this shard should be done
// in this step
if (src_keys_offset >= keys.size()) {
// In case ONLY the destination key is hosted in this shard no work on this shard should be
// done in this step
if (keys.empty()) {
return OpStatus::OK;
}
}
auto keys_and_weights =
FindShardKeysAndWeights(shard, t, keys, weights, src_keys_offset, cmdargs_keys_offset);
if (!keys_and_weights) {
return keys_and_weights.status();
auto& db_slice = shard->db_slice();
KeyIterWeightVec key_weight_vec(keys.size());
for (unsigned j = 0; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->GetDbContext(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
if (!it_res)
continue;
key_weight_vec[j] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys,
cmdargs_keys_offset)};
}
return UnionShardKeysWithScore(*keys_and_weights, agg_type);
return UnionShardKeysWithScore(key_weight_vec, agg_type);
}
OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
@ -779,18 +772,31 @@ OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
unsigned start = 0;
unsigned removed_keys = 0;
unsigned cmdargs_keys_offset = 1;
if (keys.front() == dest) {
++start;
if (store) {
// first global index is 2 after {destkey, numkeys}.
++cmdargs_keys_offset;
if (keys.front() == dest) {
keys.remove_prefix(1);
++removed_keys;
}
// In case ONLY the destination key is hosted in this shard no work on this shard should be
// done in this step
if (keys.empty()) {
return OpStatus::OK;
}
}
auto& db_slice = shard->db_slice();
vector<pair<PrimeIterator, double>> it_arr(keys.size() - start);
vector<pair<PrimeIterator, double>> it_arr(keys.size());
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
return OpStatus::SKIPPED; // return noop
for (unsigned j = start; j < keys.size(); ++j) {
for (unsigned j = 0; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->GetDbContext(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
@ -798,11 +804,8 @@ OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest
if (!it_res)
continue; // we exit in the next loop
// first global index is 2 after {destkey, numkeys}
unsigned src_indx = j - start;
unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2;
DCHECK_LT(windex, weights.size());
it_arr[src_indx] = {*it_res, weights[windex]};
it_arr[j] = {*it_res, GetKeyWeight(t, shard->shard_id(), weights, j + removed_keys,
cmdargs_keys_offset)};
}
ScoredMap result;
@ -945,12 +948,14 @@ OpResult<unsigned> ParseWeights(CmdArgList args, SetOpArgs* op_args) {
return OpStatus::SYNTAX_ERR;
}
op_args->weights.resize(op_args->num_keys, 1);
for (unsigned i = 0; i < op_args->num_keys; ++i) {
string_view weight = ArgS(args, i + 1);
if (!absl::SimpleAtod(weight, &(op_args->weights[i]))) {
if (!absl::SimpleAtod(weight, &op_args->weights[i])) {
return OpStatus::INVALID_FLOAT;
}
}
return op_args->num_keys;
}
@ -968,8 +973,7 @@ OpResult<unsigned> ParseWithScores(CmdArgList args, SetOpArgs* op_args) {
}
OpResult<SetOpArgs> ParseSetOpArgs(CmdArgList args, bool store) {
// TODO: support variadic key for ZUNION command (now fixed to 3)
string_view num_keys_str = store ? ArgS(args, 2) : "3";
string_view num_keys_str = store ? ArgS(args, 2) : ArgS(args, 1);
SetOpArgs op_args;
auto parsed = ParseKeyCount(num_keys_str, &op_args);
@ -977,11 +981,8 @@ OpResult<SetOpArgs> ParseSetOpArgs(CmdArgList args, bool store) {
return parsed.status();
}
unsigned opt_args_start = store ? 3 + op_args.num_keys : 1 + op_args.num_keys;
// TODO: modify this check when there is variadic key support for ZUNION command
DCHECK_GE(args.size(), opt_args_start);
op_args.weights.resize(op_args.num_keys, 1);
unsigned opt_args_start = op_args.num_keys + (store ? 3 : 2);
DCHECK_LE(opt_args_start, args.size()); // Checked inside DetermineKeys
for (size_t i = opt_args_start; i < args.size(); ++i) {
ToUpper(&args[i]);
@ -1032,9 +1033,10 @@ void ZUnionFamilyInternal(CmdArgList args, bool store, ConnectionContext* cntx)
vector<OpResult<ScoredMap>> maps(shard_set->size());
string_view dest_key = store ? ArgS(args, 1) : "";
string_view dest_key = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, op_args.agg_type, op_args.weights, false);
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, op_args.agg_type, op_args.weights, store);
return OpStatus::OK;
};
@ -1292,7 +1294,7 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
vector<OpResult<ScoredMap>> maps(shard_set->size(), OpStatus::SKIPPED);
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] = OpInter(shard, t, dest_key, op_args.agg_type, op_args.weights, false);
maps[shard->shard_id()] = OpInter(shard, t, dest_key, op_args.agg_type, op_args.weights, true);
return OpStatus::OK;
};
@ -1307,10 +1309,12 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
if (!op_res)
return (*cntx)->SendError(op_res.status());
if (result.empty())
if (result.empty()) {
result.swap(op_res.value());
else
} else {
InterScoredMap(&result, &op_res.value(), op_args.agg_type);
}
if (result.empty())
break;
}
@ -2119,13 +2123,13 @@ OpResult<unsigned> ZSetFamily::OpLexCount(const OpArgs& op_args, string_view key
#define HFUNC(x) SetHandler(&ZSetFamily::x)
void ZSetFamily::Register(CommandRegistry* registry) {
constexpr uint32_t kUnionMask = CO::WRITE | CO::VARIADIC_KEYS | CO::REVERSE_MAPPING;
constexpr uint32_t kStoreMask = CO::WRITE | CO::VARIADIC_KEYS | CO::REVERSE_MAPPING;
*registry << CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd)
<< CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard)
<< CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount)
<< CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy)
<< CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore)
<< CI{"ZINTERSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZInterStore)
<< CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount)
<< CI{"ZPOPMAX", CO::READONLY, 3, 1, 1, 1}.HFUNC(ZPopMax)
<< CI{"ZPOPMIN", CO::READONLY, 3, 1, 1, 1}.HFUNC(ZPopMin)
@ -2144,8 +2148,9 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRevRangeByScore)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank)
<< CI{"ZSCAN", CO::READONLY, -3, 1, 1, 1}.HFUNC(ZScan)
<< CI{"ZUNION", CO::READONLY | CO::REVERSE_MAPPING, -4, 1, 3, 1}.HFUNC(ZUnion)
<< CI{"ZUNIONSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZUnionStore);
<< CI{"ZUNION", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 2, 2, 1}
.HFUNC(ZUnion)
<< CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZUnionStore);
}
} // namespace dfly

View file

@ -272,54 +272,73 @@ TEST_F(ZSetFamilyTest, ZScan) {
EXPECT_EQ(100 * 2, scan_len);
}
TEST_F(ZSetFamilyTest, ZUnion) {
TEST_F(ZSetFamilyTest, ZUnionError) {
RespExpr resp;
resp = Run({"zunion", "0"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "1", "k"});
EXPECT_THAT(resp, ErrArg("weight value is not a float"));
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "something"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "2", "aggregate", "something"});
EXPECT_THAT(resp, ErrArg("weight value is not a float"));
resp = Run({"zunion", "3", "z1", "z2", "z3", "aggregate", "sum", "somescore"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "3", "z1", "z2", "z3", "withscores", "someargs"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "1"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"zunion", "2", "z1"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "2", "z1", "z2", "z3"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "2", "z1", "z2", "weights", "1", "2", "3"});
EXPECT_THAT(resp, ErrArg("syntax error"));
}
TEST_F(ZSetFamilyTest, ZUnion) {
RespExpr resp;
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "3", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z3", "1", "c", "1", "d"}));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "k"});
EXPECT_THAT(resp, ErrArg("weight value is not a float"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "something"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "2", "aggregate", "something"});
EXPECT_THAT(resp, ErrArg("weight value is not a float"));
resp = Run({"zunion", "z1", "z2", "z3", "aggregate", "sum", "somescore"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "z1", "z2", "z3", "withscores", "someargs"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunion", "z1", "z2", "z3"});
resp = Run({"zunion", "3", "z1", "z2", "z3"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "d", "c", "b"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2"});
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "1", "2"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "d", "b", "c"));
resp = Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "withscores"});
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "1", "2", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "d", "2", "b", "5", "c", "5"));
resp =
Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "min", "withscores"});
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "min",
"withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "2", "d", "2"));
resp =
Run({"zunion", "z1", "z2", "z3", "withscores", "weights", "1", "1", "2", "aggregate", "min"});
resp = Run({"zunion", "3", "z1", "z2", "z3", "withscores", "weights", "1", "1", "2", "aggregate",
"min"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "2", "d", "2"));
resp = Run({"zunion", "none1", "none2", "z3", "withscores", "weights", "1", "1", "2"});
resp = Run({"zunion", "3", "none1", "none2", "z3", "withscores", "weights", "1", "1", "2"});
EXPECT_THAT(resp.GetVec(), ElementsAre("c", "2", "d", "2"));
resp =
Run({"zunion", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "max", "withscores"});
resp = Run({"zunion", "3", "z1", "z2", "z3", "weights", "1", "1", "2", "aggregate", "max",
"withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "d", "2", "b", "3", "c", "3"));
resp = Run({"zunion", "1", "z1", "weights", "2", "aggregate", "max", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "2", "b", "6"));
}
TEST_F(ZSetFamilyTest, ZUnionStore) {