diff --git a/src/server/dfly_bench.cc b/src/server/dfly_bench.cc index 6102e8295..f91ec7a0f 100644 --- a/src/server/dfly_bench.cc +++ b/src/server/dfly_bench.cc @@ -75,6 +75,7 @@ ABSL_FLAG(bool, ascii, true, "If true, use ascii characters for values"); ABSL_FLAG(bool, connect_only, false, "If true, will only connect to the server, without sending " "loadtest commands"); +ABSL_FLAG(string, password, "", "password to authenticate the client"); using namespace std; using namespace util; @@ -423,6 +424,7 @@ class Driver { void ReceiveFb(); void ParseRESP(); void ParseMC(); + void RunCommandAndCheckResultIs(std::string_view cmd, std::string_view expected_res); struct Req { uint64_t start; @@ -588,6 +590,16 @@ void KeyGenerator::EnableClusterMode() { ++i; } } +void Driver::RunCommandAndCheckResultIs(std::string_view cmd, std::string_view expected_res) { + auto ec = socket_->Write(io::Buffer(cmd)); + CHECK(!ec); + + uint8_t buf[128]; + auto res_sz = socket_->Recv(io::MutableBytes(buf)); + CHECK(res_sz) << res_sz.error().message(); + string_view resp = io::View(io::Bytes(buf, *res_sz)); + CHECK_EQ(resp, expected_res) << resp; +} void Driver::Connect(unsigned index, const tcp::endpoint& ep) { VLOG(2) << "Connecting " << index << " to " << ep; @@ -598,18 +610,15 @@ void Driver::Connect(unsigned index, const tcp::endpoint& ep) { CHECK_EQ(0, setsockopt(socket_->native_handle(), IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes))); } - if (absl::GetFlag(FLAGS_greet)) { + auto password = absl::GetFlag(FLAGS_password); + if (!password.empty()) { + auto command = absl::StrCat("AUTH ", password, "\r\n"); + RunCommandAndCheckResultIs(command, "+OK\r\n"); + } else if (absl::GetFlag(FLAGS_greet)) { // TCP Connect does not ensure that the connection was indeed accepted by the server. // if server backlog is too short the connection will get stuck in the accept queue. // Therefore, we send a ping command to ensure that every connection got connected. - ec = socket_->Write(io::Buffer("ping\r\n")); - CHECK(!ec); - - uint8_t buf[128]; - auto res_sz = socket_->Recv(io::MutableBytes(buf)); - CHECK(res_sz) << res_sz.error().message(); - string_view resp = io::View(io::Bytes(buf, *res_sz)); - CHECK(absl::EndsWith(resp, "\r\n")) << resp; + RunCommandAndCheckResultIs("PING\r\n", "+PONG\r\n"); } ep_ = ep; receive_fb_ = MakeFiber(fb2::Launch::dispatch, [this] { ReceiveFb(); });