From 64bbfc70632a8a2f2b93da4e5ab1f6bf6883629b Mon Sep 17 00:00:00 2001 From: Yue Li <61070669+theyueli@users.noreply.github.com> Date: Fri, 8 Dec 2023 23:13:55 -0800 Subject: [PATCH] feat(server): Support CLIENT TRACKING subcommand (1/2) (#2277) The client tracking state is set by CLIENT TRACKING subcommand as well as upon client disconnection. Track the keys of a readonly command by maintaining mapping that maps keys to the sets of tracking clients. --- src/facade/dragonfly_connection.cc | 12 +++++++++++- src/facade/dragonfly_connection.h | 9 ++++++++- src/facade/reply_builder.cc | 4 ++++ src/facade/reply_builder.h | 1 + src/server/db_slice.cc | 15 +++++++++++++++ src/server/db_slice.h | 14 ++++++++++++++ src/server/main_service.cc | 21 ++++++++++++++++++++- src/server/server_family.cc | 26 ++++++++++++++++++++++++++ src/server/server_family.h | 1 + src/server/transaction.cc | 6 ++++++ src/server/transaction.h | 2 ++ 11 files changed, 108 insertions(+), 3 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index c9b720736..e21e3b6bb 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1353,6 +1353,16 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) { migration_request_ = dest; } +void Connection::SetClientTrackingSwitch(bool is_on) { + tracking_enabled_ = is_on; + if (tracking_enabled_) + cc_->subscriptions++; +} + +bool Connection::IsTrackingOn() const { + return tracking_enabled_; +} + Connection::MemoryUsage Connection::GetMemoryUsage() const { size_t mem = sizeof(*this) + dfly::HeapSize(dispatch_q_) + dfly::HeapSize(name_) + dfly::HeapSize(tmp_parse_args_) + dfly::HeapSize(tmp_cmd_vec_) + @@ -1421,7 +1431,7 @@ bool Connection::WeakRef::operator<(const WeakRef& other) { return client_id_ < other.client_id_; } -bool Connection::WeakRef::operator==(const WeakRef& other) { +bool Connection::WeakRef::operator==(const WeakRef& other) const { return client_id_ == other.client_id_; } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 3b218b75f..4d502e57e 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -169,7 +169,7 @@ class Connection : public util::Connection { bool EnsureMemoryBudget() const; bool operator<(const WeakRef& other); - bool operator==(const WeakRef& other); + bool operator==(const WeakRef& other) const; private: friend class Connection; @@ -263,6 +263,10 @@ class Connection : public util::Connection { // Connections will migrate at most once, and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest); + void SetClientTrackingSwitch(bool is_on); + + bool IsTrackingOn() const; + protected: void OnShutdown() override; void OnPreMigrateThread() override; @@ -402,6 +406,9 @@ class Connection : public util::Connection { // Per-thread queue backpressure structs. static thread_local QueueBackpressure tl_queue_backpressure_; + + // a flag indicating whether the client has turned on client tracking. + bool tracking_enabled_ = false; }; } // namespace facade diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index ff83838dd..2acc276b6 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -269,6 +269,10 @@ void RedisReplyBuilder::SetResp3(bool is_resp3) { is_resp3_ = is_resp3; } +bool RedisReplyBuilder::IsResp3() const { + return is_resp3_; +} + void RedisReplyBuilder::SendError(string_view str, string_view err_type) { VLOG(1) << "Error: " << str; diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index a6e35aa82..c3cc437c8 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -219,6 +219,7 @@ class RedisReplyBuilder : public SinkReplyBuilder { RedisReplyBuilder(::io::Sink* stream); void SetResp3(bool is_resp3); + bool IsResp3() const; void SendError(std::string_view str, std::string_view type = {}) override; using SinkReplyBuilder::SendError; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 5e9ab5389..556191e6e 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1340,4 +1340,19 @@ void DbSlice::ResetUpdateEvents() { events_.update = 0; } +void DbSlice::TrackKeys(const facade::Connection::WeakRef& conn, const ArgSlice& keys) { + if (conn.IsExpired()) { + DVLOG(2) << "Connection expired, exiting TrackKey function."; + return; + } + + DVLOG(2) << "Start tracking keys for client ID: " << conn.GetClientId() + << " with thread ID: " << conn.Thread(); + for (auto key : keys) { + DVLOG(2) << "Inserting client ID " << conn.GetClientId() + << " into the tracking client set of key " << key; + client_tracking_map_[key].insert(conn); + } +} + } // namespace dfly diff --git a/src/server/db_slice.h b/src/server/db_slice.h index b19107982..b2ce97add 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -4,6 +4,7 @@ #pragma once +#include "facade/dragonfly_connection.h" #include "facade/op_status.h" #include "server/common.h" #include "server/conn_context.h" @@ -334,6 +335,9 @@ class DbSlice { expire_allowed_ = is_allowed; } + // Track keys for the client represented by the the weak reference to its connection. + void TrackKeys(const facade::Connection::WeakRef&, const ArgSlice&); + private: // Releases a single key. `key` must have been normalized by GetLockKey(). void ReleaseNormalized(IntentLock::Mode m, DbIndex db_index, std::string_view key, @@ -385,6 +389,16 @@ class DbSlice { // Registered by shard indices on when first document index is created. DocDeletionCallback doc_del_cb_; + + struct Hash { + size_t operator()(const facade::Connection::WeakRef& c) const { + return std::hash()(c.GetClientId()); + } + }; + + // the table that maps keys to the clients that are tracking them. + absl::flat_hash_map> + client_tracking_map_; }; } // namespace dfly diff --git a/src/server/main_service.cc b/src/server/main_service.cc index ca8065d32..3a293cd8c 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1042,6 +1042,12 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); } +OpResult OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, const ArgSlice& keys) { + auto& db_slice = op_args.shard->db_slice(); + db_slice.TrackKeys(cntx->conn()->Borrow(), keys); + return OpStatus::OK; +} + void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { CHECK(!args.empty()); DCHECK_NE(0u, shard_set->size()) << "Init was not called"; @@ -1149,6 +1155,18 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->reply_builder()->CloseConnection(); } + // if this is a read command, and client tracking has enabled, + // start tracking all the updates to the keys in this read command + if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn()) { + auto cb = [&](Transaction* t, EngineShard* shard) { + auto keys = t->GetShardArgs(shard->shard_id()); + return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, keys); + }; + + dfly_cntx->transaction->Refurbish(); + dfly_cntx->transaction->ScheduleSingleHopT(cb); + } + if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } @@ -1466,7 +1484,6 @@ void Service::Quit(CmdArgList args, ConnectionContext* cntx) { if (cntx->protocol() == facade::Protocol::REDIS) cntx->SendOk(); using facade::SinkReplyBuilder; - SinkReplyBuilder* builder = cntx->reply_builder(); builder->CloseConnection(); @@ -2364,6 +2381,8 @@ void Service::OnClose(facade::ConnectionContext* cntx) { DeactivateMonitoring(server_cntx); server_family_.OnClose(server_cntx); + + cntx->conn()->SetClientTrackingSwitch(false); } string Service::GetContextInfo(facade::ConnectionContext* cntx) { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index fe65fef14..92e1bba8f 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1238,6 +1238,8 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { return ClientList(sub_args, cntx); } else if (sub_cmd == "PAUSE") { return ClientPause(sub_args, cntx); + } else if (sub_cmd == "TRACKING") { + return ClientTracking(sub_args, cntx); } if (sub_cmd == "SETINFO") { @@ -1357,6 +1359,30 @@ void ServerFamily::ClientPause(CmdArgList args, ConnectionContext* cntx) { cntx->SendOk(); } +void ServerFamily::ClientTracking(CmdArgList args, ConnectionContext* cntx) { + if (args.size() != 1) + return cntx->SendError(kSyntaxErr); + + auto* rb = static_cast(cntx->reply_builder()); + if (!rb->IsResp3()) + return cntx->SendError( + "Client tracking is currently not supported for RESP2. Please use RESP3."); + + ToUpper(&args[0]); + string_view state = ArgS(args, 0); + bool is_on; + if (state == "ON") { + is_on = true; + } else if (state == "OFF") { + is_on = false; + } else { + return cntx->SendError(kSyntaxErr); + } + + cntx->conn()->SetClientTrackingSwitch(is_on); + return cntx->SendOk(); +} + void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { ToUpper(&args[0]); string_view sub_cmd = ArgS(args, 0); diff --git a/src/server/server_family.h b/src/server/server_family.h index 124fedc62..192869b04 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -214,6 +214,7 @@ class ServerFamily { void ClientGetName(CmdArgList args, ConnectionContext* cntx); void ClientList(CmdArgList args, ConnectionContext* cntx); void ClientPause(CmdArgList args, ConnectionContext* cntx); + void ClientTracking(CmdArgList args, ConnectionContext* cntx); void Config(CmdArgList args, ConnectionContext* cntx); void DbSize(CmdArgList args, ConnectionContext* cntx); void Debug(CmdArgList args, ConnectionContext* cntx); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 71b127451..b7f61de86 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -906,6 +906,12 @@ void Transaction::Conclude() { Execute(std::move(cb), true); } +void Transaction::Refurbish() { + txid_ = 0; + coordinator_state_ = 0; + cb_ptr_ = nullptr; +} + void Transaction::EnableShard(ShardId sid) { unique_shard_cnt_ = 1; unique_shard_id_ = sid; diff --git a/src/server/transaction.h b/src/server/transaction.h index 6d76e24cf..64d175e28 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -323,6 +323,8 @@ class Transaction { // Utility to run a single hop on a no-key command static void RunOnceAsCommand(const CommandId* cid, RunnableType cb); + void Refurbish(); + private: // Holds number of locks for each IntentLock::Mode: shared and exlusive. struct LockCnt {