feat(server): Implement CLIENT PAUSE (#1875)

* feat(server): Implement CLIENT PAUSE

Signed-off-by: adi_holden <adi@dragonflydb.io>
This commit is contained in:
Roy Jacobson 2023-11-15 08:56:49 +02:00 committed by GitHub
parent 1ec1b997a0
commit c3a2da559e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 291 additions and 39 deletions

View file

@ -831,9 +831,10 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil
io_buf_.CommitWrite(*recv_sz); io_buf_.CommitWrite(*recv_sz);
stats_->io_read_bytes += *recv_sz; stats_->io_read_bytes += *recv_sz;
++stats_->io_read_cnt; ++stats_->io_read_cnt;
phase_ = PROCESS; phase_ = PROCESS;
bool is_iobuf_full = io_buf_.AppendLen() == 0; bool is_iobuf_full = io_buf_.AppendLen() == 0;
service_->AwaitOnPauseDispatch();
if (redis_parser_) { if (redis_parser_) {
parse_status = ParseRedis(orig_builder); parse_status = ParseRedis(orig_builder);
} else { } else {

View file

@ -44,6 +44,10 @@ class OkService : public ServiceInterface {
ConnectionStats* GetThreadLocalConnectionStats() final { ConnectionStats* GetThreadLocalConnectionStats() final {
return &tl_stats; return &tl_stats;
} }
void AwaitOnPauseDispatch() {
return;
}
}; };
void RunEngine(ProactorPool* pool, AcceptServer* acceptor) { void RunEngine(ProactorPool* pool, AcceptServer* acceptor) {

View file

@ -35,6 +35,7 @@ class ServiceInterface {
virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0; virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0;
virtual ConnectionStats* GetThreadLocalConnectionStats() = 0; virtual ConnectionStats* GetThreadLocalConnectionStats() = 0;
virtual void AwaitOnPauseDispatch() = 0;
virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) { virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) {
} }

View file

@ -235,6 +235,7 @@ void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
void ConnectionState::ExecInfo::Clear() { void ConnectionState::ExecInfo::Clear() {
state = EXEC_INACTIVE; state = EXEC_INACTIVE;
body.clear(); body.clear();
is_write = false;
ClearWatched(); ClearWatched();
} }

View file

@ -82,6 +82,7 @@ struct ConnectionState {
ExecState state = EXEC_INACTIVE; ExecState state = EXEC_INACTIVE;
std::vector<StoredCmd> body; std::vector<StoredCmd> body;
bool is_write = false;
std::vector<std::pair<DbIndex, std::string>> watched_keys; // List of keys registered by WATCH std::vector<std::pair<DbIndex, std::string>> watched_keys; // List of keys registered by WATCH
std::atomic_bool watched_dirty = false; // Set if a watched key was changed before EXEC std::atomic_bool watched_dirty = false; // Set if a watched key was changed before EXEC

View file

@ -1067,6 +1067,14 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
<< " in dbid=" << dfly_cntx->conn_state.db_index; << " in dbid=" << dfly_cntx->conn_state.db_index;
} }
string_view cmd_name(cid->name());
bool is_write = (cid->opt_mask() & CO::WRITE) || cmd_name == "PUBLISH" || cmd_name == "EVAL" ||
cmd_name == "EVALSHA";
if (cmd_name == "EXEC" && dfly_cntx->conn_state.exec_info.is_write) {
is_write = true;
}
etl.AwaitPauseState(is_write);
etl.RecordCmd(); etl.RecordCmd();
if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) { if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) {
@ -1082,7 +1090,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
// TODO: protect against aggregating huge transactions. // TODO: protect against aggregating huge transactions.
StoredCmd stored_cmd{cid, args_no_cmd}; StoredCmd stored_cmd{cid, args_no_cmd};
dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd)); dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd));
if (stored_cmd.Cid()->opt_mask() & CO::WRITE) {
dfly_cntx->conn_state.exec_info.is_write = true;
}
return cntx->SendSimpleString("QUEUED"); return cntx->SendSimpleString("QUEUED");
} }
@ -1254,8 +1264,9 @@ void Service::DispatchManyCommands(absl::Span<CmdArgList> args_list,
// invocations, we can potentially execute multiple eval in parallel, which is very powerful // invocations, we can potentially execute multiple eval in parallel, which is very powerful
// paired with shardlocal eval // paired with shardlocal eval
const bool is_eval = CO::IsEvalKind(ArgS(args, 0)); const bool is_eval = CO::IsEvalKind(ArgS(args, 0));
const bool is_pause = dfly::ServerState::tlocal()->IsPaused();
if (!is_multi && !is_eval && cid != nullptr) { if (!is_multi && !is_eval && cid != nullptr && !is_pause) {
stored_cmds.reserve(args_list.size()); stored_cmds.reserve(args_list.size());
stored_cmds.emplace_back(cid, tail_args); stored_cmds.emplace_back(cid, tail_args);
continue; continue;
@ -1410,6 +1421,10 @@ facade::ConnectionStats* Service::GetThreadLocalConnectionStats() {
return ServerState::tl_connection_stats(); return ServerState::tl_connection_stats();
} }
void Service::AwaitOnPauseDispatch() {
ServerState::tlocal()->AwaitOnPauseDispatch();
}
const CommandId* Service::FindCmd(std::string_view cmd) const { const CommandId* Service::FindCmd(std::string_view cmd) const {
return registry_.Find(cmd); return registry_.Find(cmd);
} }

View file

@ -73,6 +73,7 @@ class Service : public facade::ServiceInterface {
facade::Connection* owner) final; facade::Connection* owner) final;
facade::ConnectionStats* GetThreadLocalConnectionStats() final; facade::ConnectionStats* GetThreadLocalConnectionStats() final;
void AwaitOnPauseDispatch() final;
std::pair<const CommandId*, CmdArgList> FindCmd(CmdArgList args) const; std::pair<const CommandId*, CmdArgList> FindCmd(CmdArgList args) const;
const CommandId* FindCmd(std::string_view) const; const CommandId* FindCmd(std::string_view) const;

View file

@ -30,7 +30,7 @@ class MultiCommandSquasher {
} }
private: private:
// Per-shard exection info. // Per-shard execution info.
struct ShardExecInfo { struct ShardExecInfo {
ShardExecInfo() : had_writes{false}, cmds{}, replies{}, local_tx{nullptr} { ShardExecInfo() : had_writes{false}, cmds{}, replies{}, local_tx{nullptr} {
} }
@ -74,7 +74,7 @@ class MultiCommandSquasher {
ConnectionContext* cntx_; // Underlying context ConnectionContext* cntx_; // Underlying context
Service* service_; Service* service_;
bool atomic_; // Wheter working in any of the atomic modes bool atomic_; // Whether working in any of the atomic modes
const CommandId* base_cid_; // underlying cid (exec or eval) for executing batch hops const CommandId* base_cid_; // underlying cid (exec or eval) for executing batch hops
bool verify_commands_ = false; // Whether commands need to be verified before execution bool verify_commands_ = false; // Whether commands need to be verified before execution

View file

@ -27,6 +27,7 @@ extern "C" {
#include "base/flags.h" #include "base/flags.h"
#include "base/logging.h" #include "base/logging.h"
#include "facade/cmd_arg_parser.h"
#include "facade/dragonfly_connection.h" #include "facade/dragonfly_connection.h"
#include "facade/reply_builder.h" #include "facade/reply_builder.h"
#include "io/file_util.h" #include "io/file_util.h"
@ -1213,22 +1214,52 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) {
ToUpper(&args[0]); ToUpper(&args[0]);
string_view sub_cmd = ArgS(args, 0); string_view sub_cmd = ArgS(args, 0);
CmdArgList sub_args = args.subspan(1);
if (sub_cmd == "SETNAME" && args.size() == 2) { if (sub_cmd == "SETNAME") {
cntx->conn()->SetName(string{ArgS(args, 1)}); return ClientSetName(sub_args, cntx);
} else if (sub_cmd == "GETNAME") {
return ClientGetName(sub_args, cntx);
} else if (sub_cmd == "LIST") {
return ClientList(sub_args, cntx);
} else if (sub_cmd == "PAUSE") {
return ClientPause(sub_args, cntx);
}
if (sub_cmd == "SETINFO") {
return (*cntx)->SendOk(); return (*cntx)->SendOk();
} }
if (sub_cmd == "GETNAME") { LOG_FIRST_N(ERROR, 10) << "Subcommand " << sub_cmd << " not supported";
return (*cntx)->SendError(UnknownSubCmd(sub_cmd, "CLIENT"), kSyntaxErrType);
}
void ServerFamily::ClientSetName(CmdArgList args, ConnectionContext* cntx) {
if (args.size() == 1) {
cntx->conn()->SetName(string{ArgS(args, 0)});
return (*cntx)->SendOk();
} else {
return (*cntx)->SendError(facade::kSyntaxErr);
}
}
void ServerFamily::ClientGetName(CmdArgList args, ConnectionContext* cntx) {
if (!args.empty()) {
return (*cntx)->SendError(facade::kSyntaxErr);
}
auto name = cntx->conn()->GetName(); auto name = cntx->conn()->GetName();
if (!name.empty()) { if (!name.empty()) {
return (*cntx)->SendBulkString(name); return (*cntx)->SendBulkString(name);
} else { } else {
return (*cntx)->SendNull(); return (*cntx)->SendNull();
} }
}
void ServerFamily::ClientList(CmdArgList args, ConnectionContext* cntx) {
if (!args.empty()) {
return (*cntx)->SendError(facade::kSyntaxErr);
} }
if (sub_cmd == "LIST") {
vector<string> client_info; vector<string> client_info;
absl::base_internal::SpinLock mu; absl::base_internal::SpinLock mu;
@ -1238,24 +1269,84 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) {
facade::Connection* dcon = static_cast<facade::Connection*>(conn); facade::Connection* dcon = static_cast<facade::Connection*>(conn);
string info = dcon->GetClientInfo(thread_index); string info = dcon->GetClientInfo(thread_index);
absl::base_internal::SpinLockHolder l(&mu); absl::base_internal::SpinLockHolder l(&mu);
client_info.push_back(move(info)); client_info.push_back(std::move(info));
}; };
for (auto* listener : listeners_) { for (auto* listener : listeners_) {
listener->TraverseConnections(cb); listener->TraverseConnections(cb);
} }
string result = absl::StrJoin(move(client_info), "\n"); string result = absl::StrJoin(client_info, "\n");
result.append("\n"); result.append("\n");
return (*cntx)->SendBulkString(result); return (*cntx)->SendBulkString(result);
}
void ServerFamily::ClientPause(CmdArgList args, ConnectionContext* cntx) {
CmdArgParser parser(args);
auto timeout = parser.Next().Int<uint64_t>();
enum ClientPause pause_state = ClientPause::ALL;
if (parser.HasNext()) {
pause_state =
parser.ToUpper().Next().Case("WRITE", ClientPause::WRITE).Case("ALL", ClientPause::ALL);
}
if (auto err = parser.Error(); err) {
return (*cntx)->SendError(err->MakeReply());
} }
if (sub_cmd == "SETINFO") { // Pause dispatch commands before updating client puase state, and enable dispatch after updating
return (*cntx)->SendOk(); // pause state. This will unsure that when we after changing the state all running commands will
// read the new pause state, and we will not pause client in the middle of a transaction.
service_.proactor_pool().Await([](util::ProactorBase* pb) {
ServerState& etl = *ServerState::tlocal();
etl.SetPauseDispatch(true);
});
// TODO handle blocking commands
const absl::Duration kDispatchTimeout = absl::Seconds(1);
if (!AwaitDispatches(kDispatchTimeout, [self = cntx->conn()](util::Connection* conn) {
// Wait until the only command dispatching is the client pause command.
return conn != self;
})) {
LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << kDispatchTimeout;
service_.proactor_pool().Await([](util::ProactorBase* pb) {
ServerState& etl = *ServerState::tlocal();
etl.SetPauseDispatch(false);
});
return (*cntx)->SendError("Failed to pause all running clients");
} }
LOG_FIRST_N(ERROR, 10) << "Subcommand " << sub_cmd << " not supported"; service_.proactor_pool().AwaitFiberOnAll([pause_state](util::ProactorBase* pb) {
return (*cntx)->SendError(UnknownSubCmd(sub_cmd, "CLIENT"), kSyntaxErrType); ServerState& etl = *ServerState::tlocal();
etl.SetPauseState(pause_state, true);
etl.SetPauseDispatch(false);
});
// We should not expire/evict keys while clients are puased.
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(false); });
fb2::Fiber("client_pause", [this, timeout, pause_state]() mutable {
// On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps
// ensure this fiber will not left hanging .
auto step = 10ms;
auto timeout_ms = timeout * 1ms;
int64_t steps = timeout_ms.count() / step.count();
ServerState& etl = *ServerState::tlocal();
do {
ThisFiber::SleepFor(step);
} while (etl.gstate() != GlobalState::SHUTTING_DOWN && --steps > 0);
if (etl.gstate() != GlobalState::SHUTTING_DOWN) {
service_.proactor_pool().AwaitFiberOnAll([pause_state](util::ProactorBase* pb) {
ServerState::tlocal()->SetPauseState(pause_state, false);
});
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); });
}
}).Detach();
(*cntx)->SendOk();
} }
void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {

View file

@ -211,6 +211,10 @@ class ServerFamily {
void Auth(CmdArgList args, ConnectionContext* cntx); void Auth(CmdArgList args, ConnectionContext* cntx);
void Client(CmdArgList args, ConnectionContext* cntx); void Client(CmdArgList args, ConnectionContext* cntx);
void ClientSetName(CmdArgList args, ConnectionContext* cntx);
void ClientGetName(CmdArgList args, ConnectionContext* cntx);
void ClientList(CmdArgList args, ConnectionContext* cntx);
void ClientPause(CmdArgList args, ConnectionContext* cntx);
void Config(CmdArgList args, ConnectionContext* cntx); void Config(CmdArgList args, ConnectionContext* cntx);
void DbSize(CmdArgList args, ConnectionContext* cntx); void DbSize(CmdArgList args, ConnectionContext* cntx);
void Debug(CmdArgList args, ConnectionContext* cntx); void Debug(CmdArgList args, ConnectionContext* cntx);

View file

@ -4,6 +4,8 @@
#include "server/server_family.h" #include "server/server_family.h"
#include <absl/strings/match.h>
#include "base/gtest.h" #include "base/gtest.h"
#include "base/logging.h" #include "base/logging.h"
#include "facade/facade_test.h" #include "facade/facade_test.h"
@ -181,4 +183,21 @@ TEST_F(ServerFamilyTest, SlowLogMinusOneDisabled) {
EXPECT_THAT(resp.GetInt(), 0); EXPECT_THAT(resp.GetInt(), 0);
} }
TEST_F(ServerFamilyTest, ClientPause) {
auto start = absl::Now();
Run({"CLIENT", "PAUSE", "50"});
Run({"get", "key"});
EXPECT_GT((absl::Now() - start), absl::Milliseconds(50));
start = absl::Now();
Run({"CLIENT", "PAUSE", "50", "WRITE"});
Run({"get", "key"});
EXPECT_LT((absl::Now() - start), absl::Milliseconds(10));
Run({"set", "key", "value2"});
EXPECT_GT((absl::Now() - start), absl::Milliseconds(50));
}
} // namespace dfly } // namespace dfly

View file

@ -112,6 +112,45 @@ bool ServerState::AllowInlineScheduling() const {
return true; return true;
} }
void ServerState::SetPauseState(ClientPause state, bool start) {
client_pauses_[int(state)] += (start ? 1 : -1);
if (!client_pauses_[int(state)]) {
client_pause_ec_.notifyAll();
}
}
bool ServerState::IsPaused() const {
return client_pauses_[0] || client_pauses_[1];
}
void ServerState::AwaitPauseState(bool is_write) {
client_pause_ec_.await([is_write, this]() {
if (client_pauses_[int(ClientPause::ALL)]) {
return false;
}
if (is_write && client_pauses_[int(ClientPause::WRITE)]) {
return false;
}
return true;
});
}
void ServerState::AwaitOnPauseDispatch() {
pause_dispatch_ec_.await([this]() {
if (pause_dispatch_) {
return false;
}
return true;
});
}
void ServerState::SetPauseDispatch(bool pause) {
pause_dispatch_ = pause;
if (!pause_dispatch_) {
pause_dispatch_ec_.notifyAll();
}
}
Interpreter* ServerState::BorrowInterpreter() { Interpreter* ServerState::BorrowInterpreter() {
return interpreter_mgr_.Get(); return interpreter_mgr_.Get();
} }

View file

@ -80,6 +80,8 @@ class MonitorsRepo {
unsigned int global_count_ = 0; // by global its means that we count the monitor for all threads unsigned int global_count_ = 0; // by global its means that we count the monitor for all threads
}; };
enum class ClientPause { WRITE, ALL };
// Present in every server thread. This class differs from EngineShard. The latter manages // Present in every server thread. This class differs from EngineShard. The latter manages
// state around engine shards while the former represents coordinator/connection state. // state around engine shards while the former represents coordinator/connection state.
// There may be threads that handle engine shards but not IO, there may be threads that handle IO // There may be threads that handle engine shards but not IO, there may be threads that handle IO
@ -220,6 +222,24 @@ class ServerState { // public struct - to allow initialization.
acl::AclLog acl_log; acl::AclLog acl_log;
// Starts or ends a `CLIENT PAUSE` command. @state controls whether
// this is pausing only writes or every command, @start controls
// whether this is starting or ending the pause.
void SetPauseState(ClientPause state, bool start);
// Returns whether any type of commands is paused.
bool IsPaused() const;
// Awaits until the pause is over and the command can execute.
// @is_write controls whether the command is a write command or not.
void AwaitPauseState(bool is_write);
// Toggle a boolean indicating whether the server should temporarily pause or allow dispatching
// new commands.
void SetPauseDispatch(bool pause);
// Awaits until dispatching new commands is allowed as determinded by SetPauseDispatch function
void AwaitOnPauseDispatch();
SlowLogShard& GetSlowLog() { SlowLogShard& GetSlowLog() {
return slow_log_shard_; return slow_log_shard_;
}; };
@ -237,6 +257,15 @@ class ServerState { // public struct - to allow initialization.
GlobalState gstate_ = GlobalState::ACTIVE; GlobalState gstate_ = GlobalState::ACTIVE;
// To support concurrent `CLIENT PAUSE commands` correctly, we store the amount
// of current CLIENT PAUSE commands that are in effect. Blocked execution fibers
// should subscribe to `client_pause_ec_` through `AwaitPauseState` to be
// notified when the break is over.
int client_pauses_[2] = {};
EventCount client_pause_ec_;
bool pause_dispatch_ = false;
EventCount pause_dispatch_ec_;
using Counter = util::SlidingCounter<7>; using Counter = util::SlidingCounter<7>;
Counter qps_; Counter qps_;

View file

@ -1761,3 +1761,49 @@ async def test_search(df_local_factory):
assert (await c_replica.ft("idx-m2").search(Query("*").sort_by("f2").paging(0, 1))).docs[ assert (await c_replica.ft("idx-m2").search(Query("*").sort_by("f2").paging(0, 1))).docs[
0 0
].id == "k0" ].id == "k0"
# @pytest.mark.slow
@pytest.mark.asyncio
async def test_client_pause_with_replica(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=4)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create(port=master.port)
c_master = master.client()
c_replica = replica.client()
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
fill_task = asyncio.create_task(seeder.run())
# Give the seeder a bit of time.
await asyncio.sleep(1)
# block the seeder for 4 seconds
await c_master.execute_command("client pause 4000 write")
stats = await c_master.info("CommandStats")
info = await c_master.info("Stats")
await asyncio.sleep(0.5)
stats_after_sleep = await c_master.info("CommandStats")
# Check no commands are executed except info and replconf called from replica
for cmd, cmd_stats in stats_after_sleep.items():
if "cmdstat_INFO" != cmd and "cmdstat_REPLCONF" != cmd_stats:
assert stats[cmd] == cmd_stats
await asyncio.sleep(6)
seeder.stop()
await fill_task
stats_after_pause_finish = await c_master.info("CommandStats")
more_exeuted = False
for cmd, cmd_stats in stats_after_pause_finish.items():
if "cmdstat_INFO" != cmd and "cmdstat_REPLCONF" != cmd_stats and stats[cmd] != cmd_stats:
more_exeuted = True
assert more_exeuted
capture = await seeder.capture(port=master.port)
assert await seeder.compare(capture, port=replica.port)
await disconnect_clients(c_master, c_replica)