feat(server): Watch & Unwatch commands (#277)

* chore(server): Refactor ConnectionState & DbSlice for watched key support

* feat(server): Add WATCH & UNWATCH commands

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
Co-authored-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2022-09-13 19:14:11 +03:00 committed by GitHub
parent d3359f1a0a
commit cb024a23ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 389 additions and 89 deletions

View file

@ -212,6 +212,12 @@ void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
}
void ConnectionContext::OnClose() {
if (!conn_state.exec_info.watched_keys.empty()) {
shard_set->RunBriefInParallel([this](EngineShard* shard) {
return shard->db_slice().UnregisterConnectionWatches(&conn_state.exec_info);
});
}
if (!conn_state.subscribe_info)
return;
@ -244,4 +250,16 @@ string ConnectionContext::GetContextInfo() const {
return index ? absl::StrCat("flags:", buf) : string();
}
void ConnectionState::ExecInfo::Clear() {
state = EXEC_INACTIVE;
body.clear();
ClearWatched();
}
void ConnectionState::ExecInfo::ClearWatched() {
watched_keys.clear();
watched_dirty.store(false, memory_order_relaxed);
watched_existed = 0;
}
} // namespace dfly

View file

@ -23,40 +23,41 @@ struct StoredCmd {
};
struct ConnectionState {
DbIndex db_index = 0;
// MULTI-EXEC transaction related data.
struct ExecInfo {
enum ExecState { EXEC_INACTIVE, EXEC_COLLECT, EXEC_ERROR };
enum ExecState { EXEC_INACTIVE, EXEC_COLLECT, EXEC_ERROR };
ExecInfo() = default;
// ExecInfo is immovable due to being referenced from DbSlice.
ExecInfo(ExecInfo&&) = delete;
ExecState exec_state = EXEC_INACTIVE;
std::vector<StoredCmd> exec_body;
// Return true if ExecInfo is active (after MULTI)
bool IsActive() { return state != EXEC_INACTIVE; }
enum MCGetMask {
FETCH_CAS_VER = 1,
// Resets to blank state after EXEC or DISCARD
void Clear();
// Resets local watched keys info. Does not unregister the keys from DbSlices.
void ClearWatched();
ExecState state = EXEC_INACTIVE;
std::vector<StoredCmd> body;
// List of keys registered with WATCH
std::vector<std::pair<DbIndex, std::string>> watched_keys;
// Set if a watched key was changed before EXEC
std::atomic_bool watched_dirty = false;
// Number of times watch was called on an existing key.
uint32_t watched_existed = 0;
};
// used for memcache set/get commands.
// For set op - it's the flag value we are storing along with the value.
// For get op - we use it as a mask of MCGetMask values.
uint32_t memcache_flag = 0;
// If it's a replication client - then it holds positive sync session id.
uint32_t sync_session_id = 0;
// Lua-script related data.
struct Script {
struct ScriptInfo {
bool is_write = true;
absl::flat_hash_set<std::string_view> keys;
};
std::optional<Script> script_info;
// PUB-SUB messaging related data.
struct SubscribeInfo {
// TODO: to provide unique_strings across service. This will allow us to use string_view here.
absl::flat_hash_set<std::string> channels;
absl::flat_hash_set<std::string> patterns;
util::fibers_ext::BlockingCounter borrow_token;
bool IsEmpty() const {
return channels.empty() && patterns.empty();
}
@ -65,10 +66,28 @@ struct ConnectionState {
return channels.size() + patterns.size();
}
SubscribeInfo() : borrow_token(0) {
}
// TODO: to provide unique_strings across service. This will allow us to use string_view here.
absl::flat_hash_set<std::string> channels;
absl::flat_hash_set<std::string> patterns;
util::fibers_ext::BlockingCounter borrow_token{0};
};
enum MCGetMask {
FETCH_CAS_VER = 1,
};
DbIndex db_index = 0;
// used for memcache set/get commands.
// For set op - it's the flag value we are storing along with the value.
// For get op - we use it as a mask of MCGetMask values.
uint32_t memcache_flag = 0;
// If it's a replication client - then it holds positive sync session id.
uint32_t sync_session_id = 0;
ExecInfo exec_info;
std::optional<ScriptInfo> script_info;
std::unique_ptr<SubscribeInfo> subscribe_info;
};

View file

@ -460,6 +460,10 @@ void DbSlice::FlushDb(DbIndex db_ind) {
if (db_ind != kDbAll) {
auto& db = db_arr_[db_ind];
if (db) {
InvalidateDbWatches(db_ind);
}
auto db_ptr = std::move(db);
DCHECK(!db);
CreateDb(db_ind);
@ -470,6 +474,12 @@ void DbSlice::FlushDb(DbIndex db_ind) {
return;
}
for (size_t i = 0; i < db_arr_.size(); i++) {
if (db_arr_[i]) {
InvalidateDbWatches(i);
}
}
auto all_dbs = std::move(db_arr_);
db_arr_.resize(all_dbs.size());
for (size_t i = 0; i < db_arr_.size(); ++i) {
@ -545,7 +555,7 @@ pair<PrimeIterator, bool> DbSlice::AddEntry(DbIndex db_ind, string_view key, Pri
auto& it = res.first;
it->second = std::move(obj);
PostUpdate(db_ind, it, false);
PostUpdate(db_ind, it, key, false);
if (expire_at_ms) {
it->second.SetExpire(true);
@ -651,7 +661,7 @@ void DbSlice::PreUpdate(DbIndex db_ind, PrimeIterator it) {
it.SetVersion(NextVersion());
}
void DbSlice::PostUpdate(DbIndex db_ind, PrimeIterator it, bool existing) {
void DbSlice::PostUpdate(DbIndex db_ind, PrimeIterator it, std::string_view key, bool existing) {
DbTableStats* stats = MutableStats(db_ind);
size_t value_heap_size = it->second.MallocUsed();
@ -660,6 +670,18 @@ void DbSlice::PostUpdate(DbIndex db_ind, PrimeIterator it, bool existing) {
stats->strval_memory_usage += value_heap_size;
if (existing)
stats->update_value_amount += value_heap_size;
auto& watched_keys = db_arr_[db_ind]->watched_keys;
if (!watched_keys.empty()) {
// Check if the key is watched.
if (auto wit = watched_keys.find(key); wit != watched_keys.end()) {
for (auto conn_ptr : wit->second) {
conn_ptr->watched_dirty.store(true, memory_order_relaxed);
}
// No connections need to watch it anymore.
watched_keys.erase(wit);
}
}
}
pair<PrimeIterator, ExpireIterator> DbSlice::ExpireIfNeeded(DbIndex db_ind,
@ -830,4 +852,28 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t
return freed_memory_fun();
};
void DbSlice::RegisterWatchedKey(DbIndex db_indx, std::string_view key, ConnectionState::ExecInfo* exec_info) {
db_arr_[db_indx]->watched_keys[key].push_back(exec_info);
}
void DbSlice::UnregisterConnectionWatches(ConnectionState::ExecInfo* exec_info) {
for (const auto& [db_indx, key] : exec_info->watched_keys) {
auto& watched_keys = db_arr_[db_indx]->watched_keys;
if (auto it = watched_keys.find(key); it != watched_keys.end()) {
it->second.erase(std::remove(it->second.begin(), it->second.end(), exec_info),
it->second.end());
if (it->second.empty())
watched_keys.erase(it);
}
}
}
void DbSlice::InvalidateDbWatches(DbIndex db_indx) {
for (const auto& [key, conn_list] : db_arr_[db_indx]->watched_keys) {
for (auto conn_ptr : conn_list) {
conn_ptr->watched_dirty.store(true, memory_order_relaxed);
}
}
}
} // namespace dfly

View file

@ -9,6 +9,7 @@
#include "facade/op_status.h"
#include "server/common.h"
#include "server/table.h"
#include "server/conn_context.h"
namespace util {
class ProactorBase;
@ -217,7 +218,7 @@ class DbSlice {
// Callback functions called upon writing to the existing key.
void PreUpdate(DbIndex db_ind, PrimeIterator it);
void PostUpdate(DbIndex db_ind, PrimeIterator it, bool existing_entry = true);
void PostUpdate(DbIndex db_ind, PrimeIterator it, std::string_view key, bool existing_entry = true);
DbTableStats* MutableStats(DbIndex db_ind) {
return &db_arr_[db_ind]->stats;
@ -261,6 +262,14 @@ class DbSlice {
caching_mode_ = 1;
}
void RegisterWatchedKey(DbIndex db_indx, std::string_view key, ConnectionState::ExecInfo* exec_info);
// Unregisted all watched key entries for connection.
void UnregisterConnectionWatches(ConnectionState::ExecInfo* exec_info);
// Invalidate all watched keys in database. Used on FLUSH.
void InvalidateDbWatches(DbIndex db_indx);
private:
void CreateDb(DbIndex index);
size_t EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* table);
@ -269,6 +278,7 @@ class DbSlice {
return version_++;
}
private:
ShardId shard_id_;
uint8_t caching_mode_ : 1;

View file

@ -581,6 +581,92 @@ TEST_F(DflyEngineTest, PUnsubscribe) {
EXPECT_THAT(resp.GetVec(), ElementsAre("punsubscribe", "b*", IntArg(0)));
}
TEST_F(DflyEngineTest, Watch) {
auto kExecFail = ArgType(RespExpr::NIL);
auto kExecSuccess = ArgType(RespExpr::ARRAY);
// Check watch doesn't run in multi.
Run({"multi"});
ASSERT_THAT(Run({"watch", "a"}), ErrArg("WATCH inside MULTI is not allowed"));
Run({"discard"});
// Check watch on existing key.
Run({"set", "a", "1"});
EXPECT_EQ(Run({"watch", "a"}), "OK");
Run({"set", "a", "2"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecFail);
// Check watch data cleared after EXEC.
Run({"set", "a", "1"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecSuccess);
// Check watch on non-existent key.
Run({"del", "b"});
EXPECT_EQ(Run({"watch", "b"}), "OK"); // didn't exist yet
Run({"set", "b", "1"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecFail);
// Check EXEC doesn't miss watched key expiration.
Run({"watch", "a"});
Run({"expire", "a", "1"});
UpdateTime(expire_now_ + 1000);
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecFail);
// Check unwatch.
Run({"watch", "a"});
Run({"unwatch"});
Run({"set", "a", "3"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecSuccess);
// Check double expire
Run({"watch", "a", "b"});
Run({"set", "a", "2"});
Run({"set", "b", "2"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecFail);
// Check EXPIRE + new key.
Run({"set", "a", "1"});
Run({"del", "c"});
Run({"watch", "c"}); // didn't exist yet
Run({"watch", "a"});
Run({"set", "c", "1"});
Run({"expire", "a", "1"}); // a existed
UpdateTime(expire_now_ + 1000);
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecFail);
// Check FLUSHDB touches watched keys
Run({"select", "1"});
Run({"set", "a", "1"});
Run({"watch", "a"});
Run({"flushdb"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecFail);
// Check multi db watches are not supported.
Run({"select", "1"});
Run({"set", "a", "1"});
Run({"watch", "a"});
Run({"select", "0"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), ArgType(RespExpr::ERROR));
// Check watch keys are isolated between databases.
Run({"set", "a", "1"});
Run({"watch", "a"});
Run({"select", "1"});
Run({"set", "a", "2"}); // changing a on db 1
Run({"select", "0"});
Run({"multi"});
ASSERT_THAT(Run({"exec"}), kExecSuccess);
}
// TODO: to test transactions with a single shard since then all transactions become local.
// To consider having a parameter in dragonfly engine controlling number of shards
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.

View file

@ -28,6 +28,9 @@ class GenericFamily {
static void Register(CommandRegistry* registry);
// Accessed by Service::Exec and Service::Watch as an utility.
static OpResult<uint32_t> OpExists(const OpArgs& op_args, ArgSlice keys);
private:
enum TimeUnit { SEC, MSEC };
@ -65,7 +68,6 @@ class GenericFamily {
static OpResult<uint64_t> OpTtl(Transaction* t, EngineShard* shard, std::string_view key);
static OpResult<uint32_t> OpDel(const OpArgs& op_args, ArgSlice keys);
static OpResult<uint32_t> OpExists(const OpArgs& op_args, ArgSlice keys);
static OpResult<void> OpRen(const OpArgs& op_args, std::string_view from, std::string_view to,
bool skip_exists);
static OpResult<uint32_t> OpStick(const OpArgs& op_args, ArgSlice keys);

View file

@ -509,7 +509,7 @@ OpResult<uint32_t> HSetFamily::OpSet(const OpArgs& op_args, string_view key, Cmd
}
}
it->second.SyncRObj();
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
return created;
}
@ -548,7 +548,7 @@ OpResult<uint32_t> HSetFamily::OpDel(const OpArgs& op_args, string_view key, Cmd
co.SyncRObj();
db_slice.PostUpdate(op_args.db_ind, *it_res);
db_slice.PostUpdate(op_args.db_ind, *it_res, key);
if (key_remove) {
if (hset->encoding == OBJ_ENCODING_LISTPACK) {
stats->listpack_blob_cnt--;
@ -874,7 +874,7 @@ OpStatus HSetFamily::OpIncrBy(const OpArgs& op_args, string_view key, string_vie
}
it->second.SyncRObj();
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
return OpStatus::OK;
}

View file

@ -59,7 +59,7 @@ void SetString(const OpArgs& op_args, string_view key, const string& value) {
auto& db_slice = op_args.shard->db_slice();
auto [it_output, added] = db_slice.AddOrFind(op_args.db_ind, key);
it_output->second.SetString(value);
db_slice.PostUpdate(op_args.db_ind, it_output);
db_slice.PostUpdate(op_args.db_ind, it_output, key);
RecordJournal(op_args, it_output->first, it_output->second);
}

View file

@ -277,7 +277,7 @@ OpStatus BPopper::Pop(Transaction* t, EngineShard* shard) {
db_slice.PreUpdate(t->db_index(), it);
value_ = ListPop(dir_, ql);
db_slice.PostUpdate(t->db_index(), it);
db_slice.PostUpdate(t->db_index(), it, key_);
if (quicklistCount(ql) == 0) {
CHECK(shard->db_slice().Del(t->db_index(), it));
}
@ -300,7 +300,7 @@ OpResult<string> OpRPopLPushSingleShard(const OpArgs& op_args, string_view src,
string val = ListPop(ListDir::RIGHT, src_ql);
quicklistPushHead(src_ql, val.data(), val.size());
db_slice.PostUpdate(op_args.db_ind, src_it);
db_slice.PostUpdate(op_args.db_ind, src_it, src);
return val;
}
@ -336,8 +336,8 @@ OpResult<string> OpRPopLPushSingleShard(const OpArgs& op_args, string_view src,
string val = ListPop(ListDir::RIGHT, src_ql);
quicklistPushHead(dest_ql, val.data(), val.size());
db_slice.PostUpdate(op_args.db_ind, src_it);
db_slice.PostUpdate(op_args.db_ind, dest_it, !new_key);
db_slice.PostUpdate(op_args.db_ind, src_it, src);
db_slice.PostUpdate(op_args.db_ind, dest_it, dest, !new_key);
if (quicklistCount(src_ql) == 0) {
CHECK(db_slice.Del(op_args.db_ind, src_it));
@ -418,7 +418,7 @@ OpResult<uint32_t> OpPush(const OpArgs& op_args, std::string_view key, ListDir d
es->blocking_controller()->AwakeWatched(op_args.db_ind, key);
}
} else {
es->db_slice().PostUpdate(op_args.db_ind, it, true);
es->db_slice().PostUpdate(op_args.db_ind, it, key, true);
}
return quicklistCount(ql);
@ -451,7 +451,7 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, ListDir dir, u
}
}
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
if (quicklistCount(ql) == 0) {
CHECK(db_slice.Del(op_args.db_ind, it));
@ -891,7 +891,7 @@ OpResult<int> ListFamily::OpInsert(const OpArgs& op_args, string_view key, strin
DCHECK_EQ(LIST_HEAD, insert_param);
quicklistInsertBefore(qiter, &entry, elem.data(), elem.size());
}
db_slice.PostUpdate(op_args.db_ind, *it_res);
db_slice.PostUpdate(op_args.db_ind, *it_res, key);
res = quicklistCount(ql);
}
quicklistReleaseIterator(qiter);
@ -931,7 +931,7 @@ OpResult<uint32_t> ListFamily::OpRem(const OpArgs& op_args, string_view key, str
break;
}
}
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
quicklistReleaseIterator(qiter);
@ -954,7 +954,7 @@ OpStatus ListFamily::OpSet(const OpArgs& op_args, string_view key, string_view e
db_slice.PreUpdate(op_args.db_ind, it);
int replaced = quicklistReplaceAtIndex(ql, index, elem.data(), elem.size());
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
if (!replaced) {
return OpStatus::OUT_OF_RANGE;
@ -998,7 +998,7 @@ OpStatus ListFamily::OpTrim(const OpArgs& op_args, string_view key, long start,
db_slice.PreUpdate(op_args.db_ind, it);
quicklistDelRange(ql, 0, ltrim);
quicklistDelRange(ql, -rtrim, rtrim);
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
if (quicklistCount(ql) == 0) {
CHECK(db_slice.Del(op_args.db_ind, it));

View file

@ -20,7 +20,6 @@ extern "C" {
#include "base/logging.h"
#include "facade/dragonfly_connection.h"
#include "facade/error.h"
#include "server/conn_context.h"
#include "server/error.h"
#include "server/generic_family.h"
#include "server/hset_family.h"
@ -395,8 +394,8 @@ void Service::Shutdown() {
}
static void MultiSetError(ConnectionContext* cntx) {
if (cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE) {
cntx->conn_state.exec_state = ConnectionState::EXEC_ERROR;
if (cntx->conn_state.exec_info.IsActive()) {
cntx->conn_state.exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR;
}
}
@ -450,8 +449,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
bool is_write_cmd = (cid->opt_mask() & CO::WRITE) ||
(under_script && dfly_cntx->conn_state.script_info->is_write);
bool under_multi =
dfly_cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd;
bool under_multi = dfly_cntx->conn_state.exec_info.IsActive() && !is_trans_cmd;
if (!etl.is_master && is_write_cmd && !dfly_cntx->is_replicating) {
(*cntx)->SendError("-READONLY You can't write against a read only replica.");
@ -482,20 +480,25 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx)
(*cntx)->SendError("Can not call SELECT within a transaction");
return;
}
if (cmd_name == "WATCH") {
(*cntx)->SendError("WATCH inside MULTI is not allowed");
return;
}
}
std::move(multi_error).Cancel();
etl.connection_stats.cmd_count_map[cmd_name]++;
if (dfly_cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd) {
if (dfly_cntx->conn_state.exec_info.IsActive() && !is_trans_cmd) {
// TODO: protect against aggregating huge transactions.
StoredCmd stored_cmd{cid};
stored_cmd.cmd.reserve(args.size());
for (size_t i = 0; i < args.size(); ++i) {
stored_cmd.cmd.emplace_back(ArgS(args, i));
}
dfly_cntx->conn_state.exec_body.push_back(std::move(stored_cmd));
dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd));
return (*cntx)->SendSimpleString("QUEUED");
}
@ -725,14 +728,62 @@ void Service::Quit(CmdArgList args, ConnectionContext* cntx) {
}
void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
if (cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE) {
if (cntx->conn_state.exec_info.IsActive()) {
return (*cntx)->SendError("MULTI calls can not be nested");
}
cntx->conn_state.exec_state = ConnectionState::EXEC_COLLECT;
cntx->conn_state.exec_info.state = ConnectionState::ExecInfo::EXEC_COLLECT;
// TODO: to protect against huge exec transactions.
return (*cntx)->SendOk();
}
void Service::Watch(CmdArgList args, ConnectionContext* cntx) {
auto& exec_info = cntx->conn_state.exec_info;
// Skip if EXEC will already fail due previous WATCH.
if (exec_info.watched_dirty.load(memory_order_relaxed)) {
return (*cntx)->SendOk();
}
atomic_uint32_t keys_existed = 0;
auto cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->ShardArgsInShard(shard->shard_id());
for (auto k : largs) {
shard->db_slice().RegisterWatchedKey(cntx->db_index(), k, &exec_info);
}
auto res = GenericFamily::OpExists(t->GetOpArgs(shard), largs);
keys_existed.fetch_add(res.value_or(0), memory_order_relaxed);
return OpStatus::OK;
};
cntx->transaction->ScheduleSingleHop(std::move(cb));
// Duplicate keys are stored to keep correct count.
exec_info.watched_existed += keys_existed.load(memory_order_relaxed);
for (size_t i = 1; i < args.size(); i++) {
exec_info.watched_keys.emplace_back(cntx->db_index(), ArgS(args, i));
}
return (*cntx)->SendOk();
}
// Unwatch all keys for a connection and unregister from DbSlices.
// Used by UNWATCH, DICARD and EXEC.
void UnwatchAllKeys(ConnectionContext* cntx) {
auto& exec_info = cntx->conn_state.exec_info;
if (!exec_info.watched_keys.empty()) {
auto cb = [&](EngineShard* shard) {
shard->db_slice().UnregisterConnectionWatches(&exec_info);
};
shard_set->RunBriefInParallel(std::move(cb));
}
exec_info.ClearWatched();
}
void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
UnwatchAllKeys(cntx);
return (*cntx)->SendOk();
}
void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) {
DCHECK(cntx->transaction);
InterpreterReplier replier(reply);
@ -836,7 +887,7 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
// TODO: to determine whether the script is RO by scanning all "redis.p?call" calls
// and checking whether all invocations consist of RO commands.
// we can do it once during script insertion into script mgr.
cntx->conn_state.script_info.emplace(ConnectionState::Script{});
cntx->conn_state.script_info.emplace(ConnectionState::ScriptInfo{});
for (size_t i = 0; i < eval_args.keys.size(); ++i) {
cntx->conn_state.script_info->keys.insert(ArgS(eval_args.keys, i));
}
@ -880,35 +931,96 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
void Service::Discard(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = (*cntx).operator->();
if (cntx->conn_state.exec_state == ConnectionState::EXEC_INACTIVE) {
if (!cntx->conn_state.exec_info.IsActive()) {
return rb->SendError("DISCARD without MULTI");
}
cntx->conn_state.exec_state = ConnectionState::EXEC_INACTIVE;
cntx->conn_state.exec_body.clear();
UnwatchAllKeys(cntx);
cntx->conn_state.exec_info.Clear();
rb->SendOk();
}
// Return true if non of the connections watched keys expired.
bool CheckWatchedKeyExpiry(ConnectionContext* cntx, const CommandRegistry& registry) {
static char EXISTS[] = "EXISTS";
auto& exec_info = cntx->conn_state.exec_info;
CmdArgVec str_list(exec_info.watched_keys.size() + 1);
str_list[0] = MutableSlice{EXISTS, strlen(EXISTS)};
for (size_t i = 1; i < str_list.size(); i++) {
auto& [db, s] = exec_info.watched_keys[i - 1];
str_list[i] = MutableSlice{s.data(), s.size()};
}
atomic_uint32_t watch_exist_count{0};
auto cb = [&watch_exist_count, &exec_info](Transaction* t, EngineShard* shard) {
ArgSlice args = t->ShardArgsInShard(shard->shard_id());
auto res = GenericFamily::OpExists(t->GetOpArgs(shard), args);
watch_exist_count.fetch_add(res.value_or(0), memory_order_relaxed);
return OpStatus::OK;
};
VLOG(1) << "Checking expired watch keys";
cntx->transaction->SetExecCmd(registry.Find(EXISTS));
cntx->transaction->InitByArgs(cntx->conn_state.db_index,
CmdArgList{str_list.data(), str_list.size()});
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, status);
// The comparison can still be true even if a key expired due to another one being created.
// So we have to check the watched_dirty flag, which is set if a key expired.
return watch_exist_count.load() == exec_info.watched_existed && !exec_info.watched_dirty.load(memory_order_relaxed);
}
// Check if exec_info watches keys on dbs other than db_indx.
bool IsWatchingOtherDbs(DbIndex db_indx, const ConnectionState::ExecInfo& exec_info) {
for (const auto& [key_db, _] : exec_info.watched_keys) {
if (key_db != db_indx) {
return true;
}
}
return false;
}
void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
RedisReplyBuilder* rb = (*cntx).operator->();
if (cntx->conn_state.exec_state == ConnectionState::EXEC_INACTIVE) {
if (!cntx->conn_state.exec_info.IsActive()) {
return rb->SendError("EXEC without MULTI");
}
if (cntx->conn_state.exec_state == ConnectionState::EXEC_ERROR) {
cntx->conn_state.exec_state = ConnectionState::EXEC_INACTIVE;
cntx->conn_state.exec_body.clear();
auto& exec_info = cntx->conn_state.exec_info;
absl::Cleanup exec_clear = [&cntx, &exec_info] {
UnwatchAllKeys(cntx);
exec_info.Clear();
};
if (IsWatchingOtherDbs(cntx->db_index(), exec_info)) {
return rb->SendError("Dragonfly does not allow WATCH and EXEC on different databases");
}
if (exec_info.state == ConnectionState::ExecInfo::EXEC_ERROR) {
return rb->SendError("-EXECABORT Transaction discarded because of previous errors");
}
VLOG(1) << "StartExec " << cntx->conn_state.exec_body.size();
rb->StartArray(cntx->conn_state.exec_body.size());
if (!cntx->conn_state.exec_body.empty()) {
if (exec_info.watched_dirty.load(memory_order_relaxed)) {
return rb->SendNull();
}
// EXEC should not run if any of the watched keys expired.
if (!exec_info.watched_keys.empty() && !CheckWatchedKeyExpiry(cntx, registry_)) {
cntx->transaction->UnlockMulti();
return rb->SendNull();
}
VLOG(1) << "StartExec " << exec_info.body.size();
rb->StartArray(exec_info.body.size());
if (!exec_info.body.empty()) {
CmdArgVec str_list;
for (auto& scmd : cntx->conn_state.exec_body) {
for (auto& scmd : exec_info.body) {
str_list.resize(scmd.cmd.size());
for (size_t i = 0; i < scmd.cmd.size(); ++i) {
string& s = scmd.cmd[i];
@ -929,12 +1041,10 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
break;
}
VLOG(1) << "Exec unlocking " << cntx->conn_state.exec_body.size() << " commands";
VLOG(1) << "Exec unlocking " << exec_info.body.size() << " commands";
cntx->transaction->UnlockMulti();
}
cntx->conn_state.exec_state = ConnectionState::EXEC_INACTIVE;
cntx->conn_state.exec_body.clear();
VLOG(1) << "Exec completed";
}
@ -1150,6 +1260,8 @@ void Service::RegisterCommands() {
registry_
<< CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0}.HFUNC(Quit)
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0}.HFUNC(Multi)
<< CI{"WATCH", CO::LOADING, -2, 1, -1, 1}.HFUNC(Watch)
<< CI{"UNWATCH", CO::LOADING, 1, 0, 0, 0}.HFUNC(Unwatch)
<< CI{"DISCARD", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0}.MFUNC(Discard)
<< CI{"EVAL", CO::NOSCRIPT | CO::VARIADIC_KEYS, -3, 3, 3, 1}.MFUNC(Eval).SetValidator(
&EvalValidator)

View file

@ -87,6 +87,9 @@ class Service : public facade::ServiceInterface {
static void Quit(CmdArgList args, ConnectionContext* cntx);
static void Multi(CmdArgList args, ConnectionContext* cntx);
static void Watch(CmdArgList args, ConnectionContext* cntx);
static void Unwatch(CmdArgList args, ConnectionContext* cntx);
void Discard(CmdArgList args, ConnectionContext* cntx);
void Eval(CmdArgList args, ConnectionContext* cntx);
void EvalSha(CmdArgList args, ConnectionContext* cntx);

View file

@ -442,7 +442,7 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, ArgSlice v
res = AddStrSet(std::move(vals), &co);
}
db_slice.PostUpdate(op_args.db_ind, it, !new_key);
db_slice.PostUpdate(op_args.db_ind, it, key, !new_key);
return res;
}
@ -460,7 +460,7 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, std::string_view key, const ArgS
CompactObj& co = find_res.value()->second;
auto [removed, isempty] = RemoveSet(vals, &co);
db_slice.PostUpdate(op_args.db_ind, *find_res);
db_slice.PostUpdate(op_args.db_ind, *find_res, key);
if (isempty) {
CHECK(db_slice.Del(op_args.db_ind, find_res.value()));
@ -1157,7 +1157,7 @@ OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key
} else {
result = PopStrSet(count, st);
}
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
}
return result;
}

View file

@ -100,7 +100,7 @@ OpResult<uint32_t> OpSetRange(const OpArgs& op_args, string_view key, size_t sta
memcpy(s.data() + start, value.data(), value.size());
it->second.SetString(s);
db_slice.PostUpdate(op_args.db_ind, it, !added);
db_slice.PostUpdate(op_args.db_ind, it, key, !added);
RecordJournal(op_args, it->first, it->second);
return it->second.Size();
@ -138,7 +138,7 @@ OpResult<string> OpGetRange(const OpArgs& op_args, string_view key, int32_t star
return string(slice.substr(start, end - start + 1));
};
size_t ExtendExisting(const OpArgs& op_args, PrimeIterator it, string_view val, bool prepend) {
size_t ExtendExisting(const OpArgs& op_args, PrimeIterator it, string_view key, string_view val, bool prepend) {
string tmp, new_val;
auto* shard = op_args.shard;
string_view slice = GetSlice(shard, it->second, &tmp);
@ -150,7 +150,7 @@ size_t ExtendExisting(const OpArgs& op_args, PrimeIterator it, string_view val,
auto& db_slice = shard->db_slice();
db_slice.PreUpdate(op_args.db_ind, it);
it->second.SetString(new_val);
db_slice.PostUpdate(op_args.db_ind, it, true);
db_slice.PostUpdate(op_args.db_ind, it, key, true);
RecordJournal(op_args, it->first, it->second);
return new_val.size();
@ -164,7 +164,7 @@ OpResult<uint32_t> ExtendOrSet(const OpArgs& op_args, string_view key, string_vi
auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key);
if (inserted) {
it->second.SetString(val);
db_slice.PostUpdate(op_args.db_ind, it, false);
db_slice.PostUpdate(op_args.db_ind, it, key, false);
RecordJournal(op_args, it->first, it->second);
return val.size();
@ -173,7 +173,7 @@ OpResult<uint32_t> ExtendOrSet(const OpArgs& op_args, string_view key, string_vi
if (it->second.ObjType() != OBJ_STRING)
return OpStatus::WRONG_TYPE;
return ExtendExisting(op_args, it, val, prepend);
return ExtendExisting(op_args, it, key, val, prepend);
}
OpResult<bool> ExtendOrSkip(const OpArgs& op_args, std::string_view key, std::string_view val,
@ -184,7 +184,7 @@ OpResult<bool> ExtendOrSkip(const OpArgs& op_args, std::string_view key, std::st
return false;
}
return ExtendExisting(op_args, *it_res, val, prepend);
return ExtendExisting(op_args, *it_res, key, val, prepend);
}
OpResult<string> OpGet(const OpArgs& op_args, string_view key) {
@ -206,7 +206,7 @@ OpResult<double> OpIncrFloat(const OpArgs& op_args, std::string_view key, double
if (inserted) {
char* str = RedisReplyBuilder::FormatDouble(val, buf, sizeof(buf));
it->second.SetString(str);
db_slice.PostUpdate(op_args.db_ind, it, false);
db_slice.PostUpdate(op_args.db_ind, it, key, false);
RecordJournal(op_args, it->first, it->second);
return val;
@ -238,7 +238,7 @@ OpResult<double> OpIncrFloat(const OpArgs& op_args, std::string_view key, double
db_slice.PreUpdate(op_args.db_ind, it);
it->second.SetString(str);
db_slice.PostUpdate(op_args.db_ind, it, true);
db_slice.PostUpdate(op_args.db_ind, it, key, true);
RecordJournal(op_args, it->first, it->second);
return base;
@ -290,7 +290,7 @@ OpResult<int64_t> OpIncrBy(const OpArgs& op_args, std::string_view key, int64_t
DCHECK(!it->second.IsExternal());
db_slice.PreUpdate(op_args.db_ind, it);
it->second.SetInt(new_val);
db_slice.PostUpdate(op_args.db_ind, it);
db_slice.PostUpdate(op_args.db_ind, it, key);
RecordJournal(op_args, it->first, it->second);
return new_val;
@ -330,7 +330,7 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
// Make sure that we have this key, and only add it if it does exists
if (params.how == SET_IF_EXISTS) {
if (IsValid(it)) {
return SetExisting(params, it, expire_it, value);
return SetExisting(params, it, expire_it, key, value);
} else {
return OpStatus::SKIPPED;
}
@ -352,14 +352,14 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
PrimeIterator it = get<0>(add_res);
if (!get<2>(add_res)) { // Existing.
return SetExisting(params, it, get<1>(add_res), value);
return SetExisting(params, it, get<1>(add_res), key, value);
}
//
// Adding new value.
PrimeValue tvalue{value};
tvalue.SetFlag(params.memcache_flags != 0);
it->second = std::move(tvalue);
db_slice.PostUpdate(params.db_index, it, false);
db_slice.PostUpdate(params.db_index, it, key, false);
if (params.expire_after_ms) {
db_slice.UpdateExpire(params.db_index, it, params.expire_after_ms + db_slice.Now());
@ -381,7 +381,7 @@ OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value
}
OpStatus SetCmd::SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it,
string_view value) {
string_view key, string_view value) {
if (params.how == SET_IF_NOTEXIST)
return OpStatus::SKIPPED;
@ -429,7 +429,7 @@ OpStatus SetCmd::SetExisting(const SetParams& params, PrimeIterator it, ExpireIt
}
}
db_slice.PostUpdate(params.db_index, it);
db_slice.PostUpdate(params.db_index, it, key);
RecordJournal(op_args_, it->first, it->second);
return OpStatus::OK;

View file

@ -46,7 +46,7 @@ class SetCmd {
private:
OpStatus SetExisting(const SetParams& params, PrimeIterator it, ExpireIterator e_it,
std::string_view value);
std::string_view key, std::string_view value);
};
class StringFamily {

View file

@ -11,6 +11,7 @@
#include "core/expire_period.h"
#include "core/intent_lock.h"
#include "server/conn_context.h"
#include "server/detail/table.h"
namespace dfly {
@ -64,6 +65,9 @@ struct DbTable : boost::intrusive_ref_counter<DbTable, boost::thread_unsafe_coun
// Contains transaction locks
LockTable trans_locks;
// Stores a list of dependant connections for each watched key.
absl::flat_hash_map<std::string, std::vector<ConnectionState::ExecInfo*>> watched_keys;
mutable DbTableStats stats;
ExpireTable::Cursor expire_cursor;
PrimeTable::Cursor prime_cursor;

View file

@ -812,7 +812,7 @@ OpResult<AddResult> OpAdd(const OpArgs& op_args, const ZParams& zparams, string_
DVLOG(2) << "ZAdd " << zobj->ptr;
res_it.value()->second.SyncRObj();
op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it);
op_args.shard->db_slice().PostUpdate(op_args.db_ind, *res_it, key);
if (zparams.flags & ZADD_IN_INCR) {
aresult.new_score = new_score;
@ -1627,7 +1627,7 @@ OpResult<unsigned> ZSetFamily::OpRem(const OpArgs& op_args, string_view key, Arg
}
auto zlen = zsetLength(zobj);
res_it.value()->second.SyncRObj();
db_slice.PostUpdate(op_args.db_ind, *res_it);
db_slice.PostUpdate(op_args.db_ind, *res_it, key);
if (zlen == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_ind, res_it.value()));
@ -1681,7 +1681,7 @@ OpResult<unsigned> ZSetFamily::OpRemRange(const OpArgs& op_args, string_view key
std::visit(iv, range_spec.interval);
res_it.value()->second.SyncRObj();
db_slice.PostUpdate(op_args.db_ind, *res_it);
db_slice.PostUpdate(op_args.db_ind, *res_it, key);
auto zlen = zsetLength(zobj);
if (zlen == 0) {