WIP: Auto PostUpdate() (#2268)

* WIP: Auto `PostUpdate()`

* More `Find()` uses

* Final touches

* Fixes

* Fix bug and allow reassigning

* Rename to AutoUpdater

* Fix and add DCHECK

* Also check deletion count

* Use ccache instead of sccache

* Try to upgrade Helio

* off64_t

* off64_t

* Revert changes to CI
This commit is contained in:
Shahar Mike 2023-12-11 10:07:53 +02:00 committed by GitHub
parent c183bf69aa
commit 1ce3f983c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 391 additions and 238 deletions

View file

@ -325,6 +325,9 @@ class DashTable<_Key, _Value, Policy>::Iterator {
public:
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using IteratorPairType =
std::conditional_t<IsConst, detail::IteratorPair<const Key_t, const Value_t>,
detail::IteratorPair<Key_t, Value_t>>;
// Copy constructor from iterator to const_iterator.
template <bool TIsConst = IsConst, bool TIsSingleB,
@ -372,16 +375,9 @@ class DashTable<_Key, _Value, Policy>::Iterator {
return *this;
}
detail::IteratorPair<Key_t, Value_t> operator->() {
IteratorPairType operator->() const {
auto* seg = owner_->segment_[seg_id_];
return detail::IteratorPair<Key_t, Value_t>{seg->Key(bucket_id_, slot_id_),
seg->Value(bucket_id_, slot_id_)};
}
const detail::IteratorPair<Key_t, Value_t> operator->() const {
auto* seg = owner_->segment_[seg_id_];
return detail::IteratorPair<Key_t, Value_t>{seg->Key(bucket_id_, slot_id_),
seg->Value(bucket_id_, slot_id_)};
return {seg->Key(bucket_id_, slot_id_), seg->Value(bucket_id_, slot_id_)};
}
// Make it self-contained. Does not need container::end().

View file

@ -314,7 +314,7 @@ class ElementAccess {
};
std::optional<bool> ElementAccess::Exists(EngineShard* shard) {
auto res = shard->db_slice().Find(context_, key_, OBJ_STRING);
auto res = shard->db_slice().FindReadOnly(context_, key_, OBJ_STRING);
if (res.status() == OpStatus::WRONG_TYPE) {
return {};
}
@ -458,7 +458,8 @@ OpResult<std::string> RunBitOpNot(const OpArgs& op_args, ArgSlice keys) {
EngineShard* es = op_args.shard;
// if we found the value, just return, if not found then skip, otherwise report an error
auto key = keys.front();
OpResult<PrimeIterator> find_res = es->db_slice().Find(op_args.db_cntx, key, OBJ_STRING);
OpResult<PrimeConstIterator> find_res =
es->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (find_res) {
return GetString(find_res.value()->second, es);
} else {
@ -479,7 +480,8 @@ OpResult<std::string> RunBitOpOnShard(std::string_view op, const OpArgs& op_args
// collect all the value for this shard
for (auto& key : keys) {
OpResult<PrimeIterator> find_res = es->db_slice().Find(op_args.db_cntx, key, OBJ_STRING);
OpResult<PrimeConstIterator> find_res =
es->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (find_res) {
values.emplace_back(GetString(find_res.value()->second, es));
} else {
@ -1261,7 +1263,7 @@ OpResult<bool> ReadValueBitsetAt(const OpArgs& op_args, std::string_view key, ui
OpResult<std::string> ReadValue(const DbContext& context, std::string_view key,
EngineShard* shard) {
OpResult<PrimeIterator> it_res = shard->db_slice().Find(context, key, OBJ_STRING);
OpResult<PrimeConstIterator> it_res = shard->db_slice().FindReadOnly(context, key, OBJ_STRING);
if (!it_res.ok()) {
return it_res.status();
}

View file

@ -338,6 +338,83 @@ auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type)
return it;
}
DbSlice::AutoUpdater::AutoUpdater() {
}
DbSlice::AutoUpdater::AutoUpdater(AutoUpdater&& o) {
*this = std::move(o);
}
DbSlice::AutoUpdater& DbSlice::AutoUpdater::operator=(AutoUpdater&& o) {
Run();
fields_ = o.fields_;
o.Cancel();
return *this;
}
DbSlice::AutoUpdater::~AutoUpdater() {
Run();
}
void DbSlice::AutoUpdater::Run() {
if (fields_.action == DestructorAction::kDoNothing) {
return;
}
// Check that AutoUpdater does not run after a key was removed.
// If this CHECK() failed for you, it probably means that you deleted a key while having an auto
// updater in scope. You'll probably want to call Run() (or Cancel() - but be careful).
DCHECK(IsValid(fields_.db_slice->db_arr_[fields_.db_ind]->prime.Find(fields_.key)))
<< "Key was removed before PostUpdate() - this is a bug!";
// Make sure that the DB has not changed in size since this object was created.
// Adding or removing elements from the DB may invalidate iterators.
CHECK_EQ(fields_.db_size, fields_.db_slice->DbSize(fields_.db_ind))
<< "Attempting to run post-update after DB was modified";
CHECK_EQ(fields_.deletion_count, fields_.db_slice->deletion_count_)
<< "Attempting to run post-update after a deletion was issued";
DCHECK(fields_.action == DestructorAction::kRun);
CHECK_NE(fields_.db_slice, nullptr);
fields_.db_slice->PostUpdate(fields_.db_ind, fields_.it, fields_.key, fields_.key_existed);
Cancel(); // Reset to not run again
}
void DbSlice::AutoUpdater::Cancel() {
this->fields_ = {};
}
DbSlice::AutoUpdater::AutoUpdater(const Fields& fields) : fields_(fields) {
DCHECK(fields_.action == DestructorAction::kRun);
fields_.db_slice->PreUpdate(fields_.db_ind, fields_.it);
fields_.db_size = fields_.db_slice->DbSize(fields_.db_ind);
fields_.deletion_count = fields_.db_slice->deletion_count_;
}
OpResult<DbSlice::ItAndUpdater> DbSlice::FindMutable(const Context& cntx, string_view key,
unsigned req_obj_type) {
// TODO(#2252): Call an internal find version that does not handle post updates
auto it = FindExt(cntx, key).first;
if (!IsValid(it))
return OpStatus::KEY_NOTFOUND;
if (it->second.ObjType() != req_obj_type) {
return OpStatus::WRONG_TYPE;
}
return {
{it, AutoUpdater({AutoUpdater::DestructorAction::kRun, this, cntx.db_index, it, key, true})}};
}
auto DbSlice::FindReadOnly(const Context& cntx, string_view key, unsigned req_obj_type) const
-> OpResult<PrimeConstIterator> {
auto res = Find(cntx, key, req_obj_type);
return res.ok() ? OpResult<PrimeConstIterator>(res.value()) : res.status();
}
pair<PrimeIterator, ExpireIterator> DbSlice::FindExt(const Context& cntx, string_view key) const {
pair<PrimeIterator, ExpireIterator> res;
@ -562,6 +639,7 @@ bool DbSlice::Del(DbIndex db_ind, PrimeIterator it) {
}
PerformDeletion(it, shard_owner(), db.get());
deletion_count_++;
return true;
}

View file

@ -63,6 +63,44 @@ class DbSlice {
void operator=(const DbSlice&) = delete;
public:
class AutoUpdater {
public:
AutoUpdater();
AutoUpdater(AutoUpdater&& o);
AutoUpdater& operator=(AutoUpdater&& o);
~AutoUpdater();
void Run();
void Cancel();
private:
enum class DestructorAction {
kDoNothing,
kRun,
};
// Wrap members in a struct to auto generate operator=
struct Fields {
DestructorAction action = DestructorAction::kDoNothing;
DbSlice* db_slice = nullptr;
DbIndex db_ind = 0;
PrimeIterator it;
std::string_view key;
bool key_existed = false;
size_t db_size = 0;
size_t deletion_count = 0;
// TODO(#2252): Add heap size here, and only update memory in d'tor
};
AutoUpdater(const Fields& fields);
friend class DbSlice;
Fields fields_ = {};
};
struct Stats {
// DbStats db;
std::vector<DbStats> db_stats;
@ -148,9 +186,20 @@ class DbSlice {
return ExpirePeriod{time_ms - expire_base_[0]};
}
// TODO(#2252): Remove this in favor of FindMutable() / FindReadOnly()
OpResult<PrimeIterator> Find(const Context& cntx, std::string_view key,
unsigned req_obj_type) const;
struct ItAndUpdater {
PrimeIterator it;
AutoUpdater post_updater;
};
OpResult<ItAndUpdater> FindMutable(const Context& cntx, std::string_view key,
unsigned req_obj_type);
OpResult<PrimeConstIterator> FindReadOnly(const Context& cntx, std::string_view key,
unsigned req_obj_type) const;
// Returns (value, expire) dict entries if key exists, null if it does not exist or has expired.
std::pair<PrimeIterator, ExpireIterator> FindExt(const Context& cntx, std::string_view key) const;
@ -260,6 +309,7 @@ class DbSlice {
size_t DbSize(DbIndex db_ind) const;
// Callback functions called upon writing to the existing key.
// TODO(#2252): Remove these (or make them private)
void PreUpdate(DbIndex db_ind, PrimeIterator it);
void PostUpdate(DbIndex db_ind, PrimeIterator it, std::string_view key,
bool existing_entry = true);
@ -376,6 +426,7 @@ class DbSlice {
ssize_t memory_budget_ = SSIZE_MAX;
size_t bytes_per_object_ = 0;
size_t soft_budget_limit_ = 0;
size_t deletion_count_ = 0;
mutable SliceEvents events_; // we may change this even for const operations.

View file

@ -117,7 +117,7 @@ void PFAdd(CmdArgList args, ConnectionContext* cntx) {
OpResult<int64_t> CountHllsSingle(const OpArgs& op_args, string_view key) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> it = db_slice.Find(op_args.db_cntx, key, OBJ_STRING);
OpResult<PrimeConstIterator> it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (it.ok()) {
string hll;
string_view hll_view = it.value()->second.GetSlice(&hll);
@ -150,8 +150,8 @@ OpResult<vector<string>> ReadValues(const OpArgs& op_args, ArgSlice keys) {
try {
vector<string> values;
for (size_t i = 0; i < keys.size(); ++i) {
OpResult<PrimeIterator> it =
op_args.shard->db_slice().Find(op_args.db_cntx, keys[i], OBJ_STRING);
OpResult<PrimeConstIterator> it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_STRING);
if (it.ok()) {
string hll;
it.value()->second.GetString(&hll);

View file

@ -278,17 +278,18 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
* of returning no or very few elements. (taken from redis code at db.c line 904 */
constexpr size_t INTERATION_FACTOR = 10;
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_HASH);
OpResult<PrimeConstIterator> find_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (!find_res) {
DVLOG(1) << "ScanOp: find failed: " << find_res << ", baling out";
return find_res.status();
}
PrimeIterator it = find_res.value();
PrimeConstIterator it = find_res.value();
StringVec res;
uint32_t count = scan_op.limit * HASH_TABLE_ENTRIES_FACTOR;
PrimeValue& pv = it->second;
const PrimeValue& pv = it->second;
if (pv.Encoding() == kEncodingListPack) {
uint8_t* lp = (uint8_t*)pv.RObjPtr();
@ -342,15 +343,14 @@ OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, CmdArgList valu
DCHECK(!values.empty());
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_HASH);
if (!it_res)
return it_res.status();
op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, (*it_res)->second);
db_slice.PreUpdate(op_args.db_cntx.db_index, *it_res);
PrimeValue& pv = it_res->it->second;
op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, pv);
PrimeValue& pv = (*it_res)->second;
unsigned deleted = 0;
bool key_remove = false;
DbTableStats* stats = db_slice.MutableStats(op_args.db_cntx.db_index);
@ -387,7 +387,7 @@ OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, CmdArgList valu
}
}
db_slice.PostUpdate(op_args.db_cntx.db_index, *it_res, key);
it_res->post_updater.Run();
if (!key_remove)
op_args.shard->search_indices()->AddDoc(key, op_args.db_cntx, pv);
@ -396,7 +396,7 @@ OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, CmdArgList valu
if (enc == kEncodingListPack) {
stats->listpack_blob_cnt--;
}
db_slice.Del(op_args.db_cntx.db_index, *it_res);
db_slice.Del(op_args.db_cntx.db_index, it_res->it);
} else if (enc == kEncodingListPack) {
stats->listpack_bytes += lpBytes((uint8_t*)pv.RObjPtr());
}
@ -408,12 +408,12 @@ OpResult<vector<OptStr>> OpHMGet(const OpArgs& op_args, std::string_view key, Cm
DCHECK(!fields.empty());
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (!it_res)
return it_res.status();
PrimeValue& pv = (*it_res)->second;
const PrimeValue& pv = (*it_res)->second;
std::vector<OptStr> result(fields.size());
@ -466,7 +466,7 @@ OpResult<vector<OptStr>> OpHMGet(const OpArgs& op_args, std::string_view key, Cm
OpResult<uint32_t> OpLen(const OpArgs& op_args, string_view key) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (it_res) {
return HMapLength(op_args.db_cntx, (*it_res)->second);
@ -479,7 +479,7 @@ OpResult<uint32_t> OpLen(const OpArgs& op_args, string_view key) {
OpResult<int> OpExist(const OpArgs& op_args, string_view key, string_view field) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (!it_res) {
if (it_res.status() == OpStatus::KEY_NOTFOUND)
@ -503,7 +503,7 @@ OpResult<int> OpExist(const OpArgs& op_args, string_view key, string_view field)
OpResult<string> OpGet(const OpArgs& op_args, string_view key, string_view field) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (!it_res)
return it_res.status();
@ -531,7 +531,7 @@ OpResult<string> OpGet(const OpArgs& op_args, string_view key, string_view field
OpResult<vector<string>> OpGetAll(const OpArgs& op_args, string_view key, uint8_t mask) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (!it_res) {
if (it_res.status() == OpStatus::KEY_NOTFOUND)
return vector<string>{};
@ -582,7 +582,7 @@ OpResult<vector<string>> OpGetAll(const OpArgs& op_args, string_view key, uint8_
OpResult<size_t> OpStrLen(const OpArgs& op_args, string_view key, string_view field) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH);
if (!it_res) {
if (it_res.status() == OpStatus::KEY_NOTFOUND)
@ -1062,7 +1062,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
auto& db_slice = shard->db_slice();
DbContext db_context = t->GetDbContext();
auto it_res = db_slice.Find(db_context, key, OBJ_HASH);
auto it_res = db_slice.FindReadOnly(db_context, key, OBJ_HASH);
if (!it_res)
return it_res.status();
@ -1097,7 +1097,9 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) {
}
if (string_map->Empty()) {
db_slice.Del(db_context.db_index, *it_res);
auto it_mutable = db_slice.FindMutable(db_context, key, OBJ_HASH);
it_mutable->post_updater.Run();
db_slice.Del(db_context.db_index, it_mutable->it);
return facade::OpStatus::KEY_NOTFOUND;
}
} else if (pv.Encoding() == kEncodingListPack) {

View file

@ -146,21 +146,18 @@ error_code JsonReplace(JsonType& instance, string_view path, JsonReplaceCb callb
OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_view path,
JsonReplaceCb callback, JsonReplaceVerify verify_op = JsonReplaceVerifyNoOp) {
OpResult<PrimeIterator> it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON);
auto it_res = op_args.shard->db_slice().FindMutable(op_args.db_cntx, key, OBJ_JSON);
if (!it_res.ok()) {
return it_res.status();
}
PrimeIterator entry_it = it_res.value();
auto& db_slice = op_args.shard->db_slice();
auto db_index = op_args.db_cntx.db_index;
PrimeConstIterator entry_it = it_res->it;
JsonType* json_val = entry_it->second.GetJson();
DCHECK(json_val) << "should have a valid JSON object for key '" << key << "' the type for it is '"
<< entry_it->second.ObjType() << "'";
JsonType& json_entry = *json_val;
op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, entry_it->second);
db_slice.PreUpdate(db_index, entry_it);
// Run the update operation on this entry
error_code ec = JsonReplace(json_entry, path, callback);
@ -172,7 +169,7 @@ OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_vi
// Make sure that we don't have other internal issue with the operation
OpStatus res = verify_op(json_entry);
if (res == OpStatus::OK) {
db_slice.PostUpdate(db_index, entry_it, key);
it_res->post_updater.Run();
op_args.shard->search_indices()->AddDoc(key, op_args.db_cntx, entry_it->second);
}
@ -180,7 +177,8 @@ OpStatus UpdateEntry(const OpArgs& op_args, std::string_view key, std::string_vi
}
OpResult<JsonType*> GetJson(const OpArgs& op_args, string_view key) {
OpResult<PrimeIterator> it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON);
OpResult<PrimeConstIterator> it_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_JSON);
if (!it_res.ok())
return it_res.status();
@ -982,7 +980,8 @@ vector<OptString> OpJsonMGet(JsonExpression expression, const Transaction* t, En
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < args.size(); ++i) {
OpResult<PrimeIterator> it_res = db_slice.Find(t->GetDbContext(), args[i], OBJ_JSON);
OpResult<PrimeConstIterator> it_res =
db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_JSON);
if (!it_res.ok())
continue;
@ -1068,8 +1067,8 @@ OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
// and its not JSON, it would return an error.
if (path == "." || path == "$") {
if (is_nx_condition || is_xx_condition) {
OpResult<PrimeIterator> it_res =
op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_JSON);
OpResult<PrimeConstIterator> it_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_JSON);
bool key_exists = (it_res.status() != OpStatus::KEY_NOTFOUND);
if (is_nx_condition && key_exists) {
return false;

View file

@ -184,7 +184,7 @@ std::string OpBPop(Transaction* t, EngineShard* shard, std::string_view key, Lis
DVLOG(2) << "popping from " << key << " " << t->DebugId();
auto& db_slice = shard->db_slice();
auto it_res = db_slice.Find(t->GetDbContext(), key, OBJ_LIST);
auto it_res = db_slice.FindMutable(t->GetDbContext(), key, OBJ_LIST);
if (!it_res) {
auto messages = debugMessages.All();
@ -203,14 +203,13 @@ std::string OpBPop(Transaction* t, EngineShard* shard, std::string_view key, Lis
CHECK(it_res) << t->DebugId() << " " << key; // must exist and must be ok.
PrimeIterator it = *it_res;
PrimeIterator it = it_res->it;
quicklist* ql = GetQL(it->second);
absl::StrAppend(debugMessages.Next(), "OpBPop: ", key, " by ", t->DebugId());
db_slice.PreUpdate(t->GetDbIndex(), it);
std::string value = ListPop(dir, ql);
db_slice.PostUpdate(t->GetDbIndex(), it, key);
it_res->post_updater.Run();
if (quicklistCount(ql) == 0) {
DVLOG(1) << "deleting key " << key << " " << t->DebugId();
@ -230,20 +229,18 @@ std::string OpBPop(Transaction* t, EngineShard* shard, std::string_view key, Lis
OpResult<string> OpMoveSingleShard(const OpArgs& op_args, string_view src, string_view dest,
ListDir src_dir, ListDir dest_dir) {
auto& db_slice = op_args.shard->db_slice();
auto src_res = db_slice.Find(op_args.db_cntx, src, OBJ_LIST);
auto src_res = db_slice.FindMutable(op_args.db_cntx, src, OBJ_LIST);
if (!src_res)
return src_res.status();
PrimeIterator src_it = *src_res;
PrimeIterator src_it = src_res->it;
quicklist* src_ql = GetQL(src_it->second);
if (src == dest) { // simple case.
db_slice.PreUpdate(op_args.db_cntx.db_index, src_it);
string val = ListPop(src_dir, src_ql);
int pos = (dest_dir == ListDir::LEFT) ? QUICKLIST_HEAD : QUICKLIST_TAIL;
quicklistPush(src_ql, val.data(), val.size(), pos);
db_slice.PostUpdate(op_args.db_cntx.db_index, src_it, src);
return val;
}
@ -252,6 +249,7 @@ OpResult<string> OpMoveSingleShard(const OpArgs& op_args, string_view src, strin
PrimeIterator dest_it;
bool new_key = false;
try {
src_res->post_updater.Run();
tie(dest_it, new_key) = db_slice.AddOrFind(op_args.db_cntx, dest);
} catch (bad_alloc&) {
return OpStatus::OUT_OF_MEMORY;
@ -265,7 +263,9 @@ OpResult<string> OpMoveSingleShard(const OpArgs& op_args, string_view src, strin
dest_it->second.ImportRObj(obj);
// Insertion of dest could invalidate src_it. Find it again.
src_it = db_slice.GetTables(op_args.db_cntx.db_index).first->Find(src);
src_res = db_slice.FindMutable(op_args.db_cntx, src, OBJ_LIST);
src_it = src_res->it;
DCHECK(IsValid(src_it));
} else {
if (dest_it->second.ObjType() != OBJ_LIST)
return OpStatus::WRONG_TYPE;
@ -274,13 +274,11 @@ OpResult<string> OpMoveSingleShard(const OpArgs& op_args, string_view src, strin
db_slice.PreUpdate(op_args.db_cntx.db_index, dest_it);
}
db_slice.PreUpdate(op_args.db_cntx.db_index, src_it);
string val = ListPop(src_dir, src_ql);
int pos = (dest_dir == ListDir::LEFT) ? QUICKLIST_HEAD : QUICKLIST_TAIL;
quicklistPush(dest_ql, val.data(), val.size(), pos);
db_slice.PostUpdate(op_args.db_cntx.db_index, src_it, src);
src_res->post_updater.Run();
db_slice.PostUpdate(op_args.db_cntx.db_index, dest_it, dest, !new_key);
if (quicklistCount(src_ql) == 0) {
@ -293,7 +291,7 @@ OpResult<string> OpMoveSingleShard(const OpArgs& op_args, string_view src, strin
// Read-only peek operation that determines whether the list exists and optionally
// returns the first from left/right value without popping it from the list.
OpResult<string> Peek(const OpArgs& op_args, string_view key, ListDir dir, bool fetch) {
auto it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
auto it_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST);
if (!it_res) {
return it_res.status();
}
@ -321,6 +319,7 @@ OpResult<uint32_t> OpPush(const OpArgs& op_args, std::string_view key, ListDir d
bool new_key = false;
if (skip_notexist) {
// TODO(#2252): Move to FindMutable() once AddOrFindMutable() is implemented
auto it_res = es->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
if (!it_res)
return 0; // Redis returns 0 for nonexisting keys for the *PUSHX actions.
@ -384,13 +383,12 @@ OpResult<uint32_t> OpPush(const OpArgs& op_args, std::string_view key, ListDir d
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, ListDir dir, uint32_t count,
bool return_results, bool journal_rewrite) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST);
if (!it_res)
return it_res.status();
PrimeIterator it = *it_res;
PrimeIterator it = it_res->it;
quicklist* ql = GetQL(it->second);
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
StringVec res;
if (quicklistCount(ql) < count) {
@ -408,7 +406,7 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, ListDir dir, u
}
}
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
it_res->post_updater.Run();
if (quicklistCount(ql) == 0) {
absl::StrAppend(debugMessages.Next(), "OpPop Del: ", key, " by ", op_args.tx->DebugId());
@ -486,7 +484,7 @@ OpResult<string> MoveTwoShards(Transaction* trans, string_view src, string_view
}
OpResult<uint32_t> OpLen(const OpArgs& op_args, std::string_view key) {
auto res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
auto res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST);
if (!res)
return res.status();
@ -496,7 +494,7 @@ OpResult<uint32_t> OpLen(const OpArgs& op_args, std::string_view key) {
}
OpResult<string> OpIndex(const OpArgs& op_args, std::string_view key, long index) {
auto res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
auto res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST);
if (!res)
return res.status();
quicklist* ql = GetQL(res.value()->second);
@ -520,7 +518,8 @@ OpResult<string> OpIndex(const OpArgs& op_args, std::string_view key, long index
OpResult<vector<uint32_t>> OpPos(const OpArgs& op_args, std::string_view key,
std::string_view element, int rank, int count, int max_len) {
OpResult<PrimeIterator> it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
OpResult<PrimeConstIterator> it_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST);
if (!it_res.ok())
return it_res.status();
@ -564,11 +563,11 @@ OpResult<vector<uint32_t>> OpPos(const OpArgs& op_args, std::string_view key,
OpResult<int> OpInsert(const OpArgs& op_args, string_view key, string_view pivot, string_view elem,
int insert_param) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST);
if (!it_res)
return it_res.status();
quicklist* ql = GetQL(it_res.value()->second);
quicklist* ql = GetQL(it_res->it->second);
quicklistEntry entry = container_utils::QLEntry();
quicklistIter* qiter = quicklistGetIterator(ql, AL_START_HEAD);
bool found = false;
@ -582,14 +581,12 @@ OpResult<int> OpInsert(const OpArgs& op_args, string_view key, string_view pivot
int res = -1;
if (found) {
db_slice.PreUpdate(op_args.db_cntx.db_index, *it_res);
if (insert_param == LIST_TAIL) {
quicklistInsertAfter(qiter, &entry, elem.data(), elem.size());
} else {
DCHECK_EQ(LIST_HEAD, insert_param);
quicklistInsertBefore(qiter, &entry, elem.data(), elem.size());
}
db_slice.PostUpdate(op_args.db_cntx.db_index, *it_res, key);
res = quicklistCount(ql);
}
quicklistReleaseIterator(qiter);
@ -598,11 +595,11 @@ OpResult<int> OpInsert(const OpArgs& op_args, string_view key, string_view pivot
OpResult<uint32_t> OpRem(const OpArgs& op_args, string_view key, string_view elem, long count) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST);
if (!it_res)
return it_res.status();
PrimeIterator it = *it_res;
PrimeIterator it = it_res->it;
quicklist* ql = GetQL(it->second);
int iter_direction = AL_START_HEAD;
@ -618,7 +615,6 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, string_view key, string_view ele
unsigned removed = 0;
const uint8_t* elem_ptr = reinterpret_cast<const uint8_t*>(elem.data());
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
while (quicklistNext(qiter, &entry)) {
if (quicklistCompare(&entry, elem_ptr, elem.size())) {
quicklistDelEntry(qiter, &entry);
@ -627,7 +623,8 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, string_view key, string_view ele
break;
}
}
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
it_res->post_updater.Run();
quicklistReleaseIterator(qiter);
@ -640,16 +637,14 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, string_view key, string_view ele
OpStatus OpSet(const OpArgs& op_args, string_view key, string_view elem, long index) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST);
if (!it_res)
return it_res.status();
PrimeIterator it = *it_res;
PrimeIterator it = it_res->it;
quicklist* ql = GetQL(it->second);
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
int replaced = quicklistReplaceAtIndex(ql, index, elem.data(), elem.size());
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
if (!replaced) {
return OpStatus::OUT_OF_RANGE;
@ -659,11 +654,11 @@ OpStatus OpSet(const OpArgs& op_args, string_view key, string_view elem, long in
OpStatus OpTrim(const OpArgs& op_args, string_view key, long start, long end) {
auto& db_slice = op_args.shard->db_slice();
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_LIST);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_LIST);
if (!it_res)
return it_res.status();
PrimeIterator it = *it_res;
PrimeIterator it = it_res->it;
quicklist* ql = GetQL(it->second);
long llen = quicklistCount(ql);
@ -690,10 +685,10 @@ OpStatus OpTrim(const OpArgs& op_args, string_view key, long start, long end) {
rtrim = llen - end - 1;
}
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
quicklistDelRange(ql, 0, ltrim);
quicklistDelRange(ql, -rtrim, rtrim);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
it_res->post_updater.Run();
if (quicklistCount(ql) == 0) {
CHECK(db_slice.Del(op_args.db_cntx.db_index, it));
@ -702,7 +697,7 @@ OpStatus OpTrim(const OpArgs& op_args, string_view key, long start, long end) {
}
OpResult<StringVec> OpRange(const OpArgs& op_args, std::string_view key, long start, long end) {
auto res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
auto res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_LIST);
if (!res)
return res.status();

View file

@ -213,7 +213,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
size_t expired_count = 0;
for (size_t i = 0; i < search_results.ids.size(); i++) {
auto key = key_index_.Get(search_results.ids[i]);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
auto it = db_slice.FindReadOnly(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) { // Item must have expired
expired_count++;

View file

@ -695,20 +695,18 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, string_view key, const ArgSlice&
bool journal_rewrite) {
auto* es = op_args.shard;
auto& db_slice = es->db_slice();
OpResult<PrimeIterator> find_res = db_slice.Find(op_args.db_cntx, key, OBJ_SET);
auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
if (!find_res) {
return find_res.status();
}
db_slice.PreUpdate(op_args.db_cntx.db_index, *find_res);
CompactObj& co = find_res.value()->second;
CompactObj& co = find_res->it->second;
auto [removed, isempty] = RemoveSet(op_args.db_cntx, vals, &co);
db_slice.PostUpdate(op_args.db_cntx.db_index, *find_res, key);
find_res->post_updater.Run();
if (isempty) {
CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res.value()));
CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res->it));
}
if (journal_rewrite && op_args.shard->journal()) {
vector<string_view> mapped(vals.size() + 1);
@ -749,7 +747,7 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) {
for (auto k : largs) {
unsigned index = (k == src_) ? 0 : 1;
OpResult<PrimeIterator> res = es->db_slice().Find(t->GetDbContext(), k, OBJ_SET);
OpResult<PrimeConstIterator> res = es->db_slice().FindReadOnly(t->GetDbContext(), k, OBJ_SET);
if (res && index == 0) { // successful src find.
DCHECK(!res->is_done());
const CompactObj& val = res.value()->second;
@ -815,10 +813,10 @@ OpResult<StringVec> OpUnion(const OpArgs& op_args, ArgSlice keys) {
absl::flat_hash_set<string> uniques;
for (string_view key : keys) {
OpResult<PrimeIterator> find_res =
op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_SET);
OpResult<PrimeConstIterator> find_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
if (find_res) {
PrimeValue& pv = find_res.value()->second;
const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms));
@ -843,14 +841,15 @@ OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
DVLOG(1) << "OpDiff from " << keys.front();
EngineShard* es = op_args.shard;
OpResult<PrimeIterator> find_res = es->db_slice().Find(op_args.db_cntx, keys.front(), OBJ_SET);
OpResult<PrimeConstIterator> find_res =
es->db_slice().FindReadOnly(op_args.db_cntx, keys.front(), OBJ_SET);
if (!find_res) {
return find_res.status();
}
absl::flat_hash_set<string> uniques;
PrimeValue& pv = find_res.value()->second;
const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms));
@ -864,7 +863,8 @@ OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!uniques.empty()); // otherwise the key would not exist.
for (size_t i = 1; i < keys.size(); ++i) {
OpResult<PrimeIterator> diff_res = es->db_slice().Find(op_args.db_cntx, keys[i], OBJ_SET);
OpResult<PrimeConstIterator> diff_res =
es->db_slice().FindReadOnly(op_args.db_cntx, keys[i], OBJ_SET);
if (!diff_res) {
if (diff_res.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;
@ -901,12 +901,12 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
StringVec result;
if (keys.size() == 1) {
OpResult<PrimeIterator> find_res =
es->db_slice().Find(t->GetDbContext(), keys.front(), OBJ_SET);
OpResult<PrimeConstIterator> find_res =
es->db_slice().FindReadOnly(t->GetDbContext(), keys.front(), OBJ_SET);
if (!find_res)
return find_res.status();
PrimeValue& pv = find_res.value()->second;
const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));
@ -926,7 +926,8 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
OpStatus status = OpStatus::OK;
for (size_t i = 0; i < keys.size(); ++i) {
OpResult<PrimeIterator> find_res = es->db_slice().Find(t->GetDbContext(), keys[i], OBJ_SET);
OpResult<PrimeConstIterator> find_res =
es->db_slice().FindReadOnly(t->GetDbContext(), keys[i], OBJ_SET);
if (!find_res) {
if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND ||
find_res.status() != OpStatus::KEY_NOTFOUND) {
@ -976,7 +977,7 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
// count - how many elements to pop.
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> find_res = db_slice.Find(op_args.db_cntx, key, OBJ_SET);
auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
return find_res.status();
@ -984,7 +985,7 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count
if (count == 0)
return result;
PrimeIterator it = find_res.value();
PrimeIterator it = find_res->it;
size_t slen = it->second.Size();
/* CASE 1:
@ -1003,6 +1004,7 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count
});
// Delete the set as it is now empty
find_res->post_updater.Run();
CHECK(db_slice.Del(op_args.db_cntx.db_index, it));
// Replicate as DEL.
@ -1011,7 +1013,6 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count
}
} else {
SetType st{it->second.RObjPtr(), it->second.Encoding()};
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
if (st.second == kEncodingIntSet) {
intset* is = (intset*)st.first;
int64_t val = 0;
@ -1035,20 +1036,19 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count
std::copy(result.begin(), result.end(), mapped.begin() + 1);
RecordJournal(op_args, "SREM"sv, mapped);
}
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
}
return result;
}
OpResult<StringVec> OpScan(const OpArgs& op_args, string_view key, uint64_t* cursor,
const ScanOpts& scan_op) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_SET);
OpResult<PrimeConstIterator> find_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
return find_res.status();
PrimeIterator it = find_res.value();
PrimeConstIterator it = find_res.value();
StringVec res;
if (it->second.Encoding() == kEncodingIntSet) {
@ -1094,7 +1094,8 @@ void SIsMember(CmdArgList args, ConnectionContext* cntx) {
string_view val = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) {
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET);
OpResult<PrimeConstIterator> find_res =
shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (find_res) {
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
@ -1125,7 +1126,8 @@ void SMIsMember(CmdArgList args, ConnectionContext* cntx) {
memberships.reserve(vals.size());
auto cb = [&](Transaction* t, EngineShard* shard) {
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET);
OpResult<PrimeConstIterator> find_res =
shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (find_res) {
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FindInSet(memberships, t->GetDbContext(), st, vals);
@ -1191,7 +1193,8 @@ void SCard(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<uint32_t> {
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET);
OpResult<PrimeConstIterator> find_res =
shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (!find_res) {
return find_res.status();
}
@ -1366,12 +1369,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) {
const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
StringVec result;
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET);
OpResult<PrimeConstIterator> find_res =
shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (!find_res) {
return find_res.status();
}
PrimeValue& pv = find_res.value()->second;
const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));

View file

@ -611,6 +611,7 @@ OpResult<streamID> OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL
pair<PrimeIterator, bool> add_res;
if (opts.no_mkstream) {
// TODO(#2252): Replace with FindMutable() once AddOrFindMutable() is implemented
auto res_it = db_slice.Find(op_args.db_cntx, opts.key, OBJ_STREAM);
if (!res_it) {
return res_it.status();
@ -668,7 +669,7 @@ OpResult<streamID> OpAdd(const OpArgs& op_args, const AddTrimOpts& opts, CmdArgL
OpResult<RecordVec> OpRange(const OpArgs& op_args, string_view key, const RangeOpts& opts) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
OpResult<PrimeConstIterator> res_it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
@ -680,7 +681,7 @@ OpResult<RecordVec> OpRange(const OpArgs& op_args, string_view key, const RangeO
streamIterator si;
int64_t numfields;
streamID id;
CompactObj& cobj = (*res_it)->second;
const CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
streamID sstart = opts.start.val, send = opts.end.val;
@ -797,6 +798,14 @@ OpResult<RecordVec> OpRangeFromConsumerPEL(const OpArgs& op_args, string_view ke
return result;
}
namespace {
// Our C-API doesn't use const, so we have to const cast.
// Only intended for read-only functions.
stream* GetReadOnlyStream(const CompactObj& cobj) {
return const_cast<stream*>((const stream*)cobj.RObjPtr());
}
} // namespace
// Returns a map of stream to the ID of the last entry in the stream. Any
// streams not found are omitted from the result.
OpResult<vector<pair<string_view, streamID>>> OpLastIDs(const OpArgs& op_args,
@ -807,7 +816,7 @@ OpResult<vector<pair<string_view, streamID>>> OpLastIDs(const OpArgs& op_args,
vector<pair<string_view, streamID>> last_ids;
for (string_view key : args) {
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
OpResult<PrimeConstIterator> res_it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it) {
if (res_it.status() == OpStatus::KEY_NOTFOUND) {
continue;
@ -815,8 +824,8 @@ OpResult<vector<pair<string_view, streamID>>> OpLastIDs(const OpArgs& op_args,
return res_it.status();
}
CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
const CompactObj& cobj = (*res_it)->second;
stream* s = GetReadOnlyStream(cobj);
streamID last_id = s->last_id;
if (s->length) {
@ -869,10 +878,10 @@ vector<RecordVec> OpRead(const OpArgs& op_args, const ArgSlice& args, const Read
OpResult<uint32_t> OpLen(const OpArgs& op_args, string_view key) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
OpResult<PrimeConstIterator> res_it = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
CompactObj& cobj = (*res_it)->second;
const CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
return s->length;
}
@ -880,12 +889,12 @@ OpResult<uint32_t> OpLen(const OpArgs& op_args, string_view key) {
OpResult<vector<GroupInfo>> OpListGroups(const DbContext& db_cntx, string_view key,
EngineShard* shard) {
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(db_cntx, key, OBJ_STREAM);
OpResult<PrimeConstIterator> res_it = db_slice.FindReadOnly(db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
vector<GroupInfo> result;
CompactObj& cobj = (*res_it)->second;
const CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
if (s->cgroups) {
@ -1012,12 +1021,12 @@ void GetConsumers(stream* s, streamCG* cg, long long count, GroupInfo* ginfo) {
OpResult<StreamInfo> OpStreams(const DbContext& db_cntx, string_view key, EngineShard* shard,
int full, size_t count) {
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(db_cntx, key, OBJ_STREAM);
OpResult<PrimeConstIterator> res_it = db_slice.FindReadOnly(db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
vector<StreamInfo> result;
CompactObj& cobj = (*res_it)->second;
const CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
StreamInfo sinfo;
@ -1075,13 +1084,13 @@ OpResult<StreamInfo> OpStreams(const DbContext& db_cntx, string_view key, Engine
OpResult<vector<ConsumerInfo>> OpConsumers(const DbContext& db_cntx, EngineShard* shard,
string_view stream_name, string_view group_name) {
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(db_cntx, stream_name, OBJ_STREAM);
OpResult<PrimeConstIterator> res_it = db_slice.FindReadOnly(db_cntx, stream_name, OBJ_STREAM);
if (!res_it)
return res_it.status();
vector<ConsumerInfo> result;
CompactObj& cobj = (*res_it)->second;
stream* s = (stream*)cobj.RObjPtr();
const CompactObj& cobj = (*res_it)->second;
stream* s = GetReadOnlyStream(cobj);
shard->tmp_str1 = sdscpylen(shard->tmp_str1, group_name.data(), group_name.length());
streamCG* cg = streamLookupCG(s, shard->tmp_str1);
if (cg == NULL) {
@ -1120,6 +1129,7 @@ struct CreateOpts {
OpStatus OpCreate(const OpArgs& op_args, string_view key, const CreateOpts& opts) {
auto* shard = op_args.shard;
auto& db_slice = shard->db_slice();
// TODO(#2252): Replace with FindMutable() once new AddNew() is implemented
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
int64_t entries_read = SCG_INVALID_ENTRIES_READ;
if (!res_it) {
@ -1158,19 +1168,24 @@ OpStatus OpCreate(const OpArgs& op_args, string_view key, const CreateOpts& opts
return OpStatus::BUSY_GROUP;
}
OpResult<pair<stream*, streamCG*>> FindGroup(const OpArgs& op_args, string_view key,
string_view gname) {
struct FindGroupResult {
stream* s = nullptr;
streamCG* cg = nullptr;
DbSlice::AutoUpdater post_updater;
};
OpResult<FindGroupResult> FindGroup(const OpArgs& op_args, string_view key, string_view gname) {
auto* shard = op_args.shard;
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
CompactObj& cobj = (*res_it)->second;
pair<stream*, streamCG*> res;
res.first = (stream*)cobj.RObjPtr();
CompactObj& cobj = res_it->it->second;
FindGroupResult res;
res.s = (stream*)cobj.RObjPtr();
shard->tmp_str1 = sdscpylen(shard->tmp_str1, gname.data(), gname.size());
res.second = streamLookupCG(res.first, shard->tmp_str1);
res.cg = streamLookupCG(res.s, shard->tmp_str1);
res.post_updater = std::move(res_it->post_updater);
return res;
}
@ -1231,12 +1246,10 @@ void AppendClaimResultItem(ClaimInfo& result, stream* s, streamID id) {
// XCLAIM key group consumer min-idle-time id
OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimOpts& opts,
absl::Span<streamID> ids) {
OpResult<pair<stream*, streamCG*>> cgr_res = FindGroup(op_args, key, opts.group);
auto cgr_res = FindGroup(op_args, key, opts.group);
if (!cgr_res)
return cgr_res.status();
stream* s = cgr_res->first;
streamCG* scg = cgr_res->second;
if (!scg) {
if (!cgr_res->cg) {
return OpStatus::SKIPPED;
}
streamConsumer* consumer = nullptr;
@ -1246,8 +1259,8 @@ OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimO
streamID last_id = opts.last_id;
if (opts.flags & kClaimLastID) {
if (streamCompareID(&last_id, &scg->last_id) > 0) {
scg->last_id = last_id;
if (streamCompareID(&last_id, &cgr_res->cg->last_id) > 0) {
cgr_res->cg->last_id = last_id;
}
}
@ -1255,11 +1268,11 @@ OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimO
std::array<uint8_t, sizeof(streamID)> buf;
StreamEncodeID(buf.begin(), &id);
streamNACK* nack = (streamNACK*)raxFind(scg->pel, buf.begin(), sizeof(buf));
if (!streamEntryExists(s, &id)) {
streamNACK* nack = (streamNACK*)raxFind(cgr_res->cg->pel, buf.begin(), sizeof(buf));
if (!streamEntryExists(cgr_res->s, &id)) {
if (nack != raxNotFound) {
/* Release the NACK */
raxRemove(scg->pel, buf.begin(), sizeof(buf), nullptr);
raxRemove(cgr_res->cg->pel, buf.begin(), sizeof(buf), nullptr);
raxRemove(nack->consumer->pel, buf.begin(), sizeof(buf), nullptr);
streamFreeNACK(nack);
}
@ -1271,7 +1284,7 @@ OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimO
if ((opts.flags & kClaimForce) && nack == raxNotFound) {
/* Create the NACK. */
nack = streamCreateNACK(nullptr);
raxInsert(scg->pel, buf.begin(), sizeof(buf), nack, nullptr);
raxInsert(cgr_res->cg->pel, buf.begin(), sizeof(buf), nack, nullptr);
}
// We found the nack, continue.
@ -1287,9 +1300,9 @@ OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimO
// Try to get the consumer. If not found, create a new one.
op_args.shard->tmp_str1 =
sdscpylen(op_args.shard->tmp_str1, opts.consumer.data(), opts.consumer.size());
if ((consumer = streamLookupConsumer(scg, op_args.shard->tmp_str1, SLC_NO_REFRESH)) ==
if ((consumer = streamLookupConsumer(cgr_res->cg, op_args.shard->tmp_str1, SLC_NO_REFRESH)) ==
nullptr) {
consumer = streamCreateConsumer(scg, op_args.shard->tmp_str1, nullptr, 0,
consumer = streamCreateConsumer(cgr_res->cg, op_args.shard->tmp_str1, nullptr, 0,
SCC_NO_NOTIFY | SCC_NO_DIRTIFY);
}
@ -1319,7 +1332,7 @@ OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimO
}
/* Send the reply for this entry. */
AppendClaimResultItem(result, s, id);
AppendClaimResultItem(result, cgr_res->s, id);
}
}
return result;
@ -1327,16 +1340,13 @@ OpResult<ClaimInfo> OpClaim(const OpArgs& op_args, string_view key, const ClaimO
// XGROUP DESTROY key groupname
OpStatus OpDestroyGroup(const OpArgs& op_args, string_view key, string_view gname) {
OpResult<pair<stream*, streamCG*>> cgr_res = FindGroup(op_args, key, gname);
auto cgr_res = FindGroup(op_args, key, gname);
if (!cgr_res)
return cgr_res.status();
stream* s = cgr_res->first;
streamCG* scg = cgr_res->second;
if (scg) {
raxRemove(s->cgroups, (uint8_t*)(gname.data()), gname.size(), NULL);
streamFreeCG(scg);
if (cgr_res->cg) {
raxRemove(cgr_res->s->cgroups, (uint8_t*)(gname.data()), gname.size(), NULL);
streamFreeCG(cgr_res->cg);
return OpStatus::OK;
}
@ -1366,7 +1376,7 @@ vector<GroupConsumerPair> OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA
if (!group_res) {
continue;
}
if (group = group_res->second; !group) {
if (group = group_res->cg; !group) {
continue;
}
@ -1385,10 +1395,10 @@ vector<GroupConsumerPair> OpGetGroupConsumerPairs(ArgSlice slice_args, const OpA
// XGROUP CREATECONSUMER key groupname consumername
OpResult<uint32_t> OpCreateConsumer(const OpArgs& op_args, string_view key, string_view gname,
string_view consumer_name) {
OpResult<pair<stream*, streamCG*>> cgroup_res = FindGroup(op_args, key, gname);
auto cgroup_res = FindGroup(op_args, key, gname);
if (!cgroup_res)
return cgroup_res.status();
streamCG* cg = cgroup_res->second;
streamCG* cg = cgroup_res->cg;
if (cg == nullptr)
return OpStatus::SKIPPED;
@ -1405,11 +1415,11 @@ OpResult<uint32_t> OpCreateConsumer(const OpArgs& op_args, string_view key, stri
// XGROUP DELCONSUMER key groupname consumername
OpResult<uint32_t> OpDelConsumer(const OpArgs& op_args, string_view key, string_view gname,
string_view consumer_name) {
OpResult<pair<stream*, streamCG*>> cgroup_res = FindGroup(op_args, key, gname);
auto cgroup_res = FindGroup(op_args, key, gname);
if (!cgroup_res)
return cgroup_res.status();
streamCG* cg = cgroup_res->second;
streamCG* cg = cgroup_res->cg;
if (cg == nullptr)
return OpStatus::SKIPPED;
@ -1427,18 +1437,18 @@ OpResult<uint32_t> OpDelConsumer(const OpArgs& op_args, string_view key, string_
}
OpStatus OpSetId(const OpArgs& op_args, string_view key, string_view gname, string_view id) {
OpResult<pair<stream*, streamCG*>> cgr_res = FindGroup(op_args, key, gname);
auto cgr_res = FindGroup(op_args, key, gname);
if (!cgr_res)
return cgr_res.status();
streamCG* cg = cgr_res->second;
streamCG* cg = cgr_res->cg;
if (cg == nullptr)
return OpStatus::SKIPPED;
streamID sid;
ParsedStreamId parsed_id;
if (id == "$") {
sid = cgr_res->first->last_id;
sid = cgr_res->s->last_id;
} else {
if (ParseID(id, true, 0, &parsed_id)) {
sid = parsed_id.val;
@ -1454,11 +1464,11 @@ OpStatus OpSetId(const OpArgs& op_args, string_view key, string_view gname, stri
OpStatus OpSetId2(const OpArgs& op_args, string_view key, const streamID& sid) {
auto* shard = op_args.shard;
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
CompactObj& cobj = (*res_it)->second;
CompactObj& cobj = res_it->it->second;
stream* stream_inst = (stream*)cobj.RObjPtr();
long long entries_added = -1;
streamID max_xdel_id{0, 0};
@ -1493,11 +1503,11 @@ OpStatus OpSetId2(const OpArgs& op_args, string_view key, const streamID& sid) {
OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, absl::Span<streamID> ids) {
auto* shard = op_args.shard;
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_STREAM);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STREAM);
if (!res_it)
return res_it.status();
CompactObj& cobj = (*res_it)->second;
CompactObj& cobj = res_it->it->second;
stream* stream_inst = (stream*)cobj.RObjPtr();
uint32_t deleted = 0;
@ -1536,12 +1546,11 @@ OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, absl::Span<stre
// XACK key groupname id [id ...]
OpResult<uint32_t> OpAck(const OpArgs& op_args, string_view key, string_view gname,
absl::Span<streamID> ids) {
OpResult<pair<stream*, streamCG*>> res = FindGroup(op_args, key, gname);
auto res = FindGroup(op_args, key, gname);
if (!res)
return res.status();
auto [stream_inst, cg] = *res;
if (cg == nullptr || stream_inst == nullptr) {
if (res->cg == nullptr || res->s == nullptr) {
return 0;
}
@ -1554,9 +1563,9 @@ OpResult<uint32_t> OpAck(const OpArgs& op_args, string_view key, string_view gna
// Lookup the ID in the group PEL: it will have a reference to the
// NACK structure that will have a reference to the consumer, so that
// we are able to remove the entry from both PELs.
streamNACK* nack = (streamNACK*)raxFind(cg->pel, buf, sizeof(buf));
streamNACK* nack = (streamNACK*)raxFind(res->cg->pel, buf, sizeof(buf));
if (nack != raxNotFound) {
raxRemove(cg->pel, buf, sizeof(buf), nullptr);
raxRemove(res->cg->pel, buf, sizeof(buf), nullptr);
raxRemove(nack->consumer->pel, buf, sizeof(buf), nullptr);
streamFreeNACK(nack);
acknowledged++;
@ -1566,10 +1575,11 @@ OpResult<uint32_t> OpAck(const OpArgs& op_args, string_view key, string_view gna
}
OpResult<ClaimInfo> OpAutoClaim(const OpArgs& op_args, string_view key, const ClaimOpts& opts) {
OpResult<pair<stream*, streamCG*>> cgr_res = FindGroup(op_args, key, opts.group);
auto cgr_res = FindGroup(op_args, key, opts.group);
if (!cgr_res)
return cgr_res.status();
auto [stream, group] = *cgr_res;
stream* stream = cgr_res->s;
streamCG* group = cgr_res->cg;
if (stream == nullptr || group == nullptr) {
return OpStatus::KEY_NOTFOUND;
@ -1770,12 +1780,12 @@ PendingExtendedResultList GetPendingExtendedResult(streamCG* cg, streamConsumer*
}
OpResult<PendingResult> OpPending(const OpArgs& op_args, string_view key, const PendingOpts& opts) {
OpResult<pair<stream*, streamCG*>> cgroup_res = FindGroup(op_args, key, opts.group_name);
auto cgroup_res = FindGroup(op_args, key, opts.group_name);
if (!cgroup_res) {
return cgroup_res.status();
}
streamCG* cg = cgroup_res->second;
streamCG* cg = cgroup_res->cg;
if (cg == nullptr) {
return OpStatus::SKIPPED;
}
@ -1923,7 +1933,7 @@ void XGroupHelp(CmdArgList args, ConnectionContext* cntx) {
OpResult<int64_t> OpTrim(const OpArgs& op_args, const AddTrimOpts& opts) {
auto* shard = op_args.shard;
auto& db_slice = shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, opts.key, OBJ_STREAM);
auto res_it = db_slice.FindMutable(op_args.db_cntx, opts.key, OBJ_STREAM);
if (!res_it) {
if (res_it.status() == OpStatus::KEY_NOTFOUND) {
return 0;
@ -1931,7 +1941,7 @@ OpResult<int64_t> OpTrim(const OpArgs& op_args, const AddTrimOpts& opts) {
return res_it.status();
}
CompactObj& cobj = (*res_it)->second;
CompactObj& cobj = res_it->it->second;
stream* s = (stream*)cobj.RObjPtr();
return StreamTrim(opts, s);

View file

@ -80,7 +80,7 @@ OpResult<uint32_t> OpSetRange(const OpArgs& op_args, string_view key, size_t sta
size_t range_len = start + value.size();
if (range_len == 0) {
auto it_res = db_slice.Find(op_args.db_cntx, key, OBJ_STRING);
auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (it_res) {
return it_res.value()->second.Size();
} else {
@ -114,7 +114,7 @@ OpResult<uint32_t> OpSetRange(const OpArgs& op_args, string_view key, size_t sta
OpResult<string> OpGetRange(const OpArgs& op_args, string_view key, int32_t start, int32_t end) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> it_res = db_slice.Find(op_args.db_cntx, key, OBJ_STRING);
OpResult<PrimeConstIterator> it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
if (!it_res.ok())
return it_res.status();
@ -154,10 +154,7 @@ size_t ExtendExisting(const OpArgs& op_args, PrimeIterator it, string_view key,
else
new_val = absl::StrCat(slice, val);
auto& db_slice = shard->db_slice();
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
it->second.SetString(new_val);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, true);
return new_val.size();
}
@ -170,6 +167,7 @@ OpResult<uint32_t> ExtendOrSet(const OpArgs& op_args, string_view key, string_vi
auto [it, inserted] = db_slice.AddOrFind(op_args.db_cntx, key);
if (inserted) {
it->second.SetString(val);
// TODO(#2252): We currently only call PostUpdate() (no PreUpdate()), make sure this is fixed
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, false);
return val.size();
@ -178,17 +176,20 @@ OpResult<uint32_t> ExtendOrSet(const OpArgs& op_args, string_view key, string_vi
if (it->second.ObjType() != OBJ_STRING)
return OpStatus::WRONG_TYPE;
return ExtendExisting(op_args, it, key, val, prepend);
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
size_t res = ExtendExisting(op_args, it, key, val, prepend);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, true);
return res;
}
OpResult<bool> ExtendOrSkip(const OpArgs& op_args, string_view key, string_view val, bool prepend) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> it_res = db_slice.Find(op_args.db_cntx, key, OBJ_STRING);
auto it_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_STRING);
if (!it_res) {
return false;
}
return ExtendExisting(op_args, *it_res, key, val, prepend);
return ExtendExisting(op_args, it_res->it, key, val, prepend);
}
OpResult<string> OpGet(const OpArgs& op_args, string_view key, bool del_hit = false,
@ -509,11 +510,12 @@ SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const
auto& db_slice = shard->db_slice();
SinkReplyBuilder::MGetResponse response(args.size());
absl::InlinedVector<PrimeIterator, 32> iters(args.size());
absl::InlinedVector<PrimeConstIterator, 32> iters(args.size());
size_t total_size = 0;
for (size_t i = 0; i < args.size(); ++i) {
OpResult<PrimeIterator> it_res = db_slice.Find(t->GetDbContext(), args[i], OBJ_STRING);
OpResult<PrimeConstIterator> it_res =
db_slice.FindReadOnly(t->GetDbContext(), args[i], OBJ_STRING);
if (!it_res)
continue;
iters[i] = *it_res;
@ -524,7 +526,7 @@ SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const
char* next = response.storage_list->data;
for (size_t i = 0; i < args.size(); ++i) {
PrimeIterator it = iters[i];
PrimeConstIterator it = iters[i];
if (it.is_done())
continue;
@ -1292,7 +1294,8 @@ void StringFamily::StrLen(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<size_t> {
OpResult<PrimeIterator> it_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_STRING);
OpResult<PrimeConstIterator> it_res =
shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_STRING);
if (!it_res.ok())
return it_res.status();

View file

@ -25,8 +25,9 @@ using PrimeTable = DashTable<PrimeKey, PrimeValue, detail::PrimeTablePolicy>;
using ExpireTable = DashTable<PrimeKey, ExpirePeriod, detail::ExpireTablePolicy>;
/// Iterators are invalidated when new keys are added to the table or some entries are deleted.
/// Iterators are still valid if a different entry in the table was mutated.
/// Iterators are still valid if a different entry in the table was mutated.
using PrimeIterator = PrimeTable::iterator;
using PrimeConstIterator = PrimeTable::const_iterator;
using ExpireIterator = ExpireTable::iterator;
inline bool IsValid(PrimeIterator it) {
@ -37,6 +38,10 @@ inline bool IsValid(ExpireIterator it) {
return !it.is_done();
}
inline bool IsValid(PrimeConstIterator it) {
return !it.is_done();
}
struct SlotStats {
uint64_t key_count = 0;
uint64_t total_reads = 0;

View file

@ -147,7 +147,7 @@ int ZsetDel(detail::RobjWrapper* robj_wrapper, sds ele) {
}
// taken from t_zset.c
std::optional<double> GetZsetScore(detail::RobjWrapper* robj_wrapper, sds member) {
std::optional<double> GetZsetScore(const detail::RobjWrapper* robj_wrapper, sds member) {
if (robj_wrapper->encoding() == OBJ_ENCODING_LISTPACK) {
double score;
if (zzlFind((uint8_t*)robj_wrapper->inner_obj(), member, &score) == NULL)
@ -186,6 +186,7 @@ OpResult<PrimeIterator> FindZEntry(const ZParams& zparams, const OpArgs& op_args
size_t member_len) {
auto& db_slice = op_args.shard->db_slice();
if (zparams.flags & ZADD_IN_XX) {
// TODO(#2252): Replace once AddOrFindMutable() exists
return db_slice.Find(op_args.db_cntx, key, OBJ_ZSET);
}
@ -721,10 +722,11 @@ void SendAtLeastOneKeyError(ConnectionContext* cntx) {
enum class AggType : uint8_t { SUM, MIN, MAX, NOOP };
using ScoredMap = absl::flat_hash_map<std::string, double>;
ScoredMap FromObject(CompactObj& co, double weight) {
ScoredMap FromObject(const CompactObj& co, double weight) {
ZSetFamily::RangeParams params;
params.with_scores = true;
IntervalVisitor vis(Action::RANGE, params, &co);
// RANGE is a read-only operation, but requires const_cast
IntervalVisitor vis(Action::RANGE, params, &const_cast<CompactObj&>(co));
vis(ZSetFamily::IndexInterval(0, -1));
ScoredArray arr = vis.PopResult();
@ -795,16 +797,16 @@ void InterScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
dest->swap(*src);
}
using KeyIterWeightVec = vector<pair<PrimeIterator, double>>;
using KeyIterWeightVec = vector<pair<PrimeConstIterator, double>>;
ScoredMap UnionShardKeysWithScore(const KeyIterWeightVec& key_iter_weight_vec, AggType agg_type) {
ScoredMap result;
for (const auto& key_iter_wieght : key_iter_weight_vec) {
if (key_iter_wieght.first.is_done()) {
for (const auto& key_iter_weight : key_iter_weight_vec) {
if (key_iter_weight.first.is_done()) {
continue;
}
ScoredMap sm = FromObject(key_iter_wieght.first->second, key_iter_wieght.second);
ScoredMap sm = FromObject(key_iter_weight.first->second, key_iter_weight.second);
if (result.empty()) {
result.swap(sm);
} else {
@ -852,7 +854,7 @@ OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest
auto& db_slice = shard->db_slice();
KeyIterWeightVec key_weight_vec(keys.size());
for (unsigned j = 0; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->GetDbContext(), keys[j], OBJ_ZSET);
auto it_res = db_slice.FindReadOnly(t->GetDbContext(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
if (!it_res)
@ -1273,9 +1275,9 @@ bool ParseLimit(string_view offset_str, string_view limit_str, ZSetFamily::Range
ScoredArray OpBZPop(Transaction* t, EngineShard* shard, std::string_view key, bool is_max) {
auto& db_slice = shard->db_slice();
auto it_res = db_slice.Find(t->GetDbContext(), key, OBJ_ZSET);
auto it_res = db_slice.FindMutable(t->GetDbContext(), key, OBJ_ZSET);
CHECK(it_res) << t->DebugId() << " " << key; // must exist and must be ok.
PrimeIterator it = *it_res;
PrimeIterator it = it_res->it;
ZSetFamily::RangeParams range_params;
range_params.reverse = is_max;
@ -1291,12 +1293,12 @@ ScoredArray OpBZPop(Transaction* t, EngineShard* shard, std::string_view key, bo
IntervalVisitor iv{Action::POP, range_spec.params, &pv};
std::visit(iv, range_spec.interval);
db_slice.PostUpdate(t->GetDbIndex(), *it_res, key);
it_res->post_updater.Run();
auto zlen = pv.Size();
if (zlen == 0) {
DVLOG(1) << "deleting key " << key << " " << t->DebugId();
CHECK(db_slice.Del(t->GetDbIndex(), *it_res));
CHECK(db_slice.Del(t->GetDbIndex(), it_res->it));
}
OpArgs op_args = t->GetOpArgs(shard);
@ -1364,7 +1366,7 @@ vector<ScoredMap> OpFetch(EngineShard* shard, Transaction* t) {
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < keys.size(); ++i) {
auto it = db_slice.Find(t->GetDbContext(), keys[i], OBJ_ZSET);
auto it = db_slice.FindReadOnly(t->GetDbContext(), keys[i], OBJ_ZSET);
if (!it) {
results.push_back({});
continue;
@ -1380,22 +1382,20 @@ vector<ScoredMap> OpFetch(EngineShard* shard, Transaction* t) {
auto OpPopCount(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, string_view key)
-> OpResult<ScoredArray> {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it);
PrimeValue& pv = res_it.value()->second;
PrimeValue& pv = res_it->it->second;
IntervalVisitor iv{Action::POP, range_spec.params, &pv};
std::visit(iv, range_spec.interval);
db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key);
res_it->post_updater.Run();
auto zlen = pv.Size();
if (zlen == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value()));
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it->it));
}
return iv.PopResult();
@ -1403,11 +1403,13 @@ auto OpPopCount(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args,
auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, string_view key)
-> OpResult<ScoredArray> {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
PrimeValue& pv = res_it.value()->second;
// Action::RANGE is read-only, but requires mutable pointer, thus const_cast
PrimeValue& pv = const_cast<PrimeValue&>(res_it.value()->second);
IntervalVisitor iv{Action::RANGE, range_spec.params, &pv};
std::visit(iv, range_spec.interval);
@ -1417,11 +1419,13 @@ auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, st
auto OpRanges(const std::vector<ZSetFamily::ZRangeSpec>& range_specs, const OpArgs& op_args,
string_view key) -> OpResult<vector<ScoredArray>> {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
PrimeValue& pv = res_it.value()->second;
// Action::RANGE is read-only, but requires mutable pointer, thus const_cast
PrimeValue& pv = const_cast<PrimeValue&>(res_it.value()->second);
vector<ScoredArray> result_arrays;
for (auto& range_spec : range_specs) {
IntervalVisitor iv{Action::RANGE, range_spec.params, &pv};
@ -1435,21 +1439,19 @@ auto OpRanges(const std::vector<ZSetFamily::ZRangeSpec>& range_specs, const OpAr
OpResult<unsigned> OpRemRange(const OpArgs& op_args, string_view key,
const ZSetFamily::ZRangeSpec& range_spec) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it);
PrimeValue& pv = res_it.value()->second;
PrimeValue& pv = res_it->it->second;
IntervalVisitor iv{Action::REMOVE, range_spec.params, &pv};
std::visit(iv, range_spec.interval);
db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key);
res_it->post_updater.Run();
auto zlen = pv.Size();
if (zlen == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value()));
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it->it));
}
return iv.removed();
@ -1457,11 +1459,12 @@ OpResult<unsigned> OpRemRange(const OpArgs& op_args, string_view key,
OpResult<unsigned> OpRank(const OpArgs& op_args, string_view key, string_view member,
bool reverse) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
if (robj_wrapper->encoding() == OBJ_ENCODING_LISTPACK) {
unsigned char* zl = (uint8_t*)robj_wrapper->inner_obj();
unsigned char *eptr, *sptr;
@ -1503,11 +1506,12 @@ OpResult<unsigned> OpRank(const OpArgs& op_args, string_view key, string_view me
OpResult<unsigned> OpCount(const OpArgs& op_args, std::string_view key,
const ZSetFamily::ScoreInterval& interval) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
zrangespec range = GetZrangeSpec(false, interval);
unsigned count = 0;
@ -1553,13 +1557,14 @@ OpResult<unsigned> OpCount(const OpArgs& op_args, std::string_view key,
OpResult<unsigned> OpLexCount(const OpArgs& op_args, string_view key,
const ZSetFamily::LexInterval& interval) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
zlexrangespec range = GetLexRange(false, interval);
unsigned count = 0;
detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
if (robj_wrapper->encoding() == OBJ_ENCODING_LISTPACK) {
uint8_t* zl = (uint8_t*)robj_wrapper->inner_obj();
@ -1597,12 +1602,11 @@ OpResult<unsigned> OpLexCount(const OpArgs& op_args, string_view key,
OpResult<unsigned> OpRem(const OpArgs& op_args, string_view key, ArgSlice members) {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET);
auto res_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it);
detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
detail::RobjWrapper* robj_wrapper = res_it->it->second.GetRobjWrapper();
sds& tmp_str = op_args.shard->tmp_str1;
unsigned deleted = 0;
for (string_view member : members) {
@ -1610,25 +1614,26 @@ OpResult<unsigned> OpRem(const OpArgs& op_args, string_view key, ArgSlice member
deleted += ZsetDel(robj_wrapper, tmp_str);
}
auto zlen = robj_wrapper->Size();
db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key);
res_it->post_updater.Run();
if (zlen == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value()));
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it->it));
}
return deleted;
}
OpResult<double> OpScore(const OpArgs& op_args, string_view key, string_view member) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
PrimeValue& pv = res_it.value()->second;
const PrimeValue& pv = res_it.value()->second;
sds& tmp_str = op_args.shard->tmp_str1;
tmp_str = sdscpylen(tmp_str, member.data(), member.size());
detail::RobjWrapper* robj_wrapper = pv.GetRobjWrapper();
const detail::RobjWrapper* robj_wrapper = pv.GetRobjWrapper();
auto res = GetZsetScore(robj_wrapper, tmp_str);
if (!res)
return OpStatus::KEY_NOTFOUND;
@ -1636,13 +1641,14 @@ OpResult<double> OpScore(const OpArgs& op_args, string_view key, string_view mem
}
OpResult<MScoreResponse> OpMScore(const OpArgs& op_args, string_view key, ArgSlice members) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> res_it =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
MScoreResponse scores(members.size());
detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
const detail::RobjWrapper* robj_wrapper = res_it.value()->second.GetRobjWrapper();
sds& tmp_str = op_args.shard->tmp_str1;
for (size_t i = 0; i < members.size(); i++) {
@ -1657,20 +1663,21 @@ OpResult<MScoreResponse> OpMScore(const OpArgs& op_args, string_view key, ArgSli
OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor,
const ScanOpts& scan_op) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
OpResult<PrimeConstIterator> find_res =
op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!find_res)
return find_res.status();
PrimeIterator it = find_res.value();
PrimeValue& pv = it->second;
PrimeConstIterator it = find_res.value();
const PrimeValue& pv = it->second;
StringVec res;
char buf[128];
if (pv.Encoding() == OBJ_ENCODING_LISTPACK) {
ZSetFamily::RangeParams params;
params.with_scores = true;
IntervalVisitor iv{Action::RANGE, params, &pv};
IntervalVisitor iv{Action::RANGE, params, const_cast<PrimeValue*>(&pv)};
iv(ZSetFamily::IndexInterval{0, kuint32max});
ScoredArray arr = iv.PopResult();
@ -1865,7 +1872,8 @@ void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<uint32_t> {
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_ZSET);
OpResult<PrimeConstIterator> find_res =
shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_ZSET);
if (!find_res) {
return find_res.status();
}