feat(server): Implement CLIENT PAUSE (#1875)

* feat(server): Implement CLIENT PAUSE

Signed-off-by: adi_holden <adi@dragonflydb.io>
This commit is contained in:
Roy Jacobson 2023-11-15 08:56:49 +02:00 committed by GitHub
parent 1ec1b997a0
commit c3a2da559e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 291 additions and 39 deletions

View file

@ -831,9 +831,10 @@ auto Connection::IoLoop(util::FiberSocketBase* peer, SinkReplyBuilder* orig_buil
io_buf_.CommitWrite(*recv_sz);
stats_->io_read_bytes += *recv_sz;
++stats_->io_read_cnt;
phase_ = PROCESS;
bool is_iobuf_full = io_buf_.AppendLen() == 0;
service_->AwaitOnPauseDispatch();
if (redis_parser_) {
parse_status = ParseRedis(orig_builder);
} else {

View file

@ -44,6 +44,10 @@ class OkService : public ServiceInterface {
ConnectionStats* GetThreadLocalConnectionStats() final {
return &tl_stats;
}
void AwaitOnPauseDispatch() {
return;
}
};
void RunEngine(ProactorPool* pool, AcceptServer* acceptor) {

View file

@ -35,6 +35,7 @@ class ServiceInterface {
virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0;
virtual ConnectionStats* GetThreadLocalConnectionStats() = 0;
virtual void AwaitOnPauseDispatch() = 0;
virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) {
}

View file

@ -235,6 +235,7 @@ void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
void ConnectionState::ExecInfo::Clear() {
state = EXEC_INACTIVE;
body.clear();
is_write = false;
ClearWatched();
}

View file

@ -82,6 +82,7 @@ struct ConnectionState {
ExecState state = EXEC_INACTIVE;
std::vector<StoredCmd> body;
bool is_write = false;
std::vector<std::pair<DbIndex, std::string>> watched_keys; // List of keys registered by WATCH
std::atomic_bool watched_dirty = false; // Set if a watched key was changed before EXEC

View file

@ -1067,6 +1067,14 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
<< " in dbid=" << dfly_cntx->conn_state.db_index;
}
string_view cmd_name(cid->name());
bool is_write = (cid->opt_mask() & CO::WRITE) || cmd_name == "PUBLISH" || cmd_name == "EVAL" ||
cmd_name == "EVALSHA";
if (cmd_name == "EXEC" && dfly_cntx->conn_state.exec_info.is_write) {
is_write = true;
}
etl.AwaitPauseState(is_write);
etl.RecordCmd();
if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) {
@ -1082,7 +1090,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
// TODO: protect against aggregating huge transactions.
StoredCmd stored_cmd{cid, args_no_cmd};
dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd));
if (stored_cmd.Cid()->opt_mask() & CO::WRITE) {
dfly_cntx->conn_state.exec_info.is_write = true;
}
return cntx->SendSimpleString("QUEUED");
}
@ -1254,8 +1264,9 @@ void Service::DispatchManyCommands(absl::Span<CmdArgList> args_list,
// invocations, we can potentially execute multiple eval in parallel, which is very powerful
// paired with shardlocal eval
const bool is_eval = CO::IsEvalKind(ArgS(args, 0));
const bool is_pause = dfly::ServerState::tlocal()->IsPaused();
if (!is_multi && !is_eval && cid != nullptr) {
if (!is_multi && !is_eval && cid != nullptr && !is_pause) {
stored_cmds.reserve(args_list.size());
stored_cmds.emplace_back(cid, tail_args);
continue;
@ -1410,6 +1421,10 @@ facade::ConnectionStats* Service::GetThreadLocalConnectionStats() {
return ServerState::tl_connection_stats();
}
void Service::AwaitOnPauseDispatch() {
ServerState::tlocal()->AwaitOnPauseDispatch();
}
const CommandId* Service::FindCmd(std::string_view cmd) const {
return registry_.Find(cmd);
}

View file

@ -73,6 +73,7 @@ class Service : public facade::ServiceInterface {
facade::Connection* owner) final;
facade::ConnectionStats* GetThreadLocalConnectionStats() final;
void AwaitOnPauseDispatch() final;
std::pair<const CommandId*, CmdArgList> FindCmd(CmdArgList args) const;
const CommandId* FindCmd(std::string_view) const;

View file

@ -30,7 +30,7 @@ class MultiCommandSquasher {
}
private:
// Per-shard exection info.
// Per-shard execution info.
struct ShardExecInfo {
ShardExecInfo() : had_writes{false}, cmds{}, replies{}, local_tx{nullptr} {
}
@ -74,7 +74,7 @@ class MultiCommandSquasher {
ConnectionContext* cntx_; // Underlying context
Service* service_;
bool atomic_; // Wheter working in any of the atomic modes
bool atomic_; // Whether working in any of the atomic modes
const CommandId* base_cid_; // underlying cid (exec or eval) for executing batch hops
bool verify_commands_ = false; // Whether commands need to be verified before execution

View file

@ -27,6 +27,7 @@ extern "C" {
#include "base/flags.h"
#include "base/logging.h"
#include "facade/cmd_arg_parser.h"
#include "facade/dragonfly_connection.h"
#include "facade/reply_builder.h"
#include "io/file_util.h"
@ -1213,41 +1214,16 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) {
ToUpper(&args[0]);
string_view sub_cmd = ArgS(args, 0);
CmdArgList sub_args = args.subspan(1);
if (sub_cmd == "SETNAME" && args.size() == 2) {
cntx->conn()->SetName(string{ArgS(args, 1)});
return (*cntx)->SendOk();
}
if (sub_cmd == "GETNAME") {
auto name = cntx->conn()->GetName();
if (!name.empty()) {
return (*cntx)->SendBulkString(name);
} else {
return (*cntx)->SendNull();
}
}
if (sub_cmd == "LIST") {
vector<string> client_info;
absl::base_internal::SpinLock mu;
// we can not preempt the connection traversal, so we need to use a spinlock.
// alternatively we could lock when mutating the connection list, but it seems not important.
auto cb = [&](unsigned thread_index, util::Connection* conn) {
facade::Connection* dcon = static_cast<facade::Connection*>(conn);
string info = dcon->GetClientInfo(thread_index);
absl::base_internal::SpinLockHolder l(&mu);
client_info.push_back(move(info));
};
for (auto* listener : listeners_) {
listener->TraverseConnections(cb);
}
string result = absl::StrJoin(move(client_info), "\n");
result.append("\n");
return (*cntx)->SendBulkString(result);
if (sub_cmd == "SETNAME") {
return ClientSetName(sub_args, cntx);
} else if (sub_cmd == "GETNAME") {
return ClientGetName(sub_args, cntx);
} else if (sub_cmd == "LIST") {
return ClientList(sub_args, cntx);
} else if (sub_cmd == "PAUSE") {
return ClientPause(sub_args, cntx);
}
if (sub_cmd == "SETINFO") {
@ -1258,6 +1234,121 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(UnknownSubCmd(sub_cmd, "CLIENT"), kSyntaxErrType);
}
void ServerFamily::ClientSetName(CmdArgList args, ConnectionContext* cntx) {
if (args.size() == 1) {
cntx->conn()->SetName(string{ArgS(args, 0)});
return (*cntx)->SendOk();
} else {
return (*cntx)->SendError(facade::kSyntaxErr);
}
}
void ServerFamily::ClientGetName(CmdArgList args, ConnectionContext* cntx) {
if (!args.empty()) {
return (*cntx)->SendError(facade::kSyntaxErr);
}
auto name = cntx->conn()->GetName();
if (!name.empty()) {
return (*cntx)->SendBulkString(name);
} else {
return (*cntx)->SendNull();
}
}
void ServerFamily::ClientList(CmdArgList args, ConnectionContext* cntx) {
if (!args.empty()) {
return (*cntx)->SendError(facade::kSyntaxErr);
}
vector<string> client_info;
absl::base_internal::SpinLock mu;
// we can not preempt the connection traversal, so we need to use a spinlock.
// alternatively we could lock when mutating the connection list, but it seems not important.
auto cb = [&](unsigned thread_index, util::Connection* conn) {
facade::Connection* dcon = static_cast<facade::Connection*>(conn);
string info = dcon->GetClientInfo(thread_index);
absl::base_internal::SpinLockHolder l(&mu);
client_info.push_back(std::move(info));
};
for (auto* listener : listeners_) {
listener->TraverseConnections(cb);
}
string result = absl::StrJoin(client_info, "\n");
result.append("\n");
return (*cntx)->SendBulkString(result);
}
void ServerFamily::ClientPause(CmdArgList args, ConnectionContext* cntx) {
CmdArgParser parser(args);
auto timeout = parser.Next().Int<uint64_t>();
enum ClientPause pause_state = ClientPause::ALL;
if (parser.HasNext()) {
pause_state =
parser.ToUpper().Next().Case("WRITE", ClientPause::WRITE).Case("ALL", ClientPause::ALL);
}
if (auto err = parser.Error(); err) {
return (*cntx)->SendError(err->MakeReply());
}
// Pause dispatch commands before updating client puase state, and enable dispatch after updating
// pause state. This will unsure that when we after changing the state all running commands will
// read the new pause state, and we will not pause client in the middle of a transaction.
service_.proactor_pool().Await([](util::ProactorBase* pb) {
ServerState& etl = *ServerState::tlocal();
etl.SetPauseDispatch(true);
});
// TODO handle blocking commands
const absl::Duration kDispatchTimeout = absl::Seconds(1);
if (!AwaitDispatches(kDispatchTimeout, [self = cntx->conn()](util::Connection* conn) {
// Wait until the only command dispatching is the client pause command.
return conn != self;
})) {
LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << kDispatchTimeout;
service_.proactor_pool().Await([](util::ProactorBase* pb) {
ServerState& etl = *ServerState::tlocal();
etl.SetPauseDispatch(false);
});
return (*cntx)->SendError("Failed to pause all running clients");
}
service_.proactor_pool().AwaitFiberOnAll([pause_state](util::ProactorBase* pb) {
ServerState& etl = *ServerState::tlocal();
etl.SetPauseState(pause_state, true);
etl.SetPauseDispatch(false);
});
// We should not expire/evict keys while clients are puased.
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(false); });
fb2::Fiber("client_pause", [this, timeout, pause_state]() mutable {
// On server shutdown we sleep 10ms to make sure all running task finish, therefore 10ms steps
// ensure this fiber will not left hanging .
auto step = 10ms;
auto timeout_ms = timeout * 1ms;
int64_t steps = timeout_ms.count() / step.count();
ServerState& etl = *ServerState::tlocal();
do {
ThisFiber::SleepFor(step);
} while (etl.gstate() != GlobalState::SHUTTING_DOWN && --steps > 0);
if (etl.gstate() != GlobalState::SHUTTING_DOWN) {
service_.proactor_pool().AwaitFiberOnAll([pause_state](util::ProactorBase* pb) {
ServerState::tlocal()->SetPauseState(pause_state, false);
});
shard_set->RunBriefInParallel(
[](EngineShard* shard) { shard->db_slice().SetExpireAllowed(true); });
}
}).Detach();
(*cntx)->SendOk();
}
void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {
ToUpper(&args[0]);
string_view sub_cmd = ArgS(args, 0);

View file

@ -211,6 +211,10 @@ class ServerFamily {
void Auth(CmdArgList args, ConnectionContext* cntx);
void Client(CmdArgList args, ConnectionContext* cntx);
void ClientSetName(CmdArgList args, ConnectionContext* cntx);
void ClientGetName(CmdArgList args, ConnectionContext* cntx);
void ClientList(CmdArgList args, ConnectionContext* cntx);
void ClientPause(CmdArgList args, ConnectionContext* cntx);
void Config(CmdArgList args, ConnectionContext* cntx);
void DbSize(CmdArgList args, ConnectionContext* cntx);
void Debug(CmdArgList args, ConnectionContext* cntx);

View file

@ -4,6 +4,8 @@
#include "server/server_family.h"
#include <absl/strings/match.h>
#include "base/gtest.h"
#include "base/logging.h"
#include "facade/facade_test.h"
@ -181,4 +183,21 @@ TEST_F(ServerFamilyTest, SlowLogMinusOneDisabled) {
EXPECT_THAT(resp.GetInt(), 0);
}
TEST_F(ServerFamilyTest, ClientPause) {
auto start = absl::Now();
Run({"CLIENT", "PAUSE", "50"});
Run({"get", "key"});
EXPECT_GT((absl::Now() - start), absl::Milliseconds(50));
start = absl::Now();
Run({"CLIENT", "PAUSE", "50", "WRITE"});
Run({"get", "key"});
EXPECT_LT((absl::Now() - start), absl::Milliseconds(10));
Run({"set", "key", "value2"});
EXPECT_GT((absl::Now() - start), absl::Milliseconds(50));
}
} // namespace dfly

View file

@ -112,6 +112,45 @@ bool ServerState::AllowInlineScheduling() const {
return true;
}
void ServerState::SetPauseState(ClientPause state, bool start) {
client_pauses_[int(state)] += (start ? 1 : -1);
if (!client_pauses_[int(state)]) {
client_pause_ec_.notifyAll();
}
}
bool ServerState::IsPaused() const {
return client_pauses_[0] || client_pauses_[1];
}
void ServerState::AwaitPauseState(bool is_write) {
client_pause_ec_.await([is_write, this]() {
if (client_pauses_[int(ClientPause::ALL)]) {
return false;
}
if (is_write && client_pauses_[int(ClientPause::WRITE)]) {
return false;
}
return true;
});
}
void ServerState::AwaitOnPauseDispatch() {
pause_dispatch_ec_.await([this]() {
if (pause_dispatch_) {
return false;
}
return true;
});
}
void ServerState::SetPauseDispatch(bool pause) {
pause_dispatch_ = pause;
if (!pause_dispatch_) {
pause_dispatch_ec_.notifyAll();
}
}
Interpreter* ServerState::BorrowInterpreter() {
return interpreter_mgr_.Get();
}

View file

@ -80,6 +80,8 @@ class MonitorsRepo {
unsigned int global_count_ = 0; // by global its means that we count the monitor for all threads
};
enum class ClientPause { WRITE, ALL };
// Present in every server thread. This class differs from EngineShard. The latter manages
// state around engine shards while the former represents coordinator/connection state.
// There may be threads that handle engine shards but not IO, there may be threads that handle IO
@ -220,6 +222,24 @@ class ServerState { // public struct - to allow initialization.
acl::AclLog acl_log;
// Starts or ends a `CLIENT PAUSE` command. @state controls whether
// this is pausing only writes or every command, @start controls
// whether this is starting or ending the pause.
void SetPauseState(ClientPause state, bool start);
// Returns whether any type of commands is paused.
bool IsPaused() const;
// Awaits until the pause is over and the command can execute.
// @is_write controls whether the command is a write command or not.
void AwaitPauseState(bool is_write);
// Toggle a boolean indicating whether the server should temporarily pause or allow dispatching
// new commands.
void SetPauseDispatch(bool pause);
// Awaits until dispatching new commands is allowed as determinded by SetPauseDispatch function
void AwaitOnPauseDispatch();
SlowLogShard& GetSlowLog() {
return slow_log_shard_;
};
@ -237,6 +257,15 @@ class ServerState { // public struct - to allow initialization.
GlobalState gstate_ = GlobalState::ACTIVE;
// To support concurrent `CLIENT PAUSE commands` correctly, we store the amount
// of current CLIENT PAUSE commands that are in effect. Blocked execution fibers
// should subscribe to `client_pause_ec_` through `AwaitPauseState` to be
// notified when the break is over.
int client_pauses_[2] = {};
EventCount client_pause_ec_;
bool pause_dispatch_ = false;
EventCount pause_dispatch_ec_;
using Counter = util::SlidingCounter<7>;
Counter qps_;

View file

@ -1761,3 +1761,49 @@ async def test_search(df_local_factory):
assert (await c_replica.ft("idx-m2").search(Query("*").sort_by("f2").paging(0, 1))).docs[
0
].id == "k0"
# @pytest.mark.slow
@pytest.mark.asyncio
async def test_client_pause_with_replica(df_local_factory, df_seeder_factory):
master = df_local_factory.create(proactor_threads=4)
replica = df_local_factory.create(proactor_threads=4)
df_local_factory.start_all([master, replica])
seeder = df_seeder_factory.create(port=master.port)
c_master = master.client()
c_replica = replica.client()
await c_replica.execute_command(f"REPLICAOF localhost {master.port}")
await wait_available_async(c_replica)
fill_task = asyncio.create_task(seeder.run())
# Give the seeder a bit of time.
await asyncio.sleep(1)
# block the seeder for 4 seconds
await c_master.execute_command("client pause 4000 write")
stats = await c_master.info("CommandStats")
info = await c_master.info("Stats")
await asyncio.sleep(0.5)
stats_after_sleep = await c_master.info("CommandStats")
# Check no commands are executed except info and replconf called from replica
for cmd, cmd_stats in stats_after_sleep.items():
if "cmdstat_INFO" != cmd and "cmdstat_REPLCONF" != cmd_stats:
assert stats[cmd] == cmd_stats
await asyncio.sleep(6)
seeder.stop()
await fill_task
stats_after_pause_finish = await c_master.info("CommandStats")
more_exeuted = False
for cmd, cmd_stats in stats_after_pause_finish.items():
if "cmdstat_INFO" != cmd and "cmdstat_REPLCONF" != cmd_stats and stats[cmd] != cmd_stats:
more_exeuted = True
assert more_exeuted
capture = await seeder.capture(port=master.port)
assert await seeder.compare(capture, port=replica.port)
await disconnect_clients(c_master, c_replica)