diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 373733ff5..14bfac339 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -831,9 +831,10 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil io_buf_.CommitWrite(*recv_sz); stats_->io_read_bytes += *recv_sz; ++stats_->io_read_cnt; + phase_ = PROCESS; bool is_iobuf_full = io_buf_.AppendLen() == 0; - + service_->AwaitOnPauseDispatch(); if (redis_parser_) { parse_status = ParseRedis(orig_builder); } else { diff --git a/src/facade/ok_main.cc b/src/facade/ok_main.cc index 0d4a5803b..b00844591 100644 --- a/src/facade/ok_main.cc +++ b/src/facade/ok_main.cc @@ -44,6 +44,10 @@ class OkService : public ServiceInterface { ConnectionStats* GetThreadLocalConnectionStats() final { return &tl_stats; } + + void AwaitOnPauseDispatch() { + return; + } }; void RunEngine(ProactorPool* pool, AcceptServer* acceptor) { diff --git a/src/facade/service_interface.h b/src/facade/service_interface.h index 7f02604cc..da666e0c9 100644 --- a/src/facade/service_interface.h +++ b/src/facade/service_interface.h @@ -35,6 +35,7 @@ class ServiceInterface { virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0; virtual ConnectionStats* GetThreadLocalConnectionStats() = 0; + virtual void AwaitOnPauseDispatch() = 0; virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) { } diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 78530893c..b5c9aee34 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -235,6 +235,7 @@ void ConnectionContext::SendSubscriptionChangedResponse(string_view action, void ConnectionState::ExecInfo::Clear() { state = EXEC_INACTIVE; body.clear(); + is_write = false; ClearWatched(); } diff --git a/src/server/conn_context.h b/src/server/conn_context.h index f962864a6..f0bde27b0 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -82,6 +82,7 @@ struct ConnectionState { ExecState state = EXEC_INACTIVE; std::vector body; + bool is_write = false; std::vector> watched_keys; // List of keys registered by WATCH std::atomic_bool watched_dirty = false; // Set if a watched key was changed before EXEC diff --git a/src/server/main_service.cc b/src/server/main_service.cc index ba575e849..b128fa70a 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1067,6 +1067,14 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) << " in dbid=" << dfly_cntx->conn_state.db_index; } + string_view cmd_name(cid->name()); + bool is_write = (cid->opt_mask() & CO::WRITE) || cmd_name == "PUBLISH" || cmd_name == "EVAL" || + cmd_name == "EVALSHA"; + if (cmd_name == "EXEC" && dfly_cntx->conn_state.exec_info.is_write) { + is_write = true; + } + etl.AwaitPauseState(is_write); + etl.RecordCmd(); if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) { @@ -1082,7 +1090,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // TODO: protect against aggregating huge transactions. StoredCmd stored_cmd{cid, args_no_cmd}; dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd)); - + if (stored_cmd.Cid()->opt_mask() & CO::WRITE) { + dfly_cntx->conn_state.exec_info.is_write = true; + } return cntx->SendSimpleString("QUEUED"); } @@ -1254,8 +1264,9 @@ void Service::DispatchManyCommands(absl::Span args_list, // invocations, we can potentially execute multiple eval in parallel, which is very powerful // paired with shardlocal eval const bool is_eval = CO::IsEvalKind(ArgS(args, 0)); + const bool is_pause = dfly::ServerState::tlocal()->IsPaused(); - if (!is_multi && !is_eval && cid != nullptr) { + if (!is_multi && !is_eval && cid != nullptr && !is_pause) { stored_cmds.reserve(args_list.size()); stored_cmds.emplace_back(cid, tail_args); continue; @@ -1410,6 +1421,10 @@ facade::ConnectionStats* Service::GetThreadLocalConnectionStats() { return ServerState::tl_connection_stats(); } +void Service::AwaitOnPauseDispatch() { + ServerState::tlocal()->AwaitOnPauseDispatch(); +} + const CommandId* Service::FindCmd(std::string_view cmd) const { return registry_.Find(cmd); } diff --git a/src/server/main_service.h b/src/server/main_service.h index 190cf40c7..84830cc1c 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -73,6 +73,7 @@ class Service : public facade::ServiceInterface { facade::Connection* owner) final; facade::ConnectionStats* GetThreadLocalConnectionStats() final; + void AwaitOnPauseDispatch() final; std::pair FindCmd(CmdArgList args) const; const CommandId* FindCmd(std::string_view) const; diff --git a/src/server/multi_command_squasher.h b/src/server/multi_command_squasher.h index df08ed700..61c150628 100644 --- a/src/server/multi_command_squasher.h +++ b/src/server/multi_command_squasher.h @@ -30,7 +30,7 @@ class MultiCommandSquasher { } private: - // Per-shard exection info. + // Per-shard execution info. struct ShardExecInfo { ShardExecInfo() : had_writes{false}, cmds{}, replies{}, local_tx{nullptr} { } @@ -74,7 +74,7 @@ class MultiCommandSquasher { ConnectionContext* cntx_; // Underlying context Service* service_; - bool atomic_; // Wheter working in any of the atomic modes + bool atomic_; // Whether working in any of the atomic modes const CommandId* base_cid_; // underlying cid (exec or eval) for executing batch hops bool verify_commands_ = false; // Whether commands need to be verified before execution diff --git a/src/server/server_family.cc b/src/server/server_family.cc index defc19f3b..79d6e0f30 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -27,6 +27,7 @@ extern "C" { #include "base/flags.h" #include "base/logging.h" +#include "facade/cmd_arg_parser.h" #include "facade/dragonfly_connection.h" #include "facade/reply_builder.h" #include "io/file_util.h" @@ -1213,41 +1214,16 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { ToUpper(&args[0]); string_view sub_cmd = ArgS(args, 0); + CmdArgList sub_args = args.subspan(1); - if (sub_cmd == "SETNAME" && args.size() == 2) { - cntx->conn()->SetName(string{ArgS(args, 1)}); - return (*cntx)->SendOk(); - } - - if (sub_cmd == "GETNAME") { - auto name = cntx->conn()->GetName(); - if (!name.empty()) { - return (*cntx)->SendBulkString(name); - } else { - return (*cntx)->SendNull(); - } - } - - if (sub_cmd == "LIST") { - vector client_info; - absl::base_internal::SpinLock mu; - - // we can not preempt the connection traversal, so we need to use a spinlock. - // alternatively we could lock when mutating the connection list, but it seems not important. - auto cb = [&](unsigned thread_index, util::Connection* conn) { - facade::Connection* dcon = static_cast(conn); - string info = dcon->GetClientInfo(thread_index); - absl::base_internal::SpinLockHolder l(&mu); - client_info.push_back(move(info)); - }; - - for (auto* listener : listeners_) { - listener->TraverseConnections(cb); - } - - string result = absl::StrJoin(move(client_info), "\n"); - result.append("\n"); - return (*cntx)->SendBulkString(result); + if (sub_cmd == "SETNAME") { + return ClientSetName(sub_args, cntx); + } else if (sub_cmd == "GETNAME") { + return ClientGetName(sub_args, cntx); + } else if (sub_cmd == "LIST") { + return ClientList(sub_args, cntx); + } else if (sub_cmd == "PAUSE") { + return ClientPause(sub_args, cntx); } if (sub_cmd == "SETINFO") { @@ -1258,6 +1234,121 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendError(UnknownSubCmd(sub_cmd, "CLIENT"), kSyntaxErrType); } +void ServerFamily::ClientSetName(CmdArgList args, ConnectionContext* cntx) { + if (args.size() == 1) { + cntx->conn()->SetName(string{ArgS(args, 0)}); + return (*cntx)->SendOk(); + } else { + return (*cntx)->SendError(facade::kSyntaxErr); + } +} + +void ServerFamily::ClientGetName(CmdArgList args, ConnectionContext* cntx) { + if (!args.empty()) { + return (*cntx)->SendError(facade::kSyntaxErr); + } + auto name = cntx->conn()->GetName(); + if (!name.empty()) { + return (*cntx)->SendBulkString(name); + } else { + return (*cntx)->SendNull(); + } +} + +void ServerFamily::ClientList(CmdArgList args, ConnectionContext* cntx) { + if (!args.empty()) { + return (*cntx)->SendError(facade::kSyntaxErr); + } + + vector client_info; + absl::base_internal::SpinLock mu; + + // we can not preempt the connection traversal, so we need to use a spinlock. + // alternatively we could lock when mutating the connection list, but it seems not important. + auto cb = [&](unsigned thread_index, util::Connection* conn) { + facade::Connection* dcon = static_cast(conn); + string info = dcon->GetClientInfo(thread_index); + absl::base_internal::SpinLockHolder l(&mu); + client_info.push_back(std::move(info)); + }; + + for (auto* listener : listeners_) { + listener->TraverseConnections(cb); + } + + string result = absl::StrJoin(client_info, "\n"); + result.append("\n"); + return (*cntx)->SendBulkString(result); +} + +void ServerFamily::ClientPause(CmdArgList args, ConnectionContext* cntx) { + CmdArgParser parser(args); + + auto timeout = parser.Next().Int(); + enum ClientPause pause_state = ClientPause::ALL; + if (parser.HasNext()) { + pause_state = + parser.ToUpper().Next().Case("WRITE", ClientPause::WRITE).Case("ALL", ClientPause::ALL); + } + if (auto err = parser.Error(); err) { + return (*cntx)->SendError(err->MakeReply()); + } + + // Pause dispatch commands before updating client puase state, and enable dispatch after updating + // pause state. This will unsure that when we after changing the state all running commands will + // read the new pause state, and we will not pause client in the middle of a transaction. + service_.proactor_pool().Await([](util::ProactorBase* pb) { + ServerState& etl = *ServerState::tlocal(); + etl.SetPauseDispatch(true); + }); + + // TODO handle blocking commands + const absl::Duration kDispatchTimeout = absl::Seconds(1); + if (!AwaitDispatches(kDispatchTimeout, [self = cntx->conn()](util::Connection* conn) { + // Wait until the only command dispatching is the client pause command. + return conn != self; + })) { + LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << kDispatchTimeout; + service_.proactor_pool().Await([](util::ProactorBase* pb) { + ServerState& etl = *ServerState::tlocal(); + etl.SetPauseDispatch(false); + }); + return (*cntx)->SendError("Failed to pause all running clients"); + } + + service_.proactor_pool().AwaitFiberOnAll([pause_state](util::ProactorBase* pb) { + ServerState& etl = *ServerState::tlocal(); + etl.SetPauseState(pause_state, true); + etl.SetPauseDispatch(false); + }); + + // We should not expire/evict keys while clients are puased. + shard_set->RunBriefInParallel( + [](EngineShard* shard) { shard->db_slice().SetExpireAllowed(false); }); + + fb2::Fiber("client_pause", [this, timeout, pause_state]() mutable { + // On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps + // ensure this fiber will not left hanging . + auto step = 10ms; + auto timeout_ms = timeout * 1ms; + int64_t steps = timeout_ms.count() / step.count(); + ServerState& etl = *ServerState::tlocal(); + do { + ThisFiber::SleepFor(step); + } while (etl.gstate() != GlobalState::SHUTTING_DOWN && --steps > 0); + + if (etl.gstate() != GlobalState::SHUTTING_DOWN) { + service_.proactor_pool().AwaitFiberOnAll([pause_state](util::ProactorBase* pb) { + ServerState::tlocal()->SetPauseState(pause_state, false); + }); + shard_set->RunBriefInParallel( + [](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); }); + } + }).Detach(); + + (*cntx)->SendOk(); +} + void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { ToUpper(&args[0]); string_view sub_cmd = ArgS(args, 0); diff --git a/src/server/server_family.h b/src/server/server_family.h index a615a60e8..e5b6ca6a7 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -211,6 +211,10 @@ class ServerFamily { void Auth(CmdArgList args, ConnectionContext* cntx); void Client(CmdArgList args, ConnectionContext* cntx); + void ClientSetName(CmdArgList args, ConnectionContext* cntx); + void ClientGetName(CmdArgList args, ConnectionContext* cntx); + void ClientList(CmdArgList args, ConnectionContext* cntx); + void ClientPause(CmdArgList args, ConnectionContext* cntx); void Config(CmdArgList args, ConnectionContext* cntx); void DbSize(CmdArgList args, ConnectionContext* cntx); void Debug(CmdArgList args, ConnectionContext* cntx); diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index 07bdb956f..e0a12c269 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -4,6 +4,8 @@ #include "server/server_family.h" +#include + #include "base/gtest.h" #include "base/logging.h" #include "facade/facade_test.h" @@ -181,4 +183,21 @@ TEST_F(ServerFamilyTest, SlowLogMinusOneDisabled) { EXPECT_THAT(resp.GetInt(), 0); } +TEST_F(ServerFamilyTest, ClientPause) { + auto start = absl::Now(); + Run({"CLIENT", "PAUSE", "50"}); + + Run({"get", "key"}); + EXPECT_GT((absl::Now() - start), absl::Milliseconds(50)); + + start = absl::Now(); + + Run({"CLIENT", "PAUSE", "50", "WRITE"}); + + Run({"get", "key"}); + EXPECT_LT((absl::Now() - start), absl::Milliseconds(10)); + Run({"set", "key", "value2"}); + EXPECT_GT((absl::Now() - start), absl::Milliseconds(50)); +} + } // namespace dfly diff --git a/src/server/server_state.cc b/src/server/server_state.cc index 7d820c5a4..9419c3f30 100644 --- a/src/server/server_state.cc +++ b/src/server/server_state.cc @@ -112,6 +112,45 @@ bool ServerState::AllowInlineScheduling() const { return true; } +void ServerState::SetPauseState(ClientPause state, bool start) { + client_pauses_[int(state)] += (start ? 1 : -1); + if (!client_pauses_[int(state)]) { + client_pause_ec_.notifyAll(); + } +} + +bool ServerState::IsPaused() const { + return client_pauses_[0] || client_pauses_[1]; +} + +void ServerState::AwaitPauseState(bool is_write) { + client_pause_ec_.await([is_write, this]() { + if (client_pauses_[int(ClientPause::ALL)]) { + return false; + } + if (is_write && client_pauses_[int(ClientPause::WRITE)]) { + return false; + } + return true; + }); +} + +void ServerState::AwaitOnPauseDispatch() { + pause_dispatch_ec_.await([this]() { + if (pause_dispatch_) { + return false; + } + return true; + }); +} + +void ServerState::SetPauseDispatch(bool pause) { + pause_dispatch_ = pause; + if (!pause_dispatch_) { + pause_dispatch_ec_.notifyAll(); + } +} + Interpreter* ServerState::BorrowInterpreter() { return interpreter_mgr_.Get(); } diff --git a/src/server/server_state.h b/src/server/server_state.h index fcd015b0e..2a2478257 100644 --- a/src/server/server_state.h +++ b/src/server/server_state.h @@ -80,6 +80,8 @@ class MonitorsRepo { unsigned int global_count_ = 0; // by global its means that we count the monitor for all threads }; +enum class ClientPause { WRITE, ALL }; + // Present in every server thread. This class differs from EngineShard. The latter manages // state around engine shards while the former represents coordinator/connection state. // There may be threads that handle engine shards but not IO, there may be threads that handle IO @@ -220,6 +222,24 @@ class ServerState { // public struct - to allow initialization. acl::AclLog acl_log; + // Starts or ends a `CLIENT PAUSE` command. @state controls whether + // this is pausing only writes or every command, @start controls + // whether this is starting or ending the pause. + void SetPauseState(ClientPause state, bool start); + + // Returns whether any type of commands is paused. + bool IsPaused() const; + + // Awaits until the pause is over and the command can execute. + // @is_write controls whether the command is a write command or not. + void AwaitPauseState(bool is_write); + + // Toggle a boolean indicating whether the server should temporarily pause or allow dispatching + // new commands. + void SetPauseDispatch(bool pause); + // Awaits until dispatching new commands is allowed as determinded by SetPauseDispatch function + void AwaitOnPauseDispatch(); + SlowLogShard& GetSlowLog() { return slow_log_shard_; }; @@ -237,6 +257,15 @@ class ServerState { // public struct - to allow initialization. GlobalState gstate_ = GlobalState::ACTIVE; + // To support concurrent `CLIENT PAUSE commands` correctly, we store the amount + // of current CLIENT PAUSE commands that are in effect. Blocked execution fibers + // should subscribe to `client_pause_ec_` through `AwaitPauseState` to be + // notified when the break is over. + int client_pauses_[2] = {}; + EventCount client_pause_ec_; + bool pause_dispatch_ = false; + EventCount pause_dispatch_ec_; + using Counter = util::SlidingCounter<7>; Counter qps_; diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index 2b68c511e..309d9936f 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -1761,3 +1761,49 @@ async def test_search(df_local_factory): assert (await c_replica.ft("idx-m2").search(Query("*").sort_by("f2").paging(0, 1))).docs[ 0 ].id == "k0" + + +# @pytest.mark.slow +@pytest.mark.asyncio +async def test_client_pause_with_replica(df_local_factory, df_seeder_factory): + master = df_local_factory.create(proactor_threads=4) + replica = df_local_factory.create(proactor_threads=4) + df_local_factory.start_all([master, replica]) + + seeder = df_seeder_factory.create(port=master.port) + + c_master = master.client() + c_replica = replica.client() + + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + await wait_available_async(c_replica) + + fill_task = asyncio.create_task(seeder.run()) + + # Give the seeder a bit of time. + await asyncio.sleep(1) + # block the seeder for 4 seconds + await c_master.execute_command("client pause 4000 write") + stats = await c_master.info("CommandStats") + info = await c_master.info("Stats") + await asyncio.sleep(0.5) + stats_after_sleep = await c_master.info("CommandStats") + # Check no commands are executed except info and replconf called from replica + for cmd, cmd_stats in stats_after_sleep.items(): + if "cmdstat_INFO" != cmd and "cmdstat_REPLCONF" != cmd_stats: + assert stats[cmd] == cmd_stats + + await asyncio.sleep(6) + seeder.stop() + await fill_task + stats_after_pause_finish = await c_master.info("CommandStats") + more_exeuted = False + for cmd, cmd_stats in stats_after_pause_finish.items(): + if "cmdstat_INFO" != cmd and "cmdstat_REPLCONF" != cmd_stats and stats[cmd] != cmd_stats: + more_exeuted = True + assert more_exeuted + + capture = await seeder.capture(port=master.port) + assert await seeder.compare(capture, port=replica.port) + + await disconnect_clients(c_master, c_replica)