diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index bd4a3ac19..02a5e69f9 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -3,7 +3,7 @@ cxx_link(dragonfly base dragonfly_lib) add_library(dragonfly_lib command_registry.cc common.cc config_flags.cc conn_context.cc db_slice.cc debugcmd.cc dragonfly_listener.cc - dragonfly_connection.cc engine_shard_set.cc + dragonfly_connection.cc engine_shard_set.cc generic_family.cc main_service.cc memcache_parser.cc redis_parser.cc reply_builder.cc string_family.cc transaction.cc) diff --git a/server/command_registry.cc b/server/command_registry.cc index f0448741e..fabb9fab3 100644 --- a/server/command_registry.cc +++ b/server/command_registry.cc @@ -60,6 +60,13 @@ void CommandRegistry::Command(CmdArgList args, ConnectionContext* cntx) { cntx->SendRespBlob(resp); } +CommandRegistry& CommandRegistry::operator<<(CommandId cmd) { + string_view k = cmd.name(); + CHECK(cmd_map_.emplace(k, std::move(cmd)).second) << k; + + return *this; +} + namespace CO { const char* OptName(CO::CommandOpt fl) { diff --git a/server/command_registry.h b/server/command_registry.h index 12bc22c9c..a378dfd08 100644 --- a/server/command_registry.h +++ b/server/command_registry.h @@ -111,12 +111,7 @@ class CommandRegistry { public: CommandRegistry(); - CommandRegistry& operator<<(CommandId cmd) { - const char* k = cmd.name(); - cmd_map_.emplace(k, std::move(cmd)); - - return *this; - } + CommandRegistry& operator<<(CommandId cmd); const CommandId* Find(std::string_view cmd) const { auto it = cmd_map_.find(cmd); diff --git a/server/common.cc b/server/common.cc index 38f4a3192..f1e2af236 100644 --- a/server/common.cc +++ b/server/common.cc @@ -21,6 +21,8 @@ string WrongNumArgsError(std::string_view cmd) { const char kSyntaxErr[] = "syntax error"; const char kInvalidIntErr[] = "value is not an integer or out of range"; const char kUintErr[] = "value is out of range, must be positive"; +const char kDbIndOutOfRangeErr[] = "DB index is out of range"; +const char kInvalidDbIndErr[] = "invalid DB index"; } // namespace dfly @@ -39,4 +41,4 @@ ostream& operator<<(ostream& os, dfly::CmdArgList ras) { return os; } -} // namespace std \ No newline at end of file +} // namespace std diff --git a/server/common_types.h b/server/common_types.h index f1b4463e4..5134326df 100644 --- a/server/common_types.h +++ b/server/common_types.h @@ -40,6 +40,11 @@ struct KeyLockArgs { unsigned key_step; }; +struct OpArgs { + EngineShard* shard; + DbIndex db_ind; +}; + inline std::string_view ArgS(CmdArgList args, size_t i) { auto arg = args[i]; return std::string_view(arg.data(), arg.size()); diff --git a/server/conn_context.h b/server/conn_context.h index 37f374869..fe53b75fd 100644 --- a/server/conn_context.h +++ b/server/conn_context.h @@ -14,6 +14,8 @@ class EngineShardSet; class CommandId; struct ConnectionState { + DbIndex db_index = 0; + enum Mask : uint32_t { ASYNC_DISPATCH = 1, // whether a command is handled via async dispatch. CONN_CLOSING = 2, // could be because of unrecoverable error or planned action. @@ -45,6 +47,10 @@ class ConnectionContext : public ReplyBuilder { Protocol protocol() const; + DbIndex db_index() const { + return conn_state.db_index; + } + ConnectionState conn_state; private: diff --git a/server/db_slice.cc b/server/db_slice.cc index 73b1d4d5e..34cbfeeb6 100644 --- a/server/db_slice.cc +++ b/server/db_slice.cc @@ -106,6 +106,22 @@ void DbSlice::CreateDb(DbIndex index) { } } +bool DbSlice::Del(DbIndex db_ind, const MainIterator& it) { + auto& db = db_arr_[db_ind]; + if (it == MainIterator{}) { + return false; + } + + if (it->second.HasExpire()) { + CHECK_EQ(1u, db->expire_table.erase(it->first)); + } + + db->stats.obj_memory_usage -= (it->first.capacity() + it->second.str.capacity()); + db->main_table.erase(it); + + return true; +} + // Returns true if a state has changed, false otherwise. bool DbSlice::Expire(DbIndex db_ind, MainIterator it, uint64_t at) { auto& db = db_arr_[db_ind]; diff --git a/server/db_slice.h b/server/db_slice.h index 258e19e6c..eb7f79b1b 100644 --- a/server/db_slice.h +++ b/server/db_slice.h @@ -64,6 +64,8 @@ class DbSlice { // Creates a database with index `db_ind`. If such database exists does nothing. void ActivateDb(DbIndex db_ind); + bool Del(DbIndex db_ind, const MainIterator& it); + ShardId shard_id() const { return shard_id_; } diff --git a/server/error.h b/server/error.h index 0f0c84bad..aec842f9b 100644 --- a/server/error.h +++ b/server/error.h @@ -13,6 +13,8 @@ std::string WrongNumArgsError(std::string_view cmd); extern const char kSyntaxErr[]; extern const char kInvalidIntErr[]; extern const char kUintErr[]; +extern const char kDbIndOutOfRangeErr[]; +extern const char kInvalidDbIndErr[]; #ifndef RETURN_ON_ERR @@ -24,4 +26,4 @@ extern const char kUintErr[]; } while (0) #endif -} // namespace dfly \ No newline at end of file +} // namespace dfly diff --git a/server/generic_family.cc b/server/generic_family.cc new file mode 100644 index 000000000..7ab5a374e --- /dev/null +++ b/server/generic_family.cc @@ -0,0 +1,266 @@ +// Copyright 2021, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/generic_family.h" + +#include "base/logging.h" +#include "server/command_registry.h" +#include "server/conn_context.h" +#include "server/engine_shard_set.h" +#include "server/error.h" +#include "server/transaction.h" +#include "util/varz.h" + +DEFINE_uint32(dbnum, 16, "Number of databases"); + +namespace dfly { +using namespace std; + +namespace { + +DEFINE_VARZ(VarzQps, ping_qps); + +} // namespace + +void GenericFamily::Init(util::ProactorPool* pp) { + ping_qps.Init(pp); +} + +void GenericFamily::Shutdown() { + ping_qps.Shutdown(); +} + +void GenericFamily::Del(CmdArgList args, ConnectionContext* cntx) { + Transaction* transaction = cntx->transaction; + VLOG(1) << "Del " << ArgS(args, 1); + + atomic_uint32_t result{0}; + auto cb = [&result](const Transaction* t, EngineShard* shard) { + ArgSlice args = t->ShardArgsInShard(shard->shard_id()); + auto res = OpDel(OpArgs{shard, t->db_index()}, args); + result.fetch_add(res.value_or(0), memory_order_relaxed); + + return OpStatus::OK; + }; + + OpStatus status = transaction->ScheduleSingleHop(std::move(cb)); + CHECK_EQ(OpStatus::OK, status); + + DVLOG(2) << "Del ts " << transaction->txid(); + + cntx->SendLong(result.load(memory_order_release)); +} + +void GenericFamily::Ping(CmdArgList args, ConnectionContext* cntx) { + if (args.size() > 2) { + return cntx->SendError("wrong number of arguments for 'ping' command"); + } + ping_qps.Inc(); + + // We synchronously block here until the engine sends us the payload and notifies that + // the I/O operation has been processed. + if (args.size() == 1) { + return cntx->SendSimpleRespString("PONG"); + } else { + std::string_view arg = ArgS(args, 1); + DVLOG(2) << "Ping " << arg; + + return cntx->SendBulkString(arg); + } +} + +void GenericFamily::Exists(CmdArgList args, ConnectionContext* cntx) { + Transaction* transaction = cntx->transaction; + VLOG(1) << "Exists " << ArgS(args, 1); + + atomic_uint32_t result{0}; + + auto cb = [&result](Transaction* t, EngineShard* shard) { + ArgSlice args = t->ShardArgsInShard(shard->shard_id()); + auto res = OpExists(OpArgs{shard, t->db_index()}, args); + result.fetch_add(res.value_or(0), memory_order_relaxed); + + return OpStatus::OK; + }; + + OpStatus status = transaction->ScheduleSingleHop(std::move(cb)); + CHECK_EQ(OpStatus::OK, status); + + return cntx->SendLong(result.load(memory_order_release)); +} + +void GenericFamily::Expire(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view sec = ArgS(args, 2); + int64_t int_arg; + + if (!absl::SimpleAtoi(sec, &int_arg)) { + return cntx->SendError(kInvalidIntErr); + } + + int_arg = std::max(int_arg, -1L); + ExpireParams params{.ts = int_arg}; + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpExpire(OpArgs{shard, t->db_index()}, key, params); + }; + OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); + + cntx->SendLong(status == OpStatus::OK); +} + +void GenericFamily::ExpireAt(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view sec = ArgS(args, 2); + int64_t int_arg; + + if (!absl::SimpleAtoi(sec, &int_arg)) { + return cntx->SendError(kInvalidIntErr); + } + int_arg = std::max(int_arg, 0L); + ExpireParams params{.ts = int_arg, .absolute = true}; + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpExpire(OpArgs{shard, t->db_index()}, key, params); + }; + OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb)); + cntx->SendLong(status == OpStatus::OK); +} + +void GenericFamily::Ttl(CmdArgList args, ConnectionContext* cntx) { + TtlGeneric(args, cntx, TimeUnit::SEC); +} + +void GenericFamily::Pttl(CmdArgList args, ConnectionContext* cntx) { + TtlGeneric(args, cntx, TimeUnit::MSEC); +} + +void GenericFamily::TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit) { + std::string_view key = ArgS(args, 1); + + auto cb = [&](Transaction* t, EngineShard* shard) { return OpTtl(t, shard, key); }; + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + + if (result) { + long ttl = (unit == TimeUnit::SEC) ? (result.value() + 500) / 1000 : result.value(); + cntx->SendLong(ttl); + } else { + switch (result.status()) { + case OpStatus::KEY_NOTFOUND: + cntx->SendLong(-1); + break; + default: + cntx->SendLong(-2); + } + } +} + +void GenericFamily::Select(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + int64_t index; + if (!absl::SimpleAtoi(key, &index)) { + return cntx->SendError(kInvalidDbIndErr); + } + if (index < 0 || index >= FLAGS_dbnum) { + return cntx->SendError(kDbIndOutOfRangeErr); + } + cntx->conn_state.db_index = index; + auto cb = [index](EngineShard* shard) { + shard->db_slice().ActivateDb(index); + return OpStatus::OK; + }; + cntx->shard_set->RunBriefInParallel(std::move(cb)); + + return cntx->SendOk(); +} + +void GenericFamily::Echo(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + return cntx->SendBulkString(key); +} + +OpStatus GenericFamily::OpExpire(const OpArgs& op_args, std::string_view key, + const ExpireParams& params) { + auto& db_slice = op_args.shard->db_slice(); + auto [it, expire_it] = db_slice.FindExt(op_args.db_ind, key); + if (it == MainIterator{}) + return OpStatus::KEY_NOTFOUND; + + int64_t abs_msec = (params.unit == TimeUnit::SEC) ? params.ts * 1000 : params.ts; + + if (!params.absolute) { + abs_msec += db_slice.Now(); + } + + if (abs_msec <= int64_t(db_slice.Now())) { + CHECK(db_slice.Del(op_args.db_ind, it)); + } else if (expire_it != ExpireIterator{}) { + expire_it->second = abs_msec; + } else { + db_slice.Expire(op_args.db_ind, it, abs_msec); + } + + return OpStatus::OK; +} + +OpResult GenericFamily::OpTtl(Transaction* t, EngineShard* shard, std::string_view key) { + auto& db_slice = shard->db_slice(); + auto [it, expire] = db_slice.FindExt(t->db_index(), key); + if (it == MainIterator{}) + return OpStatus::KEY_NOTFOUND; + + if (expire == ExpireIterator{}) + return OpStatus::SKIPPED; + + int64_t ttl_ms = expire->second - db_slice.Now(); + DCHECK_GT(ttl_ms, 0); // Otherwise FindExt would return null. + return ttl_ms; +} + +OpResult GenericFamily::OpDel(const OpArgs& op_args, ArgSlice keys) { + DVLOG(1) << "Del: " << keys[0]; + auto& db_slice = op_args.shard->db_slice(); + + uint32_t res = 0; + + for (uint32_t i = 0; i < keys.size(); ++i) { + auto fres = db_slice.FindExt(op_args.db_ind, keys[i]); + if (fres.first == MainIterator{}) + continue; + res += int(db_slice.Del(op_args.db_ind, fres.first)); + } + + return res; +} + +OpResult GenericFamily::OpExists(const OpArgs& op_args, ArgSlice keys) { + DVLOG(1) << "Exists: " << keys[0]; + auto& db_slice = op_args.shard->db_slice(); + uint32_t res = 0; + + for (uint32_t i = 0; i < keys.size(); ++i) { + auto find_res = db_slice.FindExt(op_args.db_ind, keys[i]); + res += (find_res.first != MainIterator{}); + } + return res; +} + +using CI = CommandId; + +#define HFUNC(x) SetHandler(&GenericFamily::x) + +void GenericFamily::Register(CommandRegistry* registry) { + constexpr auto kSelectOpts = CO::LOADING | CO::FAST | CO::STALE; + *registry << CI{"DEL", CO::WRITE, -2, 1, -1, 1}.HFUNC(Del) + << CI{"PING", CO::STALE | CO::FAST, -1, 0, 0, 0}.HFUNC(Ping) + << CI{"ECHO", CO::READONLY | CO::FAST, 2, 0, 0, 0}.HFUNC(Echo) + << CI{"EXISTS", CO::READONLY | CO::FAST, -2, 1, -1, 1}.HFUNC(Exists) + << CI{"EXPIRE", CO::WRITE | CO::FAST, 3, 1, 1, 1}.HFUNC(Expire) + << CI{"EXPIREAT", CO::WRITE | CO::FAST, 3, 1, 1, 1}.HFUNC(ExpireAt) + << CI{"SELECT", kSelectOpts, 2, 0, 0, 0}.HFUNC(Select) + << CI{"TTL", CO::READONLY | CO::FAST | CO::RANDOM, 2, 1, 1, 1}.HFUNC(Ttl) + << CI{"PTTL", CO::READONLY | CO::FAST | CO::RANDOM, 2, 1, 1, 1}.HFUNC(Pttl); +} + +} // namespace dfly diff --git a/server/generic_family.h b/server/generic_family.h new file mode 100644 index 000000000..f0dafe176 --- /dev/null +++ b/server/generic_family.h @@ -0,0 +1,58 @@ +// Copyright 2021, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "core/op_status.h" +#include "server/common_types.h" + +namespace util { +class ProactorPool; +} // namespace util + +namespace dfly { + +class ConnectionContext; +class CommandRegistry; +class EngineShard; + +class GenericFamily { + public: + static void Init(util::ProactorPool* pp); + static void Shutdown(); + + static void Register(CommandRegistry* registry); + + private: + enum TimeUnit { SEC, MSEC }; + + struct ExpireParams { + int64_t ts; + bool absolute = false; + + TimeUnit unit = SEC; + }; + + static void Del(CmdArgList args, ConnectionContext* cntx); + static void Ping(CmdArgList args, ConnectionContext* cntx); + static void Exists(CmdArgList args, ConnectionContext* cntx); + static void Expire(CmdArgList args, ConnectionContext* cntx); + static void ExpireAt(CmdArgList args, ConnectionContext* cntx); + + static void Ttl(CmdArgList args, ConnectionContext* cntx); + static void Pttl(CmdArgList args, ConnectionContext* cntx); + + static void Echo(CmdArgList args, ConnectionContext* cntx); + static void Select(CmdArgList args, ConnectionContext* cntx); + + static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit); + + static OpStatus OpExpire(const OpArgs& op_args, std::string_view key, const ExpireParams& params); + + static OpResult OpTtl(Transaction* t, EngineShard* shard, std::string_view key); + static OpResult OpDel(const OpArgs& op_args, ArgSlice keys); + static OpResult OpExists(const OpArgs& op_args, ArgSlice keys); +}; + +} // namespace dfly diff --git a/server/main_service.cc b/server/main_service.cc index a5a4e7690..813591108 100644 --- a/server/main_service.cc +++ b/server/main_service.cc @@ -14,6 +14,7 @@ #include "server/conn_context.h" #include "server/debugcmd.h" #include "server/error.h" +#include "server/generic_family.h" #include "server/string_family.h" #include "server/transaction.h" #include "util/metrics/metrics.h" @@ -64,6 +65,7 @@ void Service::Init(util::AcceptServer* acceptor) { request_latency_usec.Init(&pp_); ping_qps.Init(&pp_); StringFamily::Init(&pp_); + GenericFamily::Init(&pp_); cmd_req.Init(&pp_, {"type"}); } @@ -74,7 +76,7 @@ void Service::Shutdown() { request_latency_usec.Shutdown(); ping_qps.Shutdown(); StringFamily::Shutdown(); - + GenericFamily::Shutdown(); shard_set_.RunBlockingInParallel([&](EngineShard*) { EngineShard::DestroyThreadLocal(); }); } @@ -173,21 +175,6 @@ void Service::RegisterHttp(HttpListenerBase* listener) { CHECK_NOTNULL(listener); } -void Service::Ping(CmdArgList args, ConnectionContext* cntx) { - if (args.size() > 2) { - return cntx->SendError(WrongNumArgsError("PING")); - } - ping_qps.Inc(); - - if (args.size() == 1) { - return cntx->SendSimpleRespString("PONG"); - } - std::string_view arg = ArgS(args, 1); - DVLOG(2) << "Ping " << arg; - - return cntx->SendSimpleRespString(arg); -} - void Service::Debug(CmdArgList args, ConnectionContext* cntx) { ToUpper(&args[1]); @@ -216,9 +203,10 @@ inline CommandId::Handler HandlerFunc(Service* se, ServiceFunc f) { void Service::RegisterCommands() { using CI = CommandId; - registry_ << CI{"PING", CO::STALE | CO::FAST, -1, 0, 0, 0}.HFUNC(Ping) - << CI{"DEBUG", CO::RANDOM | CO::READONLY, -2, 0, 0, 0}.HFUNC(Debug); + registry_ << CI{"DEBUG", CO::RANDOM | CO::READONLY, -2, 0, 0, 0}.HFUNC(Debug); + StringFamily::Register(®istry_); + GenericFamily::Register(®istry_); } } // namespace dfly diff --git a/server/main_service.h b/server/main_service.h index ec304d07b..d182c45f9 100644 --- a/server/main_service.h +++ b/server/main_service.h @@ -46,7 +46,6 @@ class Service { } private: - void Ping(CmdArgList args, ConnectionContext* cntx); void Debug(CmdArgList args, ConnectionContext* cntx); void RegisterCommands(); diff --git a/server/reply_builder.cc b/server/reply_builder.cc index a5a840d26..c4e9e18b9 100644 --- a/server/reply_builder.cc +++ b/server/reply_builder.cc @@ -177,6 +177,11 @@ void ReplyBuilder::SendGetNotFound() { } } +void ReplyBuilder::SendLong(long num) { + string str = absl::StrCat(":", num, kCRLF); + as_resp()->SendDirect(str); +} + void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) { string res = absl::StrCat("*", count, kCRLF); for (size_t i = 0; i < count; ++i) { diff --git a/server/reply_builder.h b/server/reply_builder.h index 7ca38f188..0a319bd95 100644 --- a/server/reply_builder.h +++ b/server/reply_builder.h @@ -98,6 +98,9 @@ class ReplyBuilder { void SendGetReply(std::string_view key, uint32_t flags, std::string_view value); void SendGetNotFound(); + using StrOrNil = std::optional; + void SendMGetResponse(const StrOrNil* arr, uint32_t count); + void SetBatchMode(bool mode) { serializer_->SetBatchMode(mode); } @@ -110,8 +113,11 @@ class ReplyBuilder { as_resp()->SendNull(); } - using StrOrNil = std::optional; - void SendMGetResponse(const StrOrNil* arr, uint32_t count); + void SendLong(long val); + + void SendBulkString(std::string_view str) { + as_resp()->SendBulkString(str); + } private: RespSerializer* as_resp() { diff --git a/server/string_family.cc b/server/string_family.cc index 3692b9074..c6bbfc16d 100644 --- a/server/string_family.cc +++ b/server/string_family.cc @@ -80,7 +80,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) { std::string_view value = ArgS(args, 2); VLOG(2) << "Set " << key << " " << value; - SetCmd::SetParams sparams{0}; // TODO: db_index. + SetCmd::SetParams sparams{cntx->db_index()}; // TODO: db_index. int64_t int_arg; for (size_t i = 3; i < args.size(); ++i) { @@ -139,7 +139,7 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) { std::string_view key = ArgS(args, 1); auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - OpResult it_res = shard->db_slice().Find(0, key); + OpResult it_res = shard->db_slice().Find(cntx->db_index(), key); if (!it_res.ok()) return it_res.status(); @@ -166,12 +166,12 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) { std::string_view value = ArgS(args, 2); std::optional prev_val; - SetCmd::SetParams sparams{0}; + SetCmd::SetParams sparams{cntx->db_index()}; sparams.prev_val = &prev_val; ShardId sid = Shard(key, cntx->shard_set->size()); OpResult result = cntx->shard_set->Await(sid, [&] { - EngineShard* es = EngineShard::tlocal(); + EngineShard* es = EngineShard::tlocal(); SetCmd cmd(&es->db_slice()); return cmd.Set(sparams, key, value); @@ -245,7 +245,6 @@ void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) { return cntx->SendOk(); } - auto StringFamily::OpMGet(const Transaction* t, EngineShard* shard) -> MGetResponse { auto args = t->ShardArgsInShard(shard->shard_id()); DCHECK(!args.empty()); @@ -278,7 +277,6 @@ OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) { return OpStatus::OK; } - void StringFamily::Init(util::ProactorPool* pp) { set_qps.Init(pp); get_qps.Init(pp); diff --git a/server/transaction.cc b/server/transaction.cc index c48e6e748..1feb3a466 100644 --- a/server/transaction.cc +++ b/server/transaction.cc @@ -183,7 +183,7 @@ bool Transaction::RunInShard(ShardId sid) { DVLOG(1) << "RunInShard: " << DebugId() << " sid:" << sid; - sid = TranslateSidInShard(sid); + sid = SidToId(sid); auto& sd = dist_.shard_data[sid]; DCHECK(sd.local_mask & ARMED); sd.local_mask &= ~ARMED; @@ -212,7 +212,7 @@ bool Transaction::RunInShard(ShardId sid) { // This shard should own a reference for transaction as well as coordinator thread. DCHECK_GT(use_count(), 1u); - CHECK_GE(Disarm(), 1u); + CHECK_GE(DecreaseRunCnt(), 1u); // must be computed before intrusive_ptr_release call. if (concluding) { @@ -298,7 +298,7 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) { DCHECK_EQ(1u, dist_.shard_data.size()); dist_.shard_data[0].local_mask |= ARMED; - arm_count_.fetch_add(1, memory_order_release); // Decreases in RunLocal. + run_count_.fetch_add(1, memory_order_release); // Decreases in RunLocal. auto schedule_cb = [&] { return ScheduleUniqueShard(EngineShard::tlocal()); }; run_eager = ess_->Await(unique_shard_id_, std::move(schedule_cb)); // serves as a barrier. (void)run_eager; @@ -308,7 +308,7 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) { } DVLOG(1) << "Before DoneWait " << DebugId() << " " << args_.front(); - WaitArm(); + WaitForShardCallbacks(); DVLOG(1) << "After DoneWait"; cb_ = nullptr; @@ -324,7 +324,7 @@ void Transaction::Execute(RunnableType cb, bool conclude) { ExecuteAsync(conclude); DVLOG(1) << "Wait on " << DebugId(); - WaitArm(); + WaitForShardCallbacks(); DVLOG(1) << "Wait on " << DebugId() << " completed"; cb_ = nullptr; dist_.out_of_order.store(false, memory_order_relaxed); @@ -349,7 +349,7 @@ void Transaction::ExecuteAsync(bool concluding_cb) { use_count_.fetch_add(unique_shard_cnt_, memory_order_relaxed); if (unique_shard_cnt_ == 1) { - dist_.shard_data[TranslateSidInShard(unique_shard_id_)].local_mask |= ARMED; + dist_.shard_data[SidToId(unique_shard_id_)].local_mask |= ARMED; } else { for (ShardId i = 0; i < dist_.shard_data.size(); ++i) { auto& sd = dist_.shard_data[i]; @@ -364,7 +364,7 @@ void Transaction::ExecuteAsync(bool concluding_cb) { // with a write operation after a release fence. Specifically no writes below will be reordered // upwards. Important, because it protects non-threadsafe local_mask from being accessed by // IsArmedInShard in other threads. - arm_count_.fetch_add(unique_shard_cnt_, memory_order_acq_rel); + run_count_.fetch_add(unique_shard_cnt_, memory_order_acq_rel); auto cb = [this] { EngineShard* shard = EngineShard::tlocal(); @@ -390,7 +390,7 @@ void Transaction::ExecuteAsync(bool concluding_cb) { } } -void Transaction::RunQuickSingle() { +void Transaction::RunQuickie() { DCHECK_EQ(1u, dist_.shard_data.size()); DCHECK_EQ(0u, txid_); @@ -405,7 +405,7 @@ void Transaction::RunQuickSingle() { sd.local_mask &= ~ARMED; cb_ = nullptr; // We can do it because only a single shard runs the callback. - CHECK_GE(Disarm(), 1u); + CHECK_GE(DecreaseRunCnt(), 1u); } const char* Transaction::Name() const { @@ -437,7 +437,7 @@ bool Transaction::ScheduleUniqueShard(EngineShard* shard) { // Fast path - for uncontended keys, just run the callback. // That applies for single key operations like set, get, lpush etc. if (shard->db_slice().CheckLock(mode, lock_args)) { - RunQuickSingle(); // TODO: for journal - this can become multi-shard + RunQuickie(); // TODO: for journal - this can become multi-shard // transaction on replica. return true; } @@ -473,7 +473,7 @@ pair Transaction::ScheduleInShard(EngineShard* shard) { IntentLock::Mode mode = Mode(); bool lock_granted = false; - ShardId sid = TranslateSidInShard(shard->shard_id()); + ShardId sid = SidToId(shard->shard_id()); auto& sd = dist_.shard_data[sid]; @@ -520,7 +520,7 @@ pair Transaction::ScheduleInShard(EngineShard* shard) { } bool Transaction::CancelInShard(EngineShard* shard) { - ShardId sid = TranslateSidInShard(shard->shard_id()); + ShardId sid = SidToId(shard->shard_id()); auto& sd = dist_.shard_data[sid]; auto pos = sd.pq_pos; @@ -566,4 +566,12 @@ size_t Transaction::ReverseArgIndex(ShardId shard_id, size_t arg_index) const { return dist_.reverse_index[dist_.shard_data[shard_id].arg_start + arg_index]; } +inline uint32_t Transaction::DecreaseRunCnt() { + // We use release so that no stores will be reordered after. + uint32_t res = run_count_.fetch_sub(1, std::memory_order_release); + if (res == 1) + run_ec_.notify(); + return res; +} + } // namespace dfly diff --git a/server/transaction.h b/server/transaction.h index f87572aa3..eb6eebda5 100644 --- a/server/transaction.h +++ b/server/transaction.h @@ -86,14 +86,13 @@ class Transaction { if (sid >= dist_.shard_data.size()) sid = 0; // We use acquire so that no reordering will move before this load. - return arm_count_.load(std::memory_order_acquire) > 0 && + return run_count_.load(std::memory_order_acquire) > 0 && dist_.shard_data[sid].local_mask & ARMED; } // Called from engine set shard threads. uint16_t GetLocalMask(ShardId sid) const { - sid = TranslateSidInShard(sid); - return dist_.shard_data[sid].local_mask; + return dist_.shard_data[SidToId(sid)].local_mask; } uint32_t GetStateMask() const { @@ -138,6 +137,10 @@ class Transaction { const char* Name() const; + DbIndex db_index() const { + return db_index_; // TODO: support multiple db indexes. + } + uint32_t unique_shard_cnt() const { return unique_shard_cnt_; } @@ -150,7 +153,7 @@ class Transaction { bool RunInShard(ShardId sid); private: - ShardId TranslateSidInShard(ShardId sid) const { + unsigned SidToId(ShardId sid) const { return sid < dist_.shard_data.size() ? sid : 0; } @@ -159,7 +162,7 @@ class Transaction { void ExecuteAsync(bool concluding_cb); // Optimized version of RunInShard for single shard uncontended cases. - void RunQuickSingle(); + void RunQuickie(); //! Returns true if transaction run out-of-order during the scheduling phase. bool ScheduleUniqueShard(EngineShard* shard); @@ -177,18 +180,16 @@ class Transaction { //! Runs in the shard thread. KeyLockArgs GetLockArgs(ShardId sid) const; - void WaitArm() { - arm_ec_.await([this] { return 0 == this->arm_count_.load(std::memory_order_relaxed); }); + void WaitForShardCallbacks() { + run_ec_.await([this] { return 0 == run_count_.load(std::memory_order_relaxed); }); } - uint32_t Disarm() { - // We use release so that no stores will be reordered after. - uint32_t res = arm_count_.fetch_sub(1, std::memory_order_release); - arm_ec_.notify(); - return res; - } + // Returns the previous value of arm count. + uint32_t DecreaseRunCnt(); - uint32_t use_count() const { return use_count_.load(std::memory_order_relaxed); } + uint32_t use_count() const { + return use_count_.load(std::memory_order_relaxed); + } struct PerShardData { uint32_t arg_start = 0; // Indices into args_ array. @@ -235,13 +236,14 @@ class Transaction { EngineShardSet* ess_; TxId txid_{0}; - std::atomic_uint32_t use_count_{0}, arm_count_{0}; + std::atomic_uint32_t use_count_{0}, run_count_{0}; // unique_shard_cnt_ and unique_shard_id_ is accessed only by coordinator thread. uint32_t unique_shard_cnt_{0}; // number of unique shards span by args_ ShardId unique_shard_id_{kInvalidSid}; // Written by coordination thread but may be read by Shard threads. + // A mask of State values. Mostly used for debugging and for invariant checks. std::atomic state_mask_{0}; DbIndex db_index_ = 0; @@ -253,7 +255,7 @@ class Transaction { Dist dist_; - util::fibers_ext::EventCount arm_ec_; + util::fibers_ext::EventCount run_ec_; //! Stores arguments of the transaction (i.e. keys + values) ordered by shards. absl::InlinedVector args_;