feat(replica): support multi transaction command (#634)

This commit is contained in:
adiholden 2023-01-04 09:11:30 +02:00 committed by GitHub
parent b944324bbf
commit 3065946b9a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 223 additions and 106 deletions

View file

@ -15,22 +15,16 @@ JournalExecutor::JournalExecutor(Service* service)
conn_context_.journal_emulated = true; conn_context_.journal_emulated = true;
} }
void JournalExecutor::Execute(std::vector<journal::ParsedEntry>& entries) { void JournalExecutor::Execute(DbIndex dbid, std::vector<journal::ParsedEntry::CmdData>& cmds) {
DCHECK_GT(entries.size(), 1U); DCHECK_GT(cmds.size(), 1U);
conn_context_.conn_state.db_index = entries.front().dbid; conn_context_.conn_state.db_index = dbid;
std::string multi_cmd = {"MULTI"}; std::string multi_cmd = {"MULTI"};
auto ms = MutableSlice{&multi_cmd[0], multi_cmd.size()}; auto ms = MutableSlice{&multi_cmd[0], multi_cmd.size()};
auto span = CmdArgList{&ms, 1}; auto span = CmdArgList{&ms, 1};
service_->DispatchCommand(span, &conn_context_); service_->DispatchCommand(span, &conn_context_);
for (auto& entry : entries) { for (auto& cmd : cmds) {
if (entry.payload) { Execute(cmd);
DCHECK_EQ(entry.dbid, conn_context_.conn_state.db_index);
span = CmdArgList{entry.payload->data(), entry.payload->size()};
service_->DispatchCommand(span, &conn_context_);
}
} }
std::string exec_cmd = {"EXEC"}; std::string exec_cmd = {"EXEC"};
@ -39,13 +33,14 @@ void JournalExecutor::Execute(std::vector<journal::ParsedEntry>& entries) {
service_->DispatchCommand(span, &conn_context_); service_->DispatchCommand(span, &conn_context_);
} }
void JournalExecutor::Execute(journal::ParsedEntry& entry) { void JournalExecutor::Execute(DbIndex dbid, journal::ParsedEntry::CmdData& cmd) {
conn_context_.conn_state.db_index = entry.dbid; conn_context_.conn_state.db_index = dbid;
if (entry.payload) { // TODO - when this is false? Execute(cmd);
auto span = CmdArgList{entry.payload->data(), entry.payload->size()}; }
void JournalExecutor::Execute(journal::ParsedEntry::CmdData& cmd) {
auto span = CmdArgList{cmd.cmd_args.data(), cmd.cmd_args.size()};
service_->DispatchCommand(span, &conn_context_); service_->DispatchCommand(span, &conn_context_);
}
} }
} // namespace dfly } // namespace dfly

View file

@ -14,10 +14,11 @@ class Service;
class JournalExecutor { class JournalExecutor {
public: public:
JournalExecutor(Service* service); JournalExecutor(Service* service);
void Execute(std::vector<journal::ParsedEntry>& entries); void Execute(DbIndex dbid, std::vector<journal::ParsedEntry::CmdData>& cmds);
void Execute(journal::ParsedEntry& entry); void Execute(DbIndex dbid, journal::ParsedEntry::CmdData& cmd);
private: private:
void Execute(journal::ParsedEntry::CmdData& cmd);
Service* service_; Service* service_;
ConnectionContext conn_context_; ConnectionContext conn_context_;
io::NullSink null_sink_; io::NullSink null_sink_;

View file

@ -115,6 +115,10 @@ void Journal::RecordEntry(const Entry& entry) {
journal_slice.AddLogRecord(entry); journal_slice.AddLogRecord(entry);
} }
TxId Journal::GetLastTxId() {
return journal_slice.GetLastTxId();
}
/* /*
void Journal::OpArgs(TxId txid, Op opcode, Span keys) { void Journal::OpArgs(TxId txid, Op opcode, Span keys) {
DCHECK(journal_slice.IsOpen()); DCHECK(journal_slice.IsOpen());

View file

@ -54,6 +54,7 @@ class Journal {
LSN GetLsn() const; LSN GetLsn() const;
void RecordEntry(const Entry& entry); void RecordEntry(const Entry& entry);
TxId GetLastTxId();
private: private:
mutable boost::fibers::mutex state_mu_; mutable boost::fibers::mutex state_mu_;

View file

@ -117,7 +117,7 @@ error_code JournalSlice::Close() {
void JournalSlice::AddLogRecord(const Entry& entry) { void JournalSlice::AddLogRecord(const Entry& entry) {
DCHECK(ring_buffer_); DCHECK(ring_buffer_);
last_txid_ = entry.txid;
iterating_cb_arr_ = true; iterating_cb_arr_ = true;
for (const auto& k_v : change_cb_arr_) { for (const auto& k_v : change_cb_arr_) {
k_v.second(entry); k_v.second(entry);

View file

@ -47,6 +47,10 @@ class JournalSlice {
uint32_t RegisterOnChange(ChangeCallback cb); uint32_t RegisterOnChange(ChangeCallback cb);
void UnregisterOnChange(uint32_t); void UnregisterOnChange(uint32_t);
TxId GetLastTxId() {
return last_txid_;
}
private: private:
struct RingItem; struct RingItem;
@ -62,7 +66,7 @@ class JournalSlice {
uint32_t slice_index_ = UINT32_MAX; uint32_t slice_index_ = UINT32_MAX;
uint32_t next_cb_id_ = 1; uint32_t next_cb_id_ = 1;
TxId last_txid_ = 0;
std::error_code status_ec_; std::error_code status_ec_;
bool lameduck_ = false; bool lameduck_ = false;

View file

@ -42,6 +42,11 @@ void JournalWriter::Write(std::string_view sv) {
void JournalWriter::Write(CmdArgList args) { void JournalWriter::Write(CmdArgList args) {
Write(args.size()); Write(args.size());
size_t cmd_size = 0;
for (auto v : args) {
cmd_size += v.size();
}
Write(cmd_size);
for (auto v : args) for (auto v : args)
Write(facade::ToSV(v)); Write(facade::ToSV(v));
} }
@ -50,6 +55,13 @@ void JournalWriter::Write(std::pair<std::string_view, ArgSlice> args) {
auto [cmd, tail_args] = args; auto [cmd, tail_args] = args;
Write(1 + tail_args.size()); Write(1 + tail_args.size());
size_t cmd_size = cmd.size();
for (auto v : tail_args) {
cmd_size += v.size();
}
Write(cmd_size);
Write(cmd); Write(cmd);
for (auto v : tail_args) for (auto v : tail_args)
Write(v); Write(v);
@ -71,6 +83,8 @@ void JournalWriter::Write(const journal::Entry& entry) {
case journal::Op::SELECT: case journal::Op::SELECT:
return Write(entry.dbid); return Write(entry.dbid);
case journal::Op::COMMAND: case journal::Op::COMMAND:
case journal::Op::MULTI_COMMAND:
case journal::Op::EXEC:
Write(entry.txid); Write(entry.txid);
Write(entry.shard_cnt); Write(entry.shard_cnt);
return std::visit([this](const auto& payload) { return Write(payload); }, entry.payload); return std::visit([this](const auto& payload) { return Write(payload); }, entry.payload);
@ -80,7 +94,7 @@ void JournalWriter::Write(const journal::Entry& entry) {
} }
JournalReader::JournalReader(io::Source* source, DbIndex dbid) JournalReader::JournalReader(io::Source* source, DbIndex dbid)
: str_buf_{}, source_{source}, buf_{4096}, dbid_{dbid} { : source_{source}, buf_{4_KB}, dbid_{dbid} {
} }
void JournalReader::SetDb(DbIndex dbid) { void JournalReader::SetDb(DbIndex dbid) {
@ -134,63 +148,63 @@ template io::Result<uint16_t> JournalReader::ReadUInt<uint16_t>();
template io::Result<uint32_t> JournalReader::ReadUInt<uint32_t>(); template io::Result<uint32_t> JournalReader::ReadUInt<uint32_t>();
template io::Result<uint64_t> JournalReader::ReadUInt<uint64_t>(); template io::Result<uint64_t> JournalReader::ReadUInt<uint64_t>();
io::Result<size_t> JournalReader::ReadString() { io::Result<size_t> JournalReader::ReadString(char* buffer) {
size_t size = 0; size_t size = 0;
SET_OR_UNEXPECT(ReadUInt<uint64_t>(), size); SET_OR_UNEXPECT(ReadUInt<uint64_t>(), size);
if (auto ec = EnsureRead(size); ec) if (auto ec = EnsureRead(size); ec)
return make_unexpected(ec); return make_unexpected(ec);
unsigned offset = str_buf_.size(); buf_.ReadAndConsume(size, buffer);
str_buf_.resize(offset + size);
buf_.ReadAndConsume(size, str_buf_.data() + offset);
return size; return size;
} }
std::error_code JournalReader::Read(CmdArgVec* vec) { std::error_code JournalReader::ReadCommand(journal::ParsedEntry::CmdData* data) {
size_t num_strings = 0; size_t num_strings = 0;
SET_OR_RETURN(ReadUInt<uint64_t>(), num_strings); SET_OR_RETURN(ReadUInt<uint64_t>(), num_strings);
vec->resize(num_strings); data->cmd_args.resize(num_strings);
size_t cmd_size = 0;
SET_OR_RETURN(ReadUInt<uint64_t>(), cmd_size);
// Read all strings consecutively. // Read all strings consecutively.
str_buf_.clear(); data->command_buf = make_unique<char[]>(cmd_size);
for (auto& span : *vec) { char* ptr = data->command_buf.get();
for (auto& span : data->cmd_args) {
size_t size; size_t size;
SET_OR_RETURN(ReadString(), size); SET_OR_RETURN(ReadString(ptr), size);
span = MutableSlice{nullptr, size}; span = MutableSlice{ptr, size};
ptr += size;
} }
// Set span pointers, now that string buffer won't reallocate.
char* ptr = str_buf_.data();
for (auto& span : *vec) {
span = {ptr, span.size()};
ptr += span.size();
}
return std::error_code{}; return std::error_code{};
} }
io::Result<journal::ParsedEntry> JournalReader::ReadEntry() { io::Result<journal::ParsedEntry> JournalReader::ReadEntry() {
uint8_t opcode; uint8_t int_op;
SET_OR_UNEXPECT(ReadUInt<uint8_t>(), opcode); SET_OR_UNEXPECT(ReadUInt<uint8_t>(), int_op);
journal::Op opcode = static_cast<journal::Op>(int_op);
journal::ParsedEntry entry{static_cast<journal::Op>(opcode), dbid_}; if (opcode == journal::Op::SELECT) {
switch (entry.opcode) {
case journal::Op::COMMAND:
SET_OR_UNEXPECT(ReadUInt<uint64_t>(), entry.txid);
SET_OR_UNEXPECT(ReadUInt<uint32_t>(), entry.shard_cnt);
entry.payload = CmdArgVec{};
if (auto ec = Read(&*entry.payload); ec)
return make_unexpected(ec);
break;
case journal::Op::SELECT:
SET_OR_UNEXPECT(ReadUInt<uint16_t>(), dbid_); SET_OR_UNEXPECT(ReadUInt<uint16_t>(), dbid_);
return ReadEntry(); return ReadEntry();
default: }
break;
}; journal::ParsedEntry entry;
entry.dbid = dbid_;
entry.opcode = opcode;
SET_OR_UNEXPECT(ReadUInt<uint64_t>(), entry.txid);
SET_OR_UNEXPECT(ReadUInt<uint32_t>(), entry.shard_cnt);
if (opcode == journal::Op::EXEC) {
return entry;
}
auto ec = ReadCommand(&entry.cmd);
if (ec)
return make_unexpected(ec);
return entry; return entry;
} }

View file

@ -63,14 +63,13 @@ struct JournalReader {
// Read unsigned integer in packed encoding. // Read unsigned integer in packed encoding.
template <typename UT> io::Result<UT> ReadUInt(); template <typename UT> io::Result<UT> ReadUInt();
// Read and append string to string buffer, return size. // Read and copy to buffer, return size.
io::Result<size_t> ReadString(); io::Result<size_t> ReadString(char* buffer);
// Read argument array into string buffer. // Read argument array into string buffer.
std::error_code Read(CmdArgVec* vec); std::error_code ReadCommand(journal::ParsedEntry::CmdData* entry);
private: private:
std::string str_buf_; // last parsed entry points here
io::Source* source_; io::Source* source_;
base::IoBuf buf_; base::IoBuf buf_;
DbIndex dbid_; DbIndex dbid_;

View file

@ -16,6 +16,8 @@ enum class Op : uint8_t {
NOOP = 0, NOOP = 0,
SELECT = 6, SELECT = 6,
COMMAND = 10, COMMAND = 10,
MULTI_COMMAND = 11,
EXEC = 12,
}; };
struct EntryBase { struct EntryBase {
@ -35,30 +37,25 @@ struct Entry : public EntryBase {
std::pair<std::string_view, ArgSlice> // Command and its shard parts. std::pair<std::string_view, ArgSlice> // Command and its shard parts.
>; >;
Entry(TxId txid, DbIndex dbid, Payload pl, uint32_t shard_cnt) Entry(TxId txid, Op opcode, DbIndex dbid, uint32_t shard_cnt, Payload pl)
: EntryBase{txid, journal::Op::COMMAND, dbid, shard_cnt}, payload{pl} { : EntryBase{txid, opcode, dbid, shard_cnt}, payload{pl} {
} }
Entry(journal::Op opcode, DbIndex dbid) : EntryBase{0, opcode, dbid, 0}, payload{} { Entry(journal::Op opcode, DbIndex dbid) : EntryBase{0, opcode, dbid, 0}, payload{} {
} }
Entry(TxId txid, journal::Op opcode, DbIndex dbid, uint32_t shard_cnt)
: EntryBase{txid, opcode, dbid, shard_cnt}, payload{} {
}
Payload payload; Payload payload;
}; };
struct ParsedEntry : public EntryBase { struct ParsedEntry : public EntryBase {
// Payload represents the parsed command. struct CmdData {
using Payload = std::optional<CmdArgVec>; std::unique_ptr<char[]> command_buf;
CmdArgVec cmd_args; // represents the parsed command.
ParsedEntry() = default; };
CmdData cmd;
ParsedEntry(journal::Op opcode, DbIndex dbid) : EntryBase{0, opcode, dbid, 0}, payload{} {
}
ParsedEntry(TxId txid, DbIndex dbid, Payload pl, uint32_t shard_cnt)
: EntryBase{txid, journal::Op::COMMAND, dbid, shard_cnt}, payload{pl} {
}
Payload payload;
}; };
using ChangeCallback = std::function<void(const Entry&)>; using ChangeCallback = std::function<void(const Entry&)>;

View file

@ -47,7 +47,8 @@ struct EntryPayloadVisitor {
std::string ExtractPayload(journal::ParsedEntry& entry) { std::string ExtractPayload(journal::ParsedEntry& entry) {
std::string out; std::string out;
EntryPayloadVisitor visitor{&out}; EntryPayloadVisitor visitor{&out};
CmdArgList list{entry.payload->data(), entry.payload->size()};
CmdArgList list{entry.cmd.cmd_args.data(), entry.cmd.cmd_args.size()};
visitor(list); visitor(list);
if (out.size() > 0 && out.back() == ' ') if (out.size() > 0 && out.back() == ' ')
@ -97,13 +98,15 @@ TEST(Journal, WriteRead) {
auto list = [v = &lists](auto... ss) { return StoreList(v, ss...); }; auto list = [v = &lists](auto... ss) { return StoreList(v, ss...); };
std::vector<journal::Entry> test_entries = { std::vector<journal::Entry> test_entries = {
{0, 0, make_pair("MSET", slice("A", "1", "B", "2")), 2}, {0, journal::Op::COMMAND, 0, 2, make_pair("MSET", slice("A", "1", "B", "2"))},
{0, 0, make_pair("MSET", slice("C", "3")), 2}, {0, journal::Op::COMMAND, 0, 2, make_pair("MSET", slice("C", "3"))},
{1, 0, list("DEL", "A", "B"), 2}, {1, journal::Op::COMMAND, 0, 2, list("DEL", "A", "B")},
{2, 1, list("LPUSH", "l", "v1", "v2"), 1}, {2, journal::Op::COMMAND, 1, 1, list("LPUSH", "l", "v1", "v2")},
{3, 0, make_pair("MSET", slice("D", "4")), 1}, {3, journal::Op::COMMAND, 0, 1, make_pair("MSET", slice("D", "4"))},
{4, 1, list("DEL", "l1"), 1}, {4, journal::Op::COMMAND, 1, 1, list("DEL", "l1")},
{5, 2, list("SET", "E", "2"), 1}}; {5, journal::Op::COMMAND, 2, 1, list("SET", "E", "2")},
{6, journal::Op::MULTI_COMMAND, 2, 1, list("SET", "E", "2")},
{6, journal::Op::EXEC, 2, 1}};
// Write all entries to string file. // Write all entries to string file.
JournalWriter writer{}; JournalWriter writer{};

View file

@ -1961,7 +1961,9 @@ error_code RdbLoaderBase::HandleJournalBlob(Service* service, DbIndex dbid) {
while (done < num_entries) { while (done < num_entries) {
journal::ParsedEntry entry{}; journal::ParsedEntry entry{};
SET_OR_RETURN(journal_reader_.ReadEntry(), entry); SET_OR_RETURN(journal_reader_.ReadEntry(), entry);
ex.Execute(entry); if (entry.opcode == journal::Op::COMMAND || entry.opcode == journal::Op::MULTI_COMMAND) {
ex.Execute(entry.dbid, entry.cmd);
}
done++; done++;
} }

View file

@ -751,21 +751,32 @@ void Replica::StableSyncDflyFb(Context* cntx) {
JournalReader reader{&ps, 0}; JournalReader reader{&ps, 0};
JournalExecutor executor{&service_}; JournalExecutor executor{&service_};
while (!cntx->IsCancelled()) {
TranactionData tx_data;
while (!cntx->IsCancelled()) { while (!cntx->IsCancelled()) {
auto res = reader.ReadEntry(); auto res = reader.ReadEntry();
if (!res) { if (!res) {
cntx->ReportError(res.error(), "Journal format error"); cntx->ReportError(res.error(), "Journal format error");
return; return;
} }
ExecuteEntry(&executor, std::move(res.value())); bool should_execute = tx_data.UpdateFromParsedEntry(std::move(*res));
if (should_execute == true) {
break;
}
}
ExecuteCmd(&executor, std::move(tx_data), cntx);
last_io_time_ = sock_->proactor()->GetMonotonicTimeNs(); last_io_time_ = sock_->proactor()->GetMonotonicTimeNs();
} }
return; return;
} }
void Replica::ExecuteEntry(JournalExecutor* executor, journal::ParsedEntry&& entry) { void Replica::ExecuteCmd(JournalExecutor* executor, TranactionData&& tx_data, Context* cntx) {
if (entry.shard_cnt <= 1) { // not multi shard cmd if (cntx->IsCancelled()) {
executor->Execute(entry); return;
}
if (tx_data.shard_cnt <= 1) { // not multi shard cmd
executor->Execute(tx_data.dbid, tx_data.commands.front());
return; return;
} }
@ -781,17 +792,20 @@ void Replica::ExecuteEntry(JournalExecutor* executor, journal::ParsedEntry&& ent
// Only the first fiber to reach the transaction will create data for transaction in map // Only the first fiber to reach the transaction will create data for transaction in map
multi_shard_exe_->map_mu.lock(); multi_shard_exe_->map_mu.lock();
auto [it, was_insert] = multi_shard_exe_->tx_sync_execution.emplace(entry.txid, entry.shard_cnt); auto [it, was_insert] =
VLOG(2) << "txid: " << entry.txid << " unique_shard_cnt_: " << entry.shard_cnt multi_shard_exe_->tx_sync_execution.emplace(tx_data.txid, tx_data.shard_cnt);
VLOG(2) << "txid: " << tx_data.txid << " unique_shard_cnt_: " << tx_data.shard_cnt
<< " was_insert: " << was_insert; << " was_insert: " << was_insert;
TxId txid = entry.txid; TxId txid = tx_data.txid;
// entries_vec will store all entries of trasaction and will be executed by the fiber that // entries_vec will store all entries of trasaction and will be executed by the fiber that
// inserted the txid to map. In case of global command the inserting fiber will executed his // inserted the txid to map. In case of global command the inserting fiber will executed his
// entry. // entry.
bool global_cmd = (entry.payload.value().size() == 1); bool global_cmd = (tx_data.commands.size() == 1 && tx_data.commands.front().cmd_args.size() == 1);
if (!global_cmd) { if (!global_cmd) {
it->second.entries_vec.push_back(std::move(entry)); for (auto& cmd : tx_data.commands) {
it->second.commands.push_back(std::move(cmd));
}
} }
auto& tx_sync = it->second; auto& tx_sync = it->second;
@ -800,14 +814,17 @@ void Replica::ExecuteEntry(JournalExecutor* executor, journal::ParsedEntry&& ent
// step 1 // step 1
tx_sync.barrier.wait(); tx_sync.barrier.wait();
// step 2 // step 2
if (was_insert) { if (was_insert) {
if (global_cmd) { if (global_cmd) {
executor->Execute(entry); executor->Execute(tx_data.dbid, tx_data.commands.front());
} else { } else {
executor->Execute(tx_sync.entries_vec); executor->Execute(tx_data.dbid, tx_sync.commands);
} }
} }
// step 3 // step 3
tx_sync.barrier.wait(); tx_sync.barrier.wait();
@ -1049,4 +1066,25 @@ error_code Replica::SendCommand(string_view command, ReqSerializer* serializer)
return ec; return ec;
} }
bool Replica::TranactionData::UpdateFromParsedEntry(journal::ParsedEntry&& entry) {
if (entry.opcode == journal::Op::EXEC) {
shard_cnt = entry.shard_cnt;
dbid = entry.dbid;
txid = entry.txid;
return true;
} else if (entry.opcode == journal::Op::COMMAND) {
txid = entry.txid;
shard_cnt = entry.shard_cnt;
dbid = entry.dbid;
commands.push_back(std::move(entry.cmd));
return true;
} else if (entry.opcode == journal::Op::MULTI_COMMAND) {
commands.push_back(std::move(entry.cmd));
return false;
} else {
DCHECK(false) << "Unsupported opcode";
}
return false;
}
} // namespace dfly } // namespace dfly

View file

@ -50,6 +50,17 @@ class Replica {
R_SYNC_OK = 0x10, R_SYNC_OK = 0x10,
}; };
// This class holds the commands of transaction in single shard.
// Once all commands recieved the command can be executed.
struct TranactionData {
TxId txid;
uint32_t shard_cnt;
DbIndex dbid;
std::vector<journal::ParsedEntry::CmdData> commands;
// Update the data from ParsedEntry and return if its ready for execution.
bool UpdateFromParsedEntry(journal::ParsedEntry&& entry);
};
struct MultiShardExecution { struct MultiShardExecution {
boost::fibers::mutex map_mu; boost::fibers::mutex map_mu;
@ -58,7 +69,7 @@ class Replica {
std::atomic_uint32_t counter; std::atomic_uint32_t counter;
TxExecutionSync(uint32_t counter) : barrier(counter), counter(counter) { TxExecutionSync(uint32_t counter) : barrier(counter), counter(counter) {
} }
std::vector<journal::ParsedEntry> entries_vec; std::vector<journal::ParsedEntry::CmdData> commands;
}; };
std::unordered_map<TxId, TxExecutionSync> tx_sync_execution; std::unordered_map<TxId, TxExecutionSync> tx_sync_execution;
@ -142,7 +153,7 @@ class Replica {
// Send command, update last_io_time, return error. // Send command, update last_io_time, return error.
std::error_code SendCommand(std::string_view command, facade::ReqSerializer* serializer); std::error_code SendCommand(std::string_view command, facade::ReqSerializer* serializer);
void ExecuteEntry(JournalExecutor* executor, journal::ParsedEntry&& entry); void ExecuteCmd(JournalExecutor* executor, TranactionData&& tx_data, Context* cntx);
public: /* Utility */ public: /* Utility */
struct Info { struct Info {

View file

@ -19,7 +19,7 @@ using nonstd::make_unexpected;
VLOG(1) << "Error while calling " #expr; \ VLOG(1) << "Error while calling " #expr; \
return exp_val.error(); \ return exp_val.error(); \
} \ } \
dest = exp_val.value(); \ dest = std::move(exp_val.value()); \
} while (0) } while (0)
#define SET_OR_UNEXPECT(expr, dest) \ #define SET_OR_UNEXPECT(expr, dest) \

View file

@ -328,7 +328,8 @@ bool Transaction::RunInShard(EngineShard* shard) {
// runnable concludes current operation, and should_release which tells // runnable concludes current operation, and should_release which tells
// whether we should unlock the keys. should_release is false for multi and // whether we should unlock the keys. should_release is false for multi and
// equal to concluding otherwise. // equal to concluding otherwise.
bool should_release = (coordinator_state_ & COORD_EXEC_CONCLUDING) && !multi_; bool is_concluding = (coordinator_state_ & COORD_EXEC_CONCLUDING);
bool should_release = is_concluding && !multi_;
IntentLock::Mode mode = Mode(); IntentLock::Mode mode = Mode();
// We make sure that we lock exactly once for each (multi-hop) transaction inside // We make sure that we lock exactly once for each (multi-hop) transaction inside
@ -373,7 +374,7 @@ bool Transaction::RunInShard(EngineShard* shard) {
/*************************************************************************/ /*************************************************************************/
if (!was_suspended && should_release) // Check last hop & non suspended. if (!was_suspended && is_concluding) // Check last hop & non suspended.
LogJournalOnShard(shard); LogJournalOnShard(shard);
// at least the coordinator thread owns the reference. // at least the coordinator thread owns the reference.
@ -631,6 +632,10 @@ void Transaction::UnlockMulti() {
sharded_keys[sid].push_back(k_v); sharded_keys[sid].push_back(k_v);
} }
if (ServerState::tlocal()->journal()) {
SetMultiUniqueShardCount();
}
uint32_t prev = run_count_.fetch_add(shard_data_.size(), memory_order_relaxed); uint32_t prev = run_count_.fetch_add(shard_data_.size(), memory_order_relaxed);
DCHECK_EQ(prev, 0u); DCHECK_EQ(prev, 0u);
@ -643,6 +648,33 @@ void Transaction::UnlockMulti() {
VLOG(1) << "UnlockMultiEnd " << DebugId(); VLOG(1) << "UnlockMultiEnd " << DebugId();
} }
void Transaction::SetMultiUniqueShardCount() {
uint32_t prev = run_count_.fetch_add(shard_data_.size(), memory_order_relaxed);
DCHECK_EQ(prev, 0u);
std::atomic<uint32_t> unique_shard_cnt = 0;
auto update_shard_cnd = [&] {
EngineShard* shard = EngineShard::tlocal();
auto journal = shard->journal();
if (journal != nullptr) {
TxId last_tx = journal->GetLastTxId();
if (last_tx == txid_) {
unique_shard_cnt.fetch_add(1, std::memory_order_relaxed);
}
}
this->DecreaseRunCnt();
};
for (ShardId i = 0; i < shard_data_.size(); ++i) {
shard_set->Add(i, std::move(update_shard_cnd));
}
WaitForShardCallbacks();
unique_shard_cnt_ = unique_shard_cnt.load(std::memory_order_release);
}
void Transaction::Schedule() { void Transaction::Schedule() {
if (multi_ && multi_->is_expanding) { if (multi_ && multi_->is_expanding) {
LockMulti(); LockMulti();
@ -1080,6 +1112,11 @@ void Transaction::ExpireShardCb(EngineShard* shard) {
} }
void Transaction::UnlockMultiShardCb(const std::vector<KeyList>& sharded_keys, EngineShard* shard) { void Transaction::UnlockMultiShardCb(const std::vector<KeyList>& sharded_keys, EngineShard* shard) {
auto journal = shard->journal();
if (journal != nullptr && journal->GetLastTxId() == txid_) {
journal->RecordEntry(journal::Entry{txid_, journal::Op::EXEC, db_index_, unique_shard_cnt_});
}
if (multi_->multi_opts & CO::GLOBAL_TRANS) { if (multi_->multi_opts & CO::GLOBAL_TRANS) {
shard->shard_lock()->Release(IntentLock::EXCLUSIVE); shard->shard_lock()->Release(IntentLock::EXCLUSIVE);
} }
@ -1221,7 +1258,12 @@ void Transaction::LogJournalOnShard(EngineShard* shard) {
entry_payload = entry_payload =
make_pair(facade::ToSV(cmd_with_full_args_.front()), ShardArgsInShard(shard->shard_id())); make_pair(facade::ToSV(cmd_with_full_args_.front()), ShardArgsInShard(shard->shard_id()));
} }
journal->RecordEntry(journal::Entry{txid_, db_index_, entry_payload, unique_shard_cnt_}); journal::Op opcode = journal::Op::COMMAND;
if (multi_) {
opcode = journal::Op::MULTI_COMMAND;
}
journal->RecordEntry(journal::Entry{txid_, opcode, db_index_, unique_shard_cnt_, entry_payload});
} }
void Transaction::BreakOnShutdown() { void Transaction::BreakOnShutdown() {

View file

@ -123,6 +123,12 @@ class Transaction {
} }
void UnlockMulti(); void UnlockMulti();
// In multi transaciton command we calculate the unique shard count of the trasaction
// after all transaciton commands where executed, by checking the last txid writen to
// all journals.
// This value is writen to journal so that replica we be able to apply the multi command
// atomicaly.
void SetMultiUniqueShardCount();
TxId txid() const { TxId txid() const {
return txid_; return txid_;