From be59b5eeb4d87ab335b2f1308538b4a0f04bc7a7 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 19 Jul 2024 14:23:46 +0300 Subject: [PATCH] chore: Make KeyIndex iterable (#3326) --- src/facade/facade_types.h | 4 + src/server/acl/validator.cc | 22 ++--- src/server/cluster/cluster_family_test.cc | 1 + src/server/main_service.cc | 28 ++----- src/server/multi_command_squasher.cc | 35 +++----- src/server/transaction.cc | 97 ++++++++++------------- src/server/tx_base.cc | 18 +++++ src/server/tx_base.h | 39 +++++---- 8 files changed, 107 insertions(+), 137 deletions(-) diff --git a/src/facade/facade_types.h b/src/facade/facade_types.h index 877047941..5b8e0eff1 100644 --- a/src/facade/facade_types.h +++ b/src/facade/facade_types.h @@ -91,6 +91,10 @@ struct ArgRange { return Range().second; } + std::string_view operator[](size_t idx) const { + return std::visit([idx](const auto& span) { return facade::ToSV(span[idx]); }, span); + } + std::variant span; }; struct ConnectionStats { diff --git a/src/server/acl/validator.cc b/src/server/acl/validator.cc index 9482230e5..9dcb1dba6 100644 --- a/src/server/acl/validator.cc +++ b/src/server/acl/validator.cc @@ -76,23 +76,11 @@ namespace dfly::acl { bool keys_allowed = true; if (!keys.all_keys && id.first_key_pos() != 0 && (is_read_command || is_write_command)) { - const auto keys_index = DetermineKeys(&id, tail_args).value(); - const size_t end = keys_index.end; - if (keys_index.bonus) { - auto target = facade::ToSV(tail_args[*keys_index.bonus]); - if (!iterate_globs(target)) { - keys_allowed = false; - } - } - if (keys_allowed) { - for (size_t i = keys_index.start; i < end; i += keys_index.step) { - auto target = facade::ToSV(tail_args[i]); - if (!iterate_globs(target)) { - keys_allowed = false; - break; - } - } - } + auto keys_index = DetermineKeys(&id, tail_args); + DCHECK(keys_index); + + for (std::string_view key : keys_index->Range(tail_args)) + keys_allowed &= iterate_globs(key); } return {keys_allowed, AclLog::Reason::KEY}; diff --git a/src/server/cluster/cluster_family_test.cc b/src/server/cluster/cluster_family_test.cc index 11224d035..d523bb5ae 100644 --- a/src/server/cluster/cluster_family_test.cc +++ b/src/server/cluster/cluster_family_test.cc @@ -686,6 +686,7 @@ TEST_F(ClusterFamilyTest, ClusterCrossSlot) { EXPECT_THAT(Run({"MSET", "key", "value", "key2", "value2"}), ErrArg("CROSSSLOT")); EXPECT_THAT(Run({"MGET", "key", "key2"}), ErrArg("CROSSSLOT")); + EXPECT_THAT(Run({"ZINTERSTORE", "key", "2", "key1", "key2"}), ErrArg("CROSSSLOT")); EXPECT_EQ(Run({"MSET", "key{tag}", "value", "key2{tag}", "value2"}), "OK"); EXPECT_THAT(Run({"MGET", "key{tag}", "key2{tag}"}), RespArray(ElementsAre("value", "value2"))); diff --git a/src/server/main_service.cc b/src/server/main_service.cc index c88cc15ba..c690d1a94 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -714,7 +714,7 @@ Transaction::MultiMode DeduceExecMode(ExecEvalState state, StoredCmd cmd = scmd; cmd.Fill(&arg_vec); auto keys = DetermineKeys(scmd.Cid(), absl::MakeSpan(arg_vec)); - transactional |= (keys && keys.value().num_args() > 0); + transactional |= (keys && keys.value().NumArgs() > 0); } else { transactional |= scmd.Cid()->IsTransactional(); } @@ -942,10 +942,8 @@ optional Service::CheckKeysOwnership(const CommandId* cid, CmdArgLis optional keys_slot; bool cross_slot = false; // Iterate keys and check to which slot they belong. - for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) { - string_view key = ArgS(args, i); - cluster::SlotId slot = cluster::KeySlot(key); - if (keys_slot && slot != *keys_slot) { + for (string_view key : key_index.Range(args)) { + if (cluster::SlotId slot = cluster::KeySlot(key); keys_slot && slot != *keys_slot) { cross_slot = true; // keys belong to different slots break; } else { @@ -984,18 +982,7 @@ optional CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_i // TODO: Switch to transaction internal locked keys once single hop multi transactions are merged // const auto& locked_keys = trans->GetMultiKeys(); const auto& locked_tags = eval_info.lock_tags; - - const auto& key_index = *key_index_res; - for (unsigned i = key_index.start; i < key_index.end; ++i) { - string_view key = ArgS(args, i); - LockTag tag{key}; - if (!locked_tags.contains(tag)) { - return ErrorReply(absl::StrCat(kUndeclaredKeyErr, ", key: ", key)); - } - } - - if (key_index.bonus) { - string_view key = ArgS(args, *key_index.bonus); + for (string_view key : key_index_res->Range(args)) { if (!locked_tags.contains(LockTag{key})) { return ErrorReply(absl::StrCat(kUndeclaredKeyErr, ", key: ", key)); } @@ -2118,13 +2105,8 @@ template void IterateAllKeys(ConnectionState::ExecInfo* exec_info, if (!key_res.ok()) continue; - auto key_index = key_res.value(); - - for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) + for (unsigned i : key_res->Range()) f(arg_vec[i]); - - if (key_index.bonus) - f(arg_vec[*key_index.bonus]); } } diff --git a/src/server/multi_command_squasher.cc b/src/server/multi_command_squasher.cc index c25680f3b..80d518fbf 100644 --- a/src/server/multi_command_squasher.cc +++ b/src/server/multi_command_squasher.cc @@ -13,6 +13,7 @@ #include "server/conn_context.h" #include "server/engine_shard_set.h" #include "server/transaction.h" +#include "server/tx_base.h" namespace dfly { @@ -22,14 +23,6 @@ using namespace util; namespace { -template void IterateKeys(CmdArgList args, KeyIndex keys, F&& f) { - for (unsigned i = keys.start; i < keys.end; i += keys.step) - f(args[i]); - - if (keys.bonus) - f(args[*keys.bonus]); -} - void CheckConnStateClean(const ConnectionState& state) { DCHECK_EQ(state.exec_info.state, ConnectionState::ExecInfo::EXEC_INACTIVE); DCHECK(state.exec_info.body.empty()); @@ -90,29 +83,21 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm auto keys = DetermineKeys(cmd->Cid(), args); if (!keys.ok()) return SquashResult::ERROR; + if (keys->NumArgs() == 0) + return SquashResult::NOT_SQUASHED; // Check if all commands belong to one shard - bool found_more = false; cluster::UniqueSlotChecker slot_checker; ShardId last_sid = kInvalidSid; - IterateKeys(args, *keys, [&last_sid, &found_more, &slot_checker](MutableSlice key) { - if (found_more) - return; - string_view key_sv = facade::ToSV(key); - - slot_checker.Add(key_sv); - - ShardId sid = Shard(key_sv, shard_set->size()); - if (last_sid == kInvalidSid || last_sid == sid) { + for (string_view key : keys->Range(args)) { + slot_checker.Add(key); + ShardId sid = Shard(key, shard_set->size()); + if (last_sid == kInvalidSid || last_sid == sid) last_sid = sid; - return; - } - found_more = true; - }); - - if (found_more || last_sid == kInvalidSid) - return SquashResult::NOT_SQUASHED; + else + return SquashResult::NOT_SQUASHED; // at least two shards + } auto& sinfo = PrepareShardInfo(last_sid, slot_checker.GetUniqueSlotId()); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 393a08b9a..1ff6c8045 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -191,32 +191,23 @@ void Transaction::InitGlobal() { } void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vector* out) { + // Because of the way we iterate in InitShardData + DCHECK(!key_index.bonus || key_index.step == 1); + auto& shard_index = *out; - - auto add = [&shard_index](uint32_t sid, uint32_t b, uint32_t e) { - auto& slices = shard_index[sid].slices; - if (!slices.empty() && slices.back().second == b) { - slices.back().second = e; - } else { - slices.emplace_back(b, e); - } - }; - - if (key_index.bonus) { - DCHECK(key_index.step == 1); - string_view key = ArgS(full_args_, *key_index.bonus); - unique_slot_checker_.Add(key); - uint32_t sid = Shard(key, shard_data_.size()); - add(sid, *key_index.bonus, *key_index.bonus + 1); - } - - for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) { + for (unsigned i : key_index.Range()) { string_view key = ArgS(full_args_, i); unique_slot_checker_.Add(key); - uint32_t sid = Shard(key, shard_data_.size()); - shard_index[sid].key_step = key_index.step; + ShardId sid = Shard(key, shard_data_.size()); - add(sid, i, i + key_index.step); + unsigned step = key_index.bonus ? 1 : key_index.step; + shard_index[sid].key_step = step; + auto& slices = shard_index[sid].slices; + if (!slices.empty() && slices.back().second == i) { + slices.back().second = i + step; + } else { + slices.emplace_back(i, i + step); + } } } @@ -247,11 +238,9 @@ void Transaction::InitShardData(absl::Span shard_index, siz unique_shard_cnt_++; unique_shard_id_ = i; - for (size_t j = 0; j < src.slices.size(); ++j) { - IndexSlice slice = src.slices[j]; - args_slices_.push_back(slice); - for (uint32_t k = slice.first; k < slice.second; k += src.key_step) { - string_view key = ArgS(full_args_, k); + for (const auto& [start, end] : src.slices) { + args_slices_.emplace_back(start, end); + for (string_view key : KeyIndex(start, end, src.key_step).Range(full_args_)) { kv_fp_.push_back(LockTag(key).Fingerprint()); sd.fp_count++; } @@ -279,10 +268,8 @@ void Transaction::StoreKeysInArgs(const KeyIndex& key_index) { // even for a single key we may have multiple arguments per key (MSET). args_slices_.emplace_back(key_index.start, key_index.end); - for (unsigned j = key_index.start; j < key_index.end; j += key_index.step) { - string_view key = ArgS(full_args_, j); + for (string_view key : key_index.Range(full_args_)) kv_fp_.push_back(LockTag(key).Fingerprint()); - } } void Transaction::InitByKeys(const KeyIndex& key_index) { @@ -296,14 +283,14 @@ void Transaction::InitByKeys(const KeyIndex& key_index) { // Stub transactions always operate only on single shard. bool is_stub = multi_ && multi_->role == SQUASHED_STUB; - if ((key_index.HasSingleKey() && !IsAtomicMulti()) || is_stub) { + if ((key_index.NumArgs() == 1 && !IsAtomicMulti()) || is_stub) { DCHECK(!IsActiveMulti() || multi_->mode == NON_ATOMIC); // We don't have to split the arguments by shards, so we can copy them directly. StoreKeysInArgs(key_index); unique_shard_cnt_ = 1; - string_view akey = ArgS(full_args_, key_index.start); + string_view akey = *key_index.Range(full_args_).begin(); if (is_stub) // stub transactions don't migrate DCHECK_EQ(unique_shard_id_, Shard(akey, shard_set->size())); else { @@ -329,7 +316,7 @@ void Transaction::InitByKeys(const KeyIndex& key_index) { BuildShardIndex(key_index, &shard_index); // Initialize shard data based on distributed arguments. - InitShardData(shard_index, key_index.num_args()); + InitShardData(shard_index, key_index.NumArgs()); DCHECK(!multi_ || multi_->mode != LOCK_AHEAD || !multi_->tag_fps.empty()); @@ -441,7 +428,7 @@ void Transaction::StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList PrepareMultiFps(keys); InitBase(ns, dbid, keys); - InitByKeys(KeyIndex::Range(0, keys.size())); + InitByKeys(KeyIndex(0, keys.size())); if (!skip_scheduling) ScheduleInternal(); @@ -1433,23 +1420,24 @@ bool Transaction::CanRunInlined() const { } OpResult DetermineKeys(const CommandId* cid, CmdArgList args) { - KeyIndex key_index; - if (cid->opt_mask() & (CO::GLOBAL_TRANS | CO::NO_KEY_TRANSACTIONAL)) - return key_index; + return KeyIndex{}; int num_custom_keys = -1; - if (cid->opt_mask() & CO::VARIADIC_KEYS) { + unsigned start = 0, end = 0, step = 0; + std::optional bonus = std::nullopt; + + if (cid->opt_mask() & CO::VARIADIC_KEYS) { // number of keys is not trivially deducable // ZUNION/INTER [ ...] // EVAL