fix(DenseSet): Rename Size() to UpperBoundSize() and add SizeSlow() (#2130)

Then use the right version (hopefully) in the right places.

Specifically, this fixes a serialization bug, where we could send
malformed responses when using `UpperBoundSize()` to write array length.
This commit is contained in:
Shahar Mike 2023-11-06 08:52:08 +02:00 committed by GitHub
parent f809fb04bc
commit 7e23c14c35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 48 additions and 41 deletions

View file

@ -286,7 +286,7 @@ size_t RobjWrapper::Size() const {
}
case kEncodingStrMap2: {
StringSet* ss = (StringSet*)inner_obj_;
return ss->Size();
return ss->UpperBoundSize();
}
default:
LOG(FATAL) << "Unexpected encoding " << encoding_;
@ -300,7 +300,7 @@ size_t RobjWrapper::Size() const {
case kEncodingStrMap2: {
StringMap* sm = (StringMap*)inner_obj_;
return sm->Size();
return sm->UpperBoundSize();
}
}
default:;

View file

@ -671,4 +671,17 @@ bool DenseSet::ExpireIfNeededInternal(DensePtr* prev, DensePtr* node) const {
return deleted;
}
void DenseSet::CollectExpired() {
// Simply iterating over all items will remove expired
auto it = IteratorBase(this, false);
while (it.curr_entry_ != nullptr) {
it.Advance();
}
}
size_t DenseSet::SizeSlow() {
CollectExpired();
return size_;
}
} // namespace dfly

View file

@ -210,10 +210,15 @@ class DenseSet {
explicit DenseSet(MemoryResource* mr = PMR_NS::get_default_resource());
virtual ~DenseSet();
size_t Size() const {
// Returns the number of elements in the map. Note that it might be that some of these elements
// have expired and can't be accessed.
size_t UpperBoundSize() const {
return size_;
}
// Returns an accurate size, post-expiration. O(n).
size_t SizeSlow();
bool Empty() const {
return size_ == 0;
}
@ -261,6 +266,8 @@ class DenseSet {
virtual uint32_t ObjExpireTime(const void* obj) const = 0;
virtual void ObjDelete(void* obj, bool has_ttl) const = 0;
void CollectExpired();
bool EraseInternal(void* obj, uint32_t cookie) {
auto [prev, found] = Find(obj, BucketId(obj, cookie), cookie);
if (found) {

View file

@ -642,7 +642,7 @@ optional<unsigned> SortedMap::DfImpl::GetRank(sds ele, bool reverse) const {
optional rank = score_tree->GetRank(obj);
DCHECK(rank);
return reverse ? score_map->Size() - *rank - 1 : *rank;
return reverse ? score_map->UpperBoundSize() - *rank - 1 : *rank;
}
SortedMap::ScoredArray SortedMap::DfImpl::GetRange(const zrangespec& range, unsigned offset,
@ -904,8 +904,8 @@ size_t SortedMap::DfImpl::DeleteRangeByLex(const zlexrangespec& range) {
}
SortedMap::ScoredArray SortedMap::DfImpl::PopTopScores(unsigned count, bool reverse) {
DCHECK_EQ(score_map->Size(), score_tree->Size());
size_t sz = score_map->Size();
DCHECK_EQ(score_map->UpperBoundSize(), score_tree->Size());
size_t sz = score_map->UpperBoundSize();
ScoredArray res;
@ -913,7 +913,7 @@ SortedMap::ScoredArray SortedMap::DfImpl::PopTopScores(unsigned count, bool reve
return res;
if (count >= sz)
count = score_map->Size();
count = score_map->UpperBoundSize();
res.reserve(count);
unsigned rank = 0;

View file

@ -232,7 +232,7 @@ class SortedMap {
bool Delete(sds ele);
size_t Size() const {
return score_map->Size();
return score_map->UpperBoundSize();
}
size_t MallocSize() const;

View file

@ -105,12 +105,6 @@ void StringMap::Clear() {
ClearInternal();
}
void StringMap::CollectExpired() {
// Simply iterating over all items will remove expired
for (auto it = begin(); it != end(); ++it) {
}
}
optional<pair<sds, sds>> StringMap::RandomPair() {
// Iteration may remove elements, and so we need to loop if we happen to reach the end
while (true) {
@ -121,7 +115,7 @@ optional<pair<sds, sds>> StringMap::RandomPair() {
break;
}
it += rand() % Size();
it += rand() % UpperBoundSize();
if (it != end()) {
return std::make_pair(it->first, it->second);
}
@ -131,9 +125,7 @@ optional<pair<sds, sds>> StringMap::RandomPair() {
void StringMap::RandomPairsUnique(unsigned int count, std::vector<sds>& keys,
std::vector<sds>& vals, bool with_value) {
CollectExpired();
unsigned int total_size = Size();
unsigned int total_size = SizeSlow();
unsigned int index = 0;
if (count > total_size)
count = total_size;
@ -162,11 +154,9 @@ void StringMap::RandomPairsUnique(unsigned int count, std::vector<sds>& keys,
void StringMap::RandomPairs(unsigned int count, std::vector<sds>& keys, std::vector<sds>& vals,
bool with_value) {
CollectExpired();
using RandomPick = std::pair<unsigned int, unsigned int>;
std::vector<RandomPick> picks;
unsigned int total_size = Size();
unsigned int total_size = SizeSlow();
for (unsigned int i = 0; i < count; ++i) {
RandomPick pick{rand() % total_size, i};

View file

@ -153,8 +153,6 @@ class StringMap : public DenseSet {
// Returns new pointer (stays same if key utilization is enough) and if reallocation happened.
std::pair<sds, bool> ReallocIfNeeded(void* obj, float ratio);
void CollectExpired();
uint64_t Hash(const void* obj, uint32_t cookie) const final;
bool ObjEqual(const void* left, const void* right, uint32_t right_cookie) const final;
size_t ObjectAllocSize(const void* obj) const final;

View file

@ -105,7 +105,7 @@ TEST_F(StringMapTest, Ttl) {
EXPECT_FALSE(sm_->AddOrUpdate("bla", "val2", 1));
sm_->set_time(1);
EXPECT_TRUE(sm_->AddOrUpdate("bla", "val2", 1));
EXPECT_EQ(1u, sm_->Size());
EXPECT_EQ(1u, sm_->UpperBoundSize());
EXPECT_FALSE(sm_->AddOrSkip("bla", "val3", 2));
@ -168,7 +168,7 @@ TEST_F(StringMapTest, ReallocIfNeeded) {
// Check we waste significanlty less now
EXPECT_GT(wasted_before, wasted_after * 2);
EXPECT_EQ(sm_->Size(), 1000);
EXPECT_EQ(sm_->UpperBoundSize(), 1000);
for (size_t i = 0; i < 1000; i++)
EXPECT_EQ(sm_->Find(build_str(i * 10))->second, build_str(i * 10 + 1));
}

View file

@ -90,7 +90,7 @@ TEST_F(StringSetTest, Basic) {
EXPECT_FALSE(ss_->Add("bar"sv));
EXPECT_TRUE(ss_->Contains("foo"sv));
EXPECT_TRUE(ss_->Contains("bar"sv));
EXPECT_EQ(2, ss_->Size());
EXPECT_EQ(2, ss_->UpperBoundSize());
}
TEST_F(StringSetTest, StandardAddErase) {
@ -150,11 +150,11 @@ TEST_F(StringSetTest, Resizing) {
for (size_t i = 0; i < num_strs; ++i) {
EXPECT_TRUE(ss_->Add(strs[i]));
EXPECT_EQ(ss_->Size(), i + 1);
EXPECT_EQ(ss_->UpperBoundSize(), i + 1);
// make sure we haven't lost any items after a grow
// which happens every power of 2
if (i != 0 && (ss_->Size() & (ss_->Size() - 1)) == 0) {
if (i != 0 && (ss_->UpperBoundSize() & (ss_->UpperBoundSize() - 1)) == 0) {
for (size_t j = 0; j < i; ++j) {
EXPECT_TRUE(ss_->Contains(strs[j]));
}
@ -343,12 +343,12 @@ TEST_F(StringSetTest, Pop) {
}
while (!ss_->Empty()) {
size_t size = ss_->Size();
size_t size = ss_->UpperBoundSize();
auto str = ss_->Pop();
DCHECK(ss_->Size() == to_insert.size() - 1);
DCHECK(ss_->UpperBoundSize() == to_insert.size() - 1);
DCHECK(str.has_value());
DCHECK(to_insert.count(str.value()));
DCHECK_EQ(ss_->Size(), size - 1);
DCHECK_EQ(ss_->UpperBoundSize(), size - 1);
to_insert.erase(str.value());
}
@ -394,12 +394,12 @@ TEST_F(StringSetTest, Ttl) {
ss_->set_time(1);
EXPECT_TRUE(ss_->Add("bla"sv, 1));
EXPECT_EQ(1u, ss_->Size());
EXPECT_EQ(1u, ss_->UpperBoundSize());
for (unsigned i = 0; i < 100; ++i) {
EXPECT_TRUE(ss_->Add(StrCat("foo", i), 1));
}
EXPECT_EQ(101u, ss_->Size());
EXPECT_EQ(101u, ss_->UpperBoundSize());
it = ss_->Find("foo50");
EXPECT_STREQ("foo50", *it);
EXPECT_EQ(2u, it.ExpiryTime());

View file

@ -119,7 +119,7 @@ size_t HMapLength(const DbContext& db_cntx, const CompactObj& co) {
void* ptr = co.RObjPtr();
if (co.Encoding() == kEncodingStrMap2) {
StringMap* sm = GetStringMap(co, db_cntx);
return sm->Size();
return sm->UpperBoundSize();
}
DCHECK_EQ(kEncodingListPack, co.Encoding());
@ -379,7 +379,7 @@ OpResult<uint32_t> OpDel(const OpArgs& op_args, string_view key, CmdArgList valu
bool res = sm->Erase(ToSV(s));
if (res) {
++deleted;
if (sm->Size() == 0) {
if (sm->UpperBoundSize() == 0) {
key_remove = true;
break;
}
@ -565,8 +565,7 @@ OpResult<vector<string>> OpGetAll(const OpArgs& op_args, string_view key, uint8_
DCHECK_EQ(pv.Encoding(), kEncodingStrMap2);
StringMap* sm = GetStringMap(pv, op_args.db_cntx);
// Some items could have expired, yet accounted for in Size(), so reserve() might overshoot
res.reserve(sm->Size() * (keyval ? 2 : 1));
res.reserve(sm->UpperBoundSize() * (keyval ? 2 : 1));
for (const auto& k_v : *sm) {
if (mask & FIELDS) {
res.emplace_back(k_v.first, sdslen(k_v.first));
@ -1077,7 +1076,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) {
}
} else {
size_t actual_count =
(count >= 0) ? std::min(size_t(count), string_map->Size()) : abs(count);
(count >= 0) ? std::min(size_t(count), string_map->UpperBoundSize()) : abs(count);
std::vector<sds> keys, vals;
if (count >= 0) {
string_map->RandomPairsUnique(actual_count, keys, vals, with_values);

View file

@ -408,7 +408,7 @@ error_code RdbSerializer::SaveSetObject(const PrimeValue& obj) {
} else if (obj.Encoding() == kEncodingStrMap2) {
StringSet* set = (StringSet*)obj.RObjPtr();
RETURN_ON_ERR(SaveLen(set->Size()));
RETURN_ON_ERR(SaveLen(set->SizeSlow()));
for (sds ele : *set) {
RETURN_ON_ERR(SaveString(string_view{ele, sdslen(ele)}));
@ -430,7 +430,7 @@ error_code RdbSerializer::SaveHSetObject(const PrimeValue& pv) {
if (pv.Encoding() == kEncodingStrMap2) {
StringMap* string_map = (StringMap*)pv.RObjPtr();
RETURN_ON_ERR(SaveLen(string_map->Size()));
RETURN_ON_ERR(SaveLen(string_map->SizeSlow()));
for (const auto& k_v : *string_map) {
RETURN_ON_ERR(SaveString(string_view{k_v.first, sdslen(k_v.first)}));

View file

@ -260,7 +260,7 @@ uint32_t SetTypeLen(const DbContext& db_context, const SetType& set) {
if (IsDenseEncoding(set)) {
StringSet* ss = (StringSet*)set.first;
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
return ss->Size();
return ss->UpperBoundSize();
}
DCHECK_EQ(set.second, kEncodingStrMap);