diff --git a/src/server/blocking_controller.cc b/src/server/blocking_controller.cc index b665ad66f..aa97334fe 100644 --- a/src/server/blocking_controller.cc +++ b/src/server/blocking_controller.cc @@ -4,6 +4,8 @@ #include "server/blocking_controller.h" +#include + #include extern "C" { @@ -20,12 +22,13 @@ using namespace std; struct WatchItem { Transaction* trans; + KeyReadyChecker key_ready_checker; Transaction* get() const { return trans; } - WatchItem(Transaction* t) : trans(t) { + WatchItem(Transaction* t, KeyReadyChecker krc) : trans(t), key_ready_checker(std::move(krc)) { } }; @@ -212,15 +215,7 @@ void BlockingController::NotifyPending() { for (auto key : wt.awakened_keys) { string_view sv_key = static_cast(key); DVLOG(1) << "Processing awakened key " << sv_key; - - // Double verify we still got the item. - auto [it, exp_it] = owner_->db_slice().FindReadOnly(context, sv_key); - // Only LIST, ZSET and STREAM are allowed to block. - if (!IsValid(it) || !(it->second.ObjType() == OBJ_LIST || it->second.ObjType() == OBJ_ZSET || - it->second.ObjType() == OBJ_STREAM)) - continue; - - NotifyWatchQueue(sv_key, &wt.queue_map); + NotifyWatchQueue(sv_key, &wt.queue_map, context); } wt.awakened_keys.clear(); @@ -231,7 +226,7 @@ void BlockingController::NotifyPending() { awakened_indices_.clear(); } -void BlockingController::AddWatched(ArgSlice keys, Transaction* trans) { +void BlockingController::AddWatched(ArgSlice keys, KeyReadyChecker krc, Transaction* trans) { auto [dbit, added] = watched_dbs_.emplace(trans->GetDbIndex(), nullptr); if (added) { dbit->second.reset(new DbWatchTable); @@ -254,7 +249,7 @@ void BlockingController::AddWatched(ArgSlice keys, Transaction* trans) { continue; } DVLOG(2) << "Emplace " << trans->DebugId() << " to watch " << key; - res->second->items.emplace_back(trans); + res->second->items.emplace_back(trans, krc); } } @@ -275,33 +270,40 @@ void BlockingController::AwakeWatched(DbIndex db_index, string_view db_key) { } // Marks the queue as active and notifies the first transaction in the queue. -void BlockingController::NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm) { +void BlockingController::NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm, + const DbContext& context) { auto w_it = wqm->find(key); CHECK(w_it != wqm->end()); DVLOG(1) << "Notify WQ: [" << owner_->shard_id() << "] " << key; WatchQueue* wq = w_it->second.get(); - DCHECK_EQ(wq->state, WatchQueue::SUSPENDED); - wq->state = WatchQueue::ACTIVE; auto& queue = wq->items; ShardId sid = owner_->shard_id(); - do { - WatchItem& wi = queue.front(); + // In the most cases we shouldn't have skipped elements at all + absl::InlinedVector skipped; + while (!queue.empty()) { + auto& wi = queue.front(); Transaction* head = wi.get(); - DVLOG(2) << "WQ-Pop " << head->DebugId() << " from key " << key; - - if (head->NotifySuspended(owner_->committed_txid(), sid, key)) { - // We deliberately keep the notified transaction in the queue to know which queue - // must handled when this transaction finished. - wq->notify_txid = owner_->committed_txid(); - awakened_transactions_.insert(head); - break; + // We check may the transaction be notified otherwise move it to the end of the queue + if (wi.key_ready_checker(owner_, context, head, key)) { + DVLOG(2) << "WQ-Pop " << head->DebugId() << " from key " << key; + if (head->NotifySuspended(owner_->committed_txid(), sid, key)) { + wq->state = WatchQueue::ACTIVE; + // We deliberately keep the notified transaction in the queue to know which queue + // must handled when this transaction finished. + wq->notify_txid = owner_->committed_txid(); + awakened_transactions_.insert(head); + break; + } + } else { + skipped.push_back(std::move(wi)); } queue.pop_front(); - } while (!queue.empty()); + } + std::move(skipped.begin(), skipped.end(), std::back_inserter(queue)); if (wq->items.empty()) { wqm->erase(w_it); diff --git a/src/server/blocking_controller.h b/src/server/blocking_controller.h index fb83fe934..48df48549 100644 --- a/src/server/blocking_controller.h +++ b/src/server/blocking_controller.h @@ -39,7 +39,7 @@ class BlockingController { // TODO: consider moving all watched functions to // EngineShard with separate per db map. //! AddWatched adds a transaction to the blocking queue. - void AddWatched(ArgSlice watch_keys, Transaction* me); + void AddWatched(ArgSlice watch_keys, KeyReadyChecker krc, Transaction* me); // Called from operations that create keys like lpush, rename etc. void AwakeWatched(DbIndex db_index, std::string_view db_key); @@ -54,7 +54,7 @@ class BlockingController { using WatchQueueMap = absl::flat_hash_map>; - void NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm); + void NotifyWatchQueue(std::string_view key, WatchQueueMap* wqm, const DbContext& context); // void NotifyConvergence(Transaction* tx); diff --git a/src/server/blocking_controller_test.cc b/src/server/blocking_controller_test.cc index 18df73ef2..01f0ed8c8 100644 --- a/src/server/blocking_controller_test.cc +++ b/src/server/blocking_controller_test.cc @@ -75,7 +75,8 @@ TEST_F(BlockingControllerTest, Basic) { EngineShard* shard = EngineShard::tlocal(); BlockingController bc(shard); auto keys = trans_->GetShardArgs(shard->shard_id()); - bc.AddWatched(keys, trans_.get()); + bc.AddWatched( + keys, [](auto...) { return true; }, trans_.get()); EXPECT_EQ(1, bc.NumWatched(0)); bc.FinalizeWatched(keys, trans_.get()); @@ -89,7 +90,7 @@ TEST_F(BlockingControllerTest, Timeout) { trans_->Schedule(); auto cb = [&](Transaction* t, EngineShard* shard) { return trans_->GetShardArgs(0); }; - facade::OpStatus status = trans_->WaitOnWatch(tp, cb); + facade::OpStatus status = trans_->WaitOnWatch(tp, cb, [](auto...) { return true; }); EXPECT_EQ(status, facade::OpStatus::TIMED_OUT); unsigned num_watched = shard_set->Await( diff --git a/src/server/common.h b/src/server/common.h index 19644c1d8..582cddf83 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -359,6 +359,10 @@ inline uint32_t MemberTimeSeconds(uint64_t now_ms) { return (now_ms / 1000) - kMemberExpiryBase; } +// Checks whether the touched key is valid for a blocking transaction watching it +using KeyReadyChecker = + std::function; + struct MemoryBytesFlag { uint64_t value = 0; }; diff --git a/src/server/container_utils.cc b/src/server/container_utils.cc index 7c4832e88..90c18a9a2 100644 --- a/src/server/container_utils.cc +++ b/src/server/container_utils.cc @@ -283,7 +283,11 @@ OpResult RunCbOnFirstNonEmptyBlocking(Transaction* trans, int req_obj_ty auto wcb = [](Transaction* t, EngineShard* shard) { return t->GetShardArgs(shard->shard_id()); }; *block_flag = true; - auto status = trans->WaitOnWatch(limit_tp, std::move(wcb)); + const auto key_checker = [req_obj_type](EngineShard* owner, const DbContext& context, + Transaction*, std::string_view key) -> bool { + return owner->db_slice().FindReadOnly(context, key, req_obj_type).ok(); + }; + auto status = trans->WaitOnWatch(limit_tp, std::move(wcb), key_checker); *block_flag = false; if (status != OpStatus::OK) diff --git a/src/server/list_family.cc b/src/server/list_family.cc index 9ada11ecf..2c42f2687 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -881,8 +881,12 @@ OpResult BPopPusher::RunSingle(Transaction* t, time_point tp) { auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; }; + const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*, + std::string_view key) -> bool { + return owner->db_slice().FindReadOnly(context, key, OBJ_LIST).ok(); + }; // Block - if (auto status = t->WaitOnWatch(tp, std::move(wcb)); status != OpStatus::OK) + if (auto status = t->WaitOnWatch(tp, std::move(wcb), key_checker); status != OpStatus::OK) return status; t->Execute(cb_move, true); @@ -906,7 +910,12 @@ OpResult BPopPusher::RunPair(Transaction* t, time_point tp) { // This allows us to run Transaction::Execute on watched transactions in both shards. auto wcb = [&](Transaction* t, EngineShard* shard) { return ArgSlice{&this->pop_key_, 1}; }; - if (auto status = t->WaitOnWatch(tp, std::move(wcb)); status != OpStatus::OK) + const auto key_checker = [](EngineShard* owner, const DbContext& context, Transaction*, + std::string_view key) -> bool { + return owner->db_slice().FindReadOnly(context, key, OBJ_LIST).ok(); + }; + + if (auto status = t->WaitOnWatch(tp, std::move(wcb), key_checker); status != OpStatus::OK) return status; return MoveTwoShards(t, pop_key_, push_key_, popdir_, pushdir_, true); diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc index 9b6defc9c..d27f4ef8a 100644 --- a/src/server/stream_family.cc +++ b/src/server/stream_family.cc @@ -2817,7 +2817,28 @@ void XReadBlock(ReadOpts opts, ConnectionContext* cntx) { auto tp = (opts.timeout) ? chrono::steady_clock::now() + chrono::milliseconds(opts.timeout) : Transaction::time_point::max(); - if (auto status = cntx->transaction->WaitOnWatch(tp, std::move(wcb)); status != OpStatus::OK) + const auto key_checker = [&opts](EngineShard* owner, const DbContext& context, Transaction* tx, + std::string_view key) -> bool { + auto res_it = owner->db_slice().FindReadOnly(context, key, OBJ_STREAM); + if (!res_it.ok()) + return false; + + auto sitem = opts.stream_ids.at(key); + if (sitem.id.val.ms != UINT64_MAX && sitem.id.val.seq != UINT64_MAX) + return true; + + const CompactObj& cobj = (*res_it)->second; + stream* s = GetReadOnlyStream(cobj); + streamID last_id = s->last_id; + if (s->length) { + streamLastValidID(s, &last_id); + } + + return streamCompareID(&last_id, &sitem.group->last_id) > 0; + }; + + if (auto status = cntx->transaction->WaitOnWatch(tp, std::move(wcb), key_checker); + status != OpStatus::OK) return rb->SendNullArray(); // Resolve the entry in the woken key. Note this must not use OpRead since diff --git a/src/server/stream_family_test.cc b/src/server/stream_family_test.cc index bd27ca3d9..f154839b4 100644 --- a/src/server/stream_family_test.cc +++ b/src/server/stream_family_test.cc @@ -342,26 +342,18 @@ TEST_F(StreamFamilyTest, XReadGroupBlock) { ThisFiber::SleepFor(50us); pp_->at(1)->Await([&] { return Run("xadd", {"xadd", "bar", "1-*", "k5", "v5"}); }); // The second one should be unblocked + ThisFiber::SleepFor(50us); fb0.Join(); fb1.Join(); - // temporary incorrect results - if (resp0.GetVec()[1].GetVec().size() == 0) { - EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(0))); - EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1))); - } else { - EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1))); - EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(0))); - } - // correct results - // if (resp0.GetVec()[0].GetString() == "foo") { - // EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1))); - // EXPECT_THAT(resp1.GetVec(), ElementsAre("bar", ArrLen(1))); - // } else { - // EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1))); - // EXPECT_THAT(resp0.GetVec(), ElementsAre("bar", ArrLen(1))); - // } + if (resp0.GetVec()[0].GetString() == "foo") { + EXPECT_THAT(resp0.GetVec(), ElementsAre("foo", ArrLen(1))); + EXPECT_THAT(resp1.GetVec(), ElementsAre("bar", ArrLen(1))); + } else { + EXPECT_THAT(resp1.GetVec(), ElementsAre("foo", ArrLen(1))); + EXPECT_THAT(resp0.GetVec(), ElementsAre("bar", ArrLen(1))); + } } TEST_F(StreamFamilyTest, XReadInvalidArgs) { diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 5f9b226c0..7d62c716c 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -1203,13 +1203,14 @@ size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const { return reverse_index_[sd.arg_start + arg_index]; } -OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_provider) { +OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_provider, + KeyReadyChecker krc) { DVLOG(2) << "WaitOnWatch " << DebugId(); using namespace chrono; auto cb = [&](Transaction* t, EngineShard* shard) { auto keys = wkeys_provider(t, shard); - return t->WatchInShard(keys, shard); + return t->WatchInShard(keys, shard, krc); }; Execute(std::move(cb), true); @@ -1257,14 +1258,14 @@ OpStatus Transaction::WaitOnWatch(const time_point& tp, WaitKeysProvider wkeys_p } // Runs only in the shard thread. -OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard) { +OpStatus Transaction::WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc) { ShardId idx = SidToId(shard->shard_id()); auto& sd = shard_data_[idx]; CHECK_EQ(0, sd.local_mask & SUSPENDED_Q); auto* bc = shard->EnsureBlockingController(); - bc->AddWatched(keys, this); + bc->AddWatched(keys, std::move(krc), this); sd.local_mask |= SUSPENDED_Q; sd.local_mask &= ~OUT_OF_ORDER; diff --git a/src/server/transaction.h b/src/server/transaction.h index 479b56dfa..eb9c6b2a4 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -186,7 +186,7 @@ class Transaction { // or b) tp is reached. If tp is time_point::max() then waits indefinitely. // Expects that the transaction had been scheduled before, and uses Execute(.., true) to register. // Returns false if timeout occurred, true if was notified by one of the keys. - facade::OpStatus WaitOnWatch(const time_point& tp, WaitKeysProvider cb); + facade::OpStatus WaitOnWatch(const time_point& tp, WaitKeysProvider cb, KeyReadyChecker krc); // Returns true if transaction is awaked, false if it's timed-out and can be removed from the // blocking queue. @@ -456,7 +456,7 @@ class Transaction { void ExecuteAsync(); // Adds itself to watched queue in the shard. Must run in that shard thread. - OpStatus WatchInShard(ArgSlice keys, EngineShard* shard); + OpStatus WatchInShard(ArgSlice keys, EngineShard* shard, KeyReadyChecker krc); // Expire blocking transaction, unlock keys and unregister it from the blocking controller void ExpireBlocking(WaitKeysProvider wcb); diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 1e7571057..40bce478c 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -677,6 +677,26 @@ TEST_F(ZSetFamilyTest, BlockingIsReleased) { } } +TEST_F(ZSetFamilyTest, BlockingWithIncorrectType) { + RespExpr resp0; + RespExpr resp1; + auto fb0 = pp_->at(0)->LaunchFiber(Launch::dispatch, [&] { + resp0 = Run({"BLPOP", "list1", "0"}); + }); + auto fb1 = pp_->at(1)->LaunchFiber(Launch::dispatch, [&] { + resp1 = Run({"BZPOPMIN", "list1", "0"}); + }); + + ThisFiber::SleepFor(50us); + pp_->at(2)->Await([&] { return Run({"ZADD", "list1", "1", "a"}); }); + pp_->at(2)->Await([&] { return Run({"LPUSH", "list1", "0"}); }); + fb0.Join(); + fb1.Join(); + + EXPECT_THAT(resp1.GetVec(), ElementsAre("list1", "a", "1")); + EXPECT_THAT(resp0.GetVec(), ElementsAre("list1", "0")); +} + TEST_F(ZSetFamilyTest, BlockingTimeout) { RespExpr resp0;