Consolidate generic commands under generic_family. Add Del, Echo, Ttl and select commands

This commit is contained in:
Roman Gershman 2021-12-26 17:25:49 +02:00
parent b1f32e5ebf
commit 55ee0563b0
18 changed files with 429 additions and 64 deletions

View file

@ -3,7 +3,7 @@ cxx_link(dragonfly base dragonfly_lib)
add_library(dragonfly_lib command_registry.cc common.cc config_flags.cc
conn_context.cc db_slice.cc debugcmd.cc dragonfly_listener.cc
dragonfly_connection.cc engine_shard_set.cc
dragonfly_connection.cc engine_shard_set.cc generic_family.cc
main_service.cc memcache_parser.cc
redis_parser.cc reply_builder.cc string_family.cc transaction.cc)

View file

@ -60,6 +60,13 @@ void CommandRegistry::Command(CmdArgList args, ConnectionContext* cntx) {
cntx->SendRespBlob(resp);
}
CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
string_view k = cmd.name();
CHECK(cmd_map_.emplace(k, std::move(cmd)).second) << k;
return *this;
}
namespace CO {
const char* OptName(CO::CommandOpt fl) {

View file

@ -111,12 +111,7 @@ class CommandRegistry {
public:
CommandRegistry();
CommandRegistry& operator<<(CommandId cmd) {
const char* k = cmd.name();
cmd_map_.emplace(k, std::move(cmd));
return *this;
}
CommandRegistry& operator<<(CommandId cmd);
const CommandId* Find(std::string_view cmd) const {
auto it = cmd_map_.find(cmd);

View file

@ -21,6 +21,8 @@ string WrongNumArgsError(std::string_view cmd) {
const char kSyntaxErr[] = "syntax error";
const char kInvalidIntErr[] = "value is not an integer or out of range";
const char kUintErr[] = "value is out of range, must be positive";
const char kDbIndOutOfRangeErr[] = "DB index is out of range";
const char kInvalidDbIndErr[] = "invalid DB index";
} // namespace dfly
@ -39,4 +41,4 @@ ostream& operator<<(ostream& os, dfly::CmdArgList ras) {
return os;
}
} // namespace std
} // namespace std

View file

@ -40,6 +40,11 @@ struct KeyLockArgs {
unsigned key_step;
};
struct OpArgs {
EngineShard* shard;
DbIndex db_ind;
};
inline std::string_view ArgS(CmdArgList args, size_t i) {
auto arg = args[i];
return std::string_view(arg.data(), arg.size());

View file

@ -14,6 +14,8 @@ class EngineShardSet;
class CommandId;
struct ConnectionState {
DbIndex db_index = 0;
enum Mask : uint32_t {
ASYNC_DISPATCH = 1, // whether a command is handled via async dispatch.
CONN_CLOSING = 2, // could be because of unrecoverable error or planned action.
@ -45,6 +47,10 @@ class ConnectionContext : public ReplyBuilder {
Protocol protocol() const;
DbIndex db_index() const {
return conn_state.db_index;
}
ConnectionState conn_state;
private:

View file

@ -106,6 +106,22 @@ void DbSlice::CreateDb(DbIndex index) {
}
}
bool DbSlice::Del(DbIndex db_ind, const MainIterator& it) {
auto& db = db_arr_[db_ind];
if (it == MainIterator{}) {
return false;
}
if (it->second.HasExpire()) {
CHECK_EQ(1u, db->expire_table.erase(it->first));
}
db->stats.obj_memory_usage -= (it->first.capacity() + it->second.str.capacity());
db->main_table.erase(it);
return true;
}
// Returns true if a state has changed, false otherwise.
bool DbSlice::Expire(DbIndex db_ind, MainIterator it, uint64_t at) {
auto& db = db_arr_[db_ind];

View file

@ -64,6 +64,8 @@ class DbSlice {
// Creates a database with index `db_ind`. If such database exists does nothing.
void ActivateDb(DbIndex db_ind);
bool Del(DbIndex db_ind, const MainIterator& it);
ShardId shard_id() const {
return shard_id_;
}

View file

@ -13,6 +13,8 @@ std::string WrongNumArgsError(std::string_view cmd);
extern const char kSyntaxErr[];
extern const char kInvalidIntErr[];
extern const char kUintErr[];
extern const char kDbIndOutOfRangeErr[];
extern const char kInvalidDbIndErr[];
#ifndef RETURN_ON_ERR
@ -24,4 +26,4 @@ extern const char kUintErr[];
} while (0)
#endif
} // namespace dfly
} // namespace dfly

266
server/generic_family.cc Normal file
View file

@ -0,0 +1,266 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/generic_family.h"
#include "base/logging.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/engine_shard_set.h"
#include "server/error.h"
#include "server/transaction.h"
#include "util/varz.h"
DEFINE_uint32(dbnum, 16, "Number of databases");
namespace dfly {
using namespace std;
namespace {
DEFINE_VARZ(VarzQps, ping_qps);
} // namespace
void GenericFamily::Init(util::ProactorPool* pp) {
ping_qps.Init(pp);
}
void GenericFamily::Shutdown() {
ping_qps.Shutdown();
}
void GenericFamily::Del(CmdArgList args, ConnectionContext* cntx) {
Transaction* transaction = cntx->transaction;
VLOG(1) << "Del " << ArgS(args, 1);
atomic_uint32_t result{0};
auto cb = [&result](const Transaction* t, EngineShard* shard) {
ArgSlice args = t->ShardArgsInShard(shard->shard_id());
auto res = OpDel(OpArgs{shard, t->db_index()}, args);
result.fetch_add(res.value_or(0), memory_order_relaxed);
return OpStatus::OK;
};
OpStatus status = transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, status);
DVLOG(2) << "Del ts " << transaction->txid();
cntx->SendLong(result.load(memory_order_release));
}
void GenericFamily::Ping(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 2) {
return cntx->SendError("wrong number of arguments for 'ping' command");
}
ping_qps.Inc();
// We synchronously block here until the engine sends us the payload and notifies that
// the I/O operation has been processed.
if (args.size() == 1) {
return cntx->SendSimpleRespString("PONG");
} else {
std::string_view arg = ArgS(args, 1);
DVLOG(2) << "Ping " << arg;
return cntx->SendBulkString(arg);
}
}
void GenericFamily::Exists(CmdArgList args, ConnectionContext* cntx) {
Transaction* transaction = cntx->transaction;
VLOG(1) << "Exists " << ArgS(args, 1);
atomic_uint32_t result{0};
auto cb = [&result](Transaction* t, EngineShard* shard) {
ArgSlice args = t->ShardArgsInShard(shard->shard_id());
auto res = OpExists(OpArgs{shard, t->db_index()}, args);
result.fetch_add(res.value_or(0), memory_order_relaxed);
return OpStatus::OK;
};
OpStatus status = transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, status);
return cntx->SendLong(result.load(memory_order_release));
}
void GenericFamily::Expire(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view sec = ArgS(args, 2);
int64_t int_arg;
if (!absl::SimpleAtoi(sec, &int_arg)) {
return cntx->SendError(kInvalidIntErr);
}
int_arg = std::max(int_arg, -1L);
ExpireParams params{.ts = int_arg};
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpExpire(OpArgs{shard, t->db_index()}, key, params);
};
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
cntx->SendLong(status == OpStatus::OK);
}
void GenericFamily::ExpireAt(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view sec = ArgS(args, 2);
int64_t int_arg;
if (!absl::SimpleAtoi(sec, &int_arg)) {
return cntx->SendError(kInvalidIntErr);
}
int_arg = std::max(int_arg, 0L);
ExpireParams params{.ts = int_arg, .absolute = true};
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpExpire(OpArgs{shard, t->db_index()}, key, params);
};
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
cntx->SendLong(status == OpStatus::OK);
}
void GenericFamily::Ttl(CmdArgList args, ConnectionContext* cntx) {
TtlGeneric(args, cntx, TimeUnit::SEC);
}
void GenericFamily::Pttl(CmdArgList args, ConnectionContext* cntx) {
TtlGeneric(args, cntx, TimeUnit::MSEC);
}
void GenericFamily::TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit) {
std::string_view key = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) { return OpTtl(t, shard, key); };
OpResult<uint64_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result) {
long ttl = (unit == TimeUnit::SEC) ? (result.value() + 500) / 1000 : result.value();
cntx->SendLong(ttl);
} else {
switch (result.status()) {
case OpStatus::KEY_NOTFOUND:
cntx->SendLong(-1);
break;
default:
cntx->SendLong(-2);
}
}
}
void GenericFamily::Select(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
int64_t index;
if (!absl::SimpleAtoi(key, &index)) {
return cntx->SendError(kInvalidDbIndErr);
}
if (index < 0 || index >= FLAGS_dbnum) {
return cntx->SendError(kDbIndOutOfRangeErr);
}
cntx->conn_state.db_index = index;
auto cb = [index](EngineShard* shard) {
shard->db_slice().ActivateDb(index);
return OpStatus::OK;
};
cntx->shard_set->RunBriefInParallel(std::move(cb));
return cntx->SendOk();
}
void GenericFamily::Echo(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
return cntx->SendBulkString(key);
}
OpStatus GenericFamily::OpExpire(const OpArgs& op_args, std::string_view key,
const ExpireParams& params) {
auto& db_slice = op_args.shard->db_slice();
auto [it, expire_it] = db_slice.FindExt(op_args.db_ind, key);
if (it == MainIterator{})
return OpStatus::KEY_NOTFOUND;
int64_t abs_msec = (params.unit == TimeUnit::SEC) ? params.ts * 1000 : params.ts;
if (!params.absolute) {
abs_msec += db_slice.Now();
}
if (abs_msec <= int64_t(db_slice.Now())) {
CHECK(db_slice.Del(op_args.db_ind, it));
} else if (expire_it != ExpireIterator{}) {
expire_it->second = abs_msec;
} else {
db_slice.Expire(op_args.db_ind, it, abs_msec);
}
return OpStatus::OK;
}
OpResult<uint64_t> GenericFamily::OpTtl(Transaction* t, EngineShard* shard, std::string_view key) {
auto& db_slice = shard->db_slice();
auto [it, expire] = db_slice.FindExt(t->db_index(), key);
if (it == MainIterator{})
return OpStatus::KEY_NOTFOUND;
if (expire == ExpireIterator{})
return OpStatus::SKIPPED;
int64_t ttl_ms = expire->second - db_slice.Now();
DCHECK_GT(ttl_ms, 0); // Otherwise FindExt would return null.
return ttl_ms;
}
OpResult<uint32_t> GenericFamily::OpDel(const OpArgs& op_args, ArgSlice keys) {
DVLOG(1) << "Del: " << keys[0];
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (uint32_t i = 0; i < keys.size(); ++i) {
auto fres = db_slice.FindExt(op_args.db_ind, keys[i]);
if (fres.first == MainIterator{})
continue;
res += int(db_slice.Del(op_args.db_ind, fres.first));
}
return res;
}
OpResult<uint32_t> GenericFamily::OpExists(const OpArgs& op_args, ArgSlice keys) {
DVLOG(1) << "Exists: " << keys[0];
auto& db_slice = op_args.shard->db_slice();
uint32_t res = 0;
for (uint32_t i = 0; i < keys.size(); ++i) {
auto find_res = db_slice.FindExt(op_args.db_ind, keys[i]);
res += (find_res.first != MainIterator{});
}
return res;
}
using CI = CommandId;
#define HFUNC(x) SetHandler(&GenericFamily::x)
void GenericFamily::Register(CommandRegistry* registry) {
constexpr auto kSelectOpts = CO::LOADING | CO::FAST | CO::STALE;
*registry << CI{"DEL", CO::WRITE, -2, 1, -1, 1}.HFUNC(Del)
<< CI{"PING", CO::STALE | CO::FAST, -1, 0, 0, 0}.HFUNC(Ping)
<< CI{"ECHO", CO::READONLY | CO::FAST, 2, 0, 0, 0}.HFUNC(Echo)
<< CI{"EXISTS", CO::READONLY | CO::FAST, -2, 1, -1, 1}.HFUNC(Exists)
<< CI{"EXPIRE", CO::WRITE | CO::FAST, 3, 1, 1, 1}.HFUNC(Expire)
<< CI{"EXPIREAT", CO::WRITE | CO::FAST, 3, 1, 1, 1}.HFUNC(ExpireAt)
<< CI{"SELECT", kSelectOpts, 2, 0, 0, 0}.HFUNC(Select)
<< CI{"TTL", CO::READONLY | CO::FAST | CO::RANDOM, 2, 1, 1, 1}.HFUNC(Ttl)
<< CI{"PTTL", CO::READONLY | CO::FAST | CO::RANDOM, 2, 1, 1, 1}.HFUNC(Pttl);
}
} // namespace dfly

58
server/generic_family.h Normal file
View file

@ -0,0 +1,58 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include "core/op_status.h"
#include "server/common_types.h"
namespace util {
class ProactorPool;
} // namespace util
namespace dfly {
class ConnectionContext;
class CommandRegistry;
class EngineShard;
class GenericFamily {
public:
static void Init(util::ProactorPool* pp);
static void Shutdown();
static void Register(CommandRegistry* registry);
private:
enum TimeUnit { SEC, MSEC };
struct ExpireParams {
int64_t ts;
bool absolute = false;
TimeUnit unit = SEC;
};
static void Del(CmdArgList args, ConnectionContext* cntx);
static void Ping(CmdArgList args, ConnectionContext* cntx);
static void Exists(CmdArgList args, ConnectionContext* cntx);
static void Expire(CmdArgList args, ConnectionContext* cntx);
static void ExpireAt(CmdArgList args, ConnectionContext* cntx);
static void Ttl(CmdArgList args, ConnectionContext* cntx);
static void Pttl(CmdArgList args, ConnectionContext* cntx);
static void Echo(CmdArgList args, ConnectionContext* cntx);
static void Select(CmdArgList args, ConnectionContext* cntx);
static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit);
static OpStatus OpExpire(const OpArgs& op_args, std::string_view key, const ExpireParams& params);
static OpResult<uint64_t> OpTtl(Transaction* t, EngineShard* shard, std::string_view key);
static OpResult<uint32_t> OpDel(const OpArgs& op_args, ArgSlice keys);
static OpResult<uint32_t> OpExists(const OpArgs& op_args, ArgSlice keys);
};
} // namespace dfly

View file

@ -14,6 +14,7 @@
#include "server/conn_context.h"
#include "server/debugcmd.h"
#include "server/error.h"
#include "server/generic_family.h"
#include "server/string_family.h"
#include "server/transaction.h"
#include "util/metrics/metrics.h"
@ -64,6 +65,7 @@ void Service::Init(util::AcceptServer* acceptor) {
request_latency_usec.Init(&pp_);
ping_qps.Init(&pp_);
StringFamily::Init(&pp_);
GenericFamily::Init(&pp_);
cmd_req.Init(&pp_, {"type"});
}
@ -74,7 +76,7 @@ void Service::Shutdown() {
request_latency_usec.Shutdown();
ping_qps.Shutdown();
StringFamily::Shutdown();
GenericFamily::Shutdown();
shard_set_.RunBlockingInParallel([&](EngineShard*) { EngineShard::DestroyThreadLocal(); });
}
@ -173,21 +175,6 @@ void Service::RegisterHttp(HttpListenerBase* listener) {
CHECK_NOTNULL(listener);
}
void Service::Ping(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 2) {
return cntx->SendError(WrongNumArgsError("PING"));
}
ping_qps.Inc();
if (args.size() == 1) {
return cntx->SendSimpleRespString("PONG");
}
std::string_view arg = ArgS(args, 1);
DVLOG(2) << "Ping " << arg;
return cntx->SendSimpleRespString(arg);
}
void Service::Debug(CmdArgList args, ConnectionContext* cntx) {
ToUpper(&args[1]);
@ -216,9 +203,10 @@ inline CommandId::Handler HandlerFunc(Service* se, ServiceFunc f) {
void Service::RegisterCommands() {
using CI = CommandId;
registry_ << CI{"PING", CO::STALE | CO::FAST, -1, 0, 0, 0}.HFUNC(Ping)
<< CI{"DEBUG", CO::RANDOM | CO::READONLY, -2, 0, 0, 0}.HFUNC(Debug);
registry_ << CI{"DEBUG", CO::RANDOM | CO::READONLY, -2, 0, 0, 0}.HFUNC(Debug);
StringFamily::Register(&registry_);
GenericFamily::Register(&registry_);
}
} // namespace dfly

View file

@ -46,7 +46,6 @@ class Service {
}
private:
void Ping(CmdArgList args, ConnectionContext* cntx);
void Debug(CmdArgList args, ConnectionContext* cntx);
void RegisterCommands();

View file

@ -177,6 +177,11 @@ void ReplyBuilder::SendGetNotFound() {
}
}
void ReplyBuilder::SendLong(long num) {
string str = absl::StrCat(":", num, kCRLF);
as_resp()->SendDirect(str);
}
void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
string res = absl::StrCat("*", count, kCRLF);
for (size_t i = 0; i < count; ++i) {

View file

@ -98,6 +98,9 @@ class ReplyBuilder {
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value);
void SendGetNotFound();
using StrOrNil = std::optional<std::string_view>;
void SendMGetResponse(const StrOrNil* arr, uint32_t count);
void SetBatchMode(bool mode) {
serializer_->SetBatchMode(mode);
}
@ -110,8 +113,11 @@ class ReplyBuilder {
as_resp()->SendNull();
}
using StrOrNil = std::optional<std::string_view>;
void SendMGetResponse(const StrOrNil* arr, uint32_t count);
void SendLong(long val);
void SendBulkString(std::string_view str) {
as_resp()->SendBulkString(str);
}
private:
RespSerializer* as_resp() {

View file

@ -80,7 +80,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
std::string_view value = ArgS(args, 2);
VLOG(2) << "Set " << key << " " << value;
SetCmd::SetParams sparams{0}; // TODO: db_index.
SetCmd::SetParams sparams{cntx->db_index()}; // TODO: db_index.
int64_t int_arg;
for (size_t i = 3; i < args.size(); ++i) {
@ -139,7 +139,7 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<string> {
OpResult<MainIterator> it_res = shard->db_slice().Find(0, key);
OpResult<MainIterator> it_res = shard->db_slice().Find(cntx->db_index(), key);
if (!it_res.ok())
return it_res.status();
@ -166,12 +166,12 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
std::string_view value = ArgS(args, 2);
std::optional<string> prev_val;
SetCmd::SetParams sparams{0};
SetCmd::SetParams sparams{cntx->db_index()};
sparams.prev_val = &prev_val;
ShardId sid = Shard(key, cntx->shard_set->size());
OpResult<void> result = cntx->shard_set->Await(sid, [&] {
EngineShard* es = EngineShard::tlocal();
EngineShard* es = EngineShard::tlocal();
SetCmd cmd(&es->db_slice());
return cmd.Set(sparams, key, value);
@ -245,7 +245,6 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendOk();
}
auto StringFamily::OpMGet(const Transaction* t, EngineShard* shard) -> MGetResponse {
auto args = t->ShardArgsInShard(shard->shard_id());
DCHECK(!args.empty());
@ -278,7 +277,6 @@ OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) {
return OpStatus::OK;
}
void StringFamily::Init(util::ProactorPool* pp) {
set_qps.Init(pp);
get_qps.Init(pp);

View file

@ -183,7 +183,7 @@ bool Transaction::RunInShard(ShardId sid) {
DVLOG(1) << "RunInShard: " << DebugId() << " sid:" << sid;
sid = TranslateSidInShard(sid);
sid = SidToId(sid);
auto& sd = dist_.shard_data[sid];
DCHECK(sd.local_mask & ARMED);
sd.local_mask &= ~ARMED;
@ -212,7 +212,7 @@ bool Transaction::RunInShard(ShardId sid) {
// This shard should own a reference for transaction as well as coordinator thread.
DCHECK_GT(use_count(), 1u);
CHECK_GE(Disarm(), 1u);
CHECK_GE(DecreaseRunCnt(), 1u);
// must be computed before intrusive_ptr_release call.
if (concluding) {
@ -298,7 +298,7 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) {
DCHECK_EQ(1u, dist_.shard_data.size());
dist_.shard_data[0].local_mask |= ARMED;
arm_count_.fetch_add(1, memory_order_release); // Decreases in RunLocal.
run_count_.fetch_add(1, memory_order_release); // Decreases in RunLocal.
auto schedule_cb = [&] { return ScheduleUniqueShard(EngineShard::tlocal()); };
run_eager = ess_->Await(unique_shard_id_, std::move(schedule_cb)); // serves as a barrier.
(void)run_eager;
@ -308,7 +308,7 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) {
}
DVLOG(1) << "Before DoneWait " << DebugId() << " " << args_.front();
WaitArm();
WaitForShardCallbacks();
DVLOG(1) << "After DoneWait";
cb_ = nullptr;
@ -324,7 +324,7 @@ void Transaction::Execute(RunnableType cb, bool conclude) {
ExecuteAsync(conclude);
DVLOG(1) << "Wait on " << DebugId();
WaitArm();
WaitForShardCallbacks();
DVLOG(1) << "Wait on " << DebugId() << " completed";
cb_ = nullptr;
dist_.out_of_order.store(false, memory_order_relaxed);
@ -349,7 +349,7 @@ void Transaction::ExecuteAsync(bool concluding_cb) {
use_count_.fetch_add(unique_shard_cnt_, memory_order_relaxed);
if (unique_shard_cnt_ == 1) {
dist_.shard_data[TranslateSidInShard(unique_shard_id_)].local_mask |= ARMED;
dist_.shard_data[SidToId(unique_shard_id_)].local_mask |= ARMED;
} else {
for (ShardId i = 0; i < dist_.shard_data.size(); ++i) {
auto& sd = dist_.shard_data[i];
@ -364,7 +364,7 @@ void Transaction::ExecuteAsync(bool concluding_cb) {
// with a write operation after a release fence. Specifically no writes below will be reordered
// upwards. Important, because it protects non-threadsafe local_mask from being accessed by
// IsArmedInShard in other threads.
arm_count_.fetch_add(unique_shard_cnt_, memory_order_acq_rel);
run_count_.fetch_add(unique_shard_cnt_, memory_order_acq_rel);
auto cb = [this] {
EngineShard* shard = EngineShard::tlocal();
@ -390,7 +390,7 @@ void Transaction::ExecuteAsync(bool concluding_cb) {
}
}
void Transaction::RunQuickSingle() {
void Transaction::RunQuickie() {
DCHECK_EQ(1u, dist_.shard_data.size());
DCHECK_EQ(0u, txid_);
@ -405,7 +405,7 @@ void Transaction::RunQuickSingle() {
sd.local_mask &= ~ARMED;
cb_ = nullptr; // We can do it because only a single shard runs the callback.
CHECK_GE(Disarm(), 1u);
CHECK_GE(DecreaseRunCnt(), 1u);
}
const char* Transaction::Name() const {
@ -437,7 +437,7 @@ bool Transaction::ScheduleUniqueShard(EngineShard* shard) {
// Fast path - for uncontended keys, just run the callback.
// That applies for single key operations like set, get, lpush etc.
if (shard->db_slice().CheckLock(mode, lock_args)) {
RunQuickSingle(); // TODO: for journal - this can become multi-shard
RunQuickie(); // TODO: for journal - this can become multi-shard
// transaction on replica.
return true;
}
@ -473,7 +473,7 @@ pair<bool, bool> Transaction::ScheduleInShard(EngineShard* shard) {
IntentLock::Mode mode = Mode();
bool lock_granted = false;
ShardId sid = TranslateSidInShard(shard->shard_id());
ShardId sid = SidToId(shard->shard_id());
auto& sd = dist_.shard_data[sid];
@ -520,7 +520,7 @@ pair<bool, bool> Transaction::ScheduleInShard(EngineShard* shard) {
}
bool Transaction::CancelInShard(EngineShard* shard) {
ShardId sid = TranslateSidInShard(shard->shard_id());
ShardId sid = SidToId(shard->shard_id());
auto& sd = dist_.shard_data[sid];
auto pos = sd.pq_pos;
@ -566,4 +566,12 @@ size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const {
return dist_.reverse_index[dist_.shard_data[shard_id].arg_start + arg_index];
}
inline uint32_t Transaction::DecreaseRunCnt() {
// We use release so that no stores will be reordered after.
uint32_t res = run_count_.fetch_sub(1, std::memory_order_release);
if (res == 1)
run_ec_.notify();
return res;
}
} // namespace dfly

View file

@ -86,14 +86,13 @@ class Transaction {
if (sid >= dist_.shard_data.size())
sid = 0;
// We use acquire so that no reordering will move before this load.
return arm_count_.load(std::memory_order_acquire) > 0 &&
return run_count_.load(std::memory_order_acquire) > 0 &&
dist_.shard_data[sid].local_mask & ARMED;
}
// Called from engine set shard threads.
uint16_t GetLocalMask(ShardId sid) const {
sid = TranslateSidInShard(sid);
return dist_.shard_data[sid].local_mask;
return dist_.shard_data[SidToId(sid)].local_mask;
}
uint32_t GetStateMask() const {
@ -138,6 +137,10 @@ class Transaction {
const char* Name() const;
DbIndex db_index() const {
return db_index_; // TODO: support multiple db indexes.
}
uint32_t unique_shard_cnt() const {
return unique_shard_cnt_;
}
@ -150,7 +153,7 @@ class Transaction {
bool RunInShard(ShardId sid);
private:
ShardId TranslateSidInShard(ShardId sid) const {
unsigned SidToId(ShardId sid) const {
return sid < dist_.shard_data.size() ? sid : 0;
}
@ -159,7 +162,7 @@ class Transaction {
void ExecuteAsync(bool concluding_cb);
// Optimized version of RunInShard for single shard uncontended cases.
void RunQuickSingle();
void RunQuickie();
//! Returns true if transaction run out-of-order during the scheduling phase.
bool ScheduleUniqueShard(EngineShard* shard);
@ -177,18 +180,16 @@ class Transaction {
//! Runs in the shard thread.
KeyLockArgs GetLockArgs(ShardId sid) const;
void WaitArm() {
arm_ec_.await([this] { return 0 == this->arm_count_.load(std::memory_order_relaxed); });
void WaitForShardCallbacks() {
run_ec_.await([this] { return 0 == run_count_.load(std::memory_order_relaxed); });
}
uint32_t Disarm() {
// We use release so that no stores will be reordered after.
uint32_t res = arm_count_.fetch_sub(1, std::memory_order_release);
arm_ec_.notify();
return res;
}
// Returns the previous value of arm count.
uint32_t DecreaseRunCnt();
uint32_t use_count() const { return use_count_.load(std::memory_order_relaxed); }
uint32_t use_count() const {
return use_count_.load(std::memory_order_relaxed);
}
struct PerShardData {
uint32_t arg_start = 0; // Indices into args_ array.
@ -235,13 +236,14 @@ class Transaction {
EngineShardSet* ess_;
TxId txid_{0};
std::atomic_uint32_t use_count_{0}, arm_count_{0};
std::atomic_uint32_t use_count_{0}, run_count_{0};
// unique_shard_cnt_ and unique_shard_id_ is accessed only by coordinator thread.
uint32_t unique_shard_cnt_{0}; // number of unique shards span by args_
ShardId unique_shard_id_{kInvalidSid};
// Written by coordination thread but may be read by Shard threads.
// A mask of State values. Mostly used for debugging and for invariant checks.
std::atomic<uint16_t> state_mask_{0};
DbIndex db_index_ = 0;
@ -253,7 +255,7 @@ class Transaction {
Dist dist_;
util::fibers_ext::EventCount arm_ec_;
util::fibers_ext::EventCount run_ec_;
//! Stores arguments of the transaction (i.e. keys + values) ordered by shards.
absl::InlinedVector<std::string_view, 4> args_;