chore: several improvements around sorted map (#1699)

* chore: several improvements around sorted map

1. Pass memory_resource to sorted_map.
2. Get rid of GetDict leaky accessor in SortedMap and introduce a proper
   Scan method.
3. Introduce correct BPTree type inside SortedMap::DFImpl.
4. Added a test for bptree_test that covers sds comparison
   (apparently, sdscmp can return values outside of [-1, 1] range).
   Fixed bptree code to support a proper spec for three-way comparison.
5. Expose pointers to internal objects allocated by score_map so we could insert them
   into bptree.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>

* chore: fix comments

Signed-off-by: Roman Gershman <roman@dragonflydb.io>

---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2023-08-15 18:08:59 +03:00 committed by GitHub
parent e22c131b7c
commit ec22e73a28
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 226 additions and 124 deletions

View file

@ -65,7 +65,7 @@ template <typename T, typename Policy = BPTreePolicy<T>> class BPTree {
void Clear();
BPTreeNode* DEBUG_root() {
const BPTreeNode* DEBUG_root() const {
return root_;
}
@ -299,10 +299,10 @@ void BPTree<T, Policy>::InsertToFullLeaf(KeyT item, const BPTreePath& path) {
assert(node->NumItems() < Layout::kMaxLeafKeys);
if (insert_pos <= node->NumItems()) {
assert(Comp()(item, median) == -1);
assert(Comp()(item, median) < 0);
node->LeafInsert(insert_pos, item);
} else {
assert(Comp()(item, median) == 1);
assert(Comp()(item, median) > 0);
right->LeafInsert(insert_pos - node->NumItems() - 1, item);
}
@ -359,12 +359,12 @@ void BPTree<T, Policy>::InsertToFullLeaf(KeyT item, const BPTreePath& path) {
assert(node->NumItems() < Layout::kMaxInnerKeys);
if (pos <= node->NumItems()) {
assert(Comp()(median, next_median) == -1);
assert(Comp()(median, next_median) < 0);
node->InnerInsert(pos, median, right);
node->IncreaseTreeCount(1);
} else {
assert(Comp()(median, next_median) == 1);
assert(Comp()(median, next_median) > 0);
next_right->InnerInsert(pos - node->NumItems() - 1, median, right);
@ -566,7 +566,7 @@ template <typename T, typename Policy> void BPTree<T, Policy>::Delete(BPTreePath
path.DigRight();
BPTreeNode* leaf = path.Last().first;
assert(Comp()(leaf->Key(leaf->NumItems() - 1), node->Key(key_pos)) == -1);
assert(Comp()(leaf->Key(leaf->NumItems() - 1), node->Key(key_pos)) < 0);
// set a new separator.
node->SetKey(key_pos, leaf->Key(leaf->NumItems() - 1));

View file

@ -22,35 +22,14 @@ using namespace std;
namespace dfly {
class BPTreeSetTest : public ::testing::Test {
using Node = detail::BPTreeNode<uint64_t>;
namespace {
protected:
static constexpr size_t kNumElems = 7000;
template <typename Node, typename Policy>
bool ValidateNode(const Node* node, typename Node::KeyT ubound) {
typename Policy::KeyCompareTo cmp;
BPTreeSetTest() : mi_alloc_(mi_heap_get_backing()), bptree_(&mi_alloc_) {
}
static void SetUpTestSuite() {
}
void FillTree(unsigned factor = 1) {
for (unsigned i = 0; i < kNumElems; ++i) {
bptree_.Insert(i * factor);
}
}
bool Validate();
static bool Validate(const Node* node, uint64_t ubound);
MiMemoryResource mi_alloc_;
BPTree<uint64_t> bptree_;
mt19937 generator_{1};
};
bool BPTreeSetTest::Validate(const Node* node, uint64_t ubound) {
for (unsigned i = 1; i < node->NumItems(); ++i) {
if (node->Key(i - 1) >= node->Key(i))
if (cmp(node->Key(i - 1), node->Key(i)) > -1)
return false;
}
@ -71,25 +50,72 @@ bool BPTreeSetTest::Validate(const Node* node, uint64_t ubound) {
}
}
return node->Key(node->NumItems() - 1) < ubound;
return cmp(node->Key(node->NumItems() - 1), ubound) == -1;
}
struct ZsetPolicy {
struct KeyT {
double d;
sds s;
};
struct KeyCompareTo {
int operator()(const KeyT& left, const KeyT& right) {
if (left.d < right.d)
return -1;
if (left.d > right.d)
return 1;
// Note that sdscmp can return values outside of [-1, 1] range.
return sdscmp(left.s, right.s);
}
};
};
using SDSTree = BPTree<ZsetPolicy::KeyT, ZsetPolicy>;
} // namespace
class BPTreeSetTest : public ::testing::Test {
using Node = detail::BPTreeNode<uint64_t>;
protected:
static constexpr size_t kNumElems = 7000;
BPTreeSetTest() : mi_alloc_(mi_heap_get_backing()), bptree_(&mi_alloc_) {
}
static void SetUpTestSuite() {
}
void FillTree(unsigned factor = 1) {
for (unsigned i = 0; i < kNumElems; ++i) {
bptree_.Insert(i * factor);
}
}
bool Validate();
MiMemoryResource mi_alloc_;
BPTree<uint64_t> bptree_;
mt19937 generator_{1};
};
bool BPTreeSetTest::Validate() {
auto* root = bptree_.DEBUG_root();
if (!root)
return true;
// node, upper bound
std::vector<pair<Node*, uint64_t>> stack;
vector<pair<const Node*, uint64_t>> stack;
stack.emplace_back(root, UINT64_MAX);
while (!stack.empty()) {
Node* node = stack.back().first;
const Node* node = stack.back().first;
uint64_t ubound = stack.back().second;
stack.pop_back();
if (!Validate(node, ubound))
if (!ValidateNode<Node, BPTreePolicy<uint64_t>>(node, ubound))
return false;
if (!node->IsLeaf()) {
@ -353,25 +379,24 @@ TEST_F(BPTreeSetTest, MemoryUsage) {
LOG(INFO) << "btree after: " << mi_alloc.used() << " bytes";
}
struct ZsetPolicy {
struct KeyT {
double d;
sds s;
};
TEST_F(BPTreeSetTest, InsertSDS) {
vector<ZsetPolicy::KeyT> vals;
for (unsigned i = 0; i < 256; ++i) {
sds s = sdsempty();
struct KeyCompareTo {
int operator()(const KeyT& left, const KeyT& right) {
if (left.d < right.d)
return -1;
if (left.d > right.d)
return 1;
s = sdscatfmt(s, "a%u", i);
vals.emplace_back(ZsetPolicy::KeyT{.d = 1000, .s = s});
}
return sdscmp(left.s, right.s);
}
};
};
SDSTree tree(&mi_alloc_);
for (size_t i = 0; i < vals.size(); ++i) {
ASSERT_TRUE(tree.Insert(vals[i]));
}
using SDSTree = BPTree<ZsetPolicy::KeyT, ZsetPolicy>;
for (auto v : vals) {
sdsfree(v.s);
}
}
static string RandomString(mt19937& rand, unsigned len) {
const string_view alpanum = "1234567890abcdefghijklmnopqrstuvwxyz";

View file

@ -430,7 +430,7 @@ int RobjWrapper::ZsetAdd(double score, sds ele, int in_flags, int* out_flags, do
* becomes too long *before* executing zzlInsert. */
if (zl_len + 1 > server.zset_max_listpack_entries ||
sdslen(ele) > server.zset_max_listpack_value || !lpSafeToAdd(lp, sdslen(ele))) {
unique_ptr<SortedMap> ss = SortedMap::FromListPack(lp);
unique_ptr<SortedMap> ss = SortedMap::FromListPack(tl.local_mr, lp);
lpFree(lp);
inner_obj_ = ss.release();
encoding_ = OBJ_ENCODING_SKIPLIST;

View file

@ -19,28 +19,13 @@ namespace dfly {
namespace {
union DoubleUnion {
double d;
uint64_t u;
};
inline double GetValue(sds key) {
char* valptr = key + sdslen(key) + 1;
DoubleUnion u;
u.u = absl::little_endian::Load64(valptr);
return u.d;
return absl::bit_cast<double>(absl::little_endian::Load64(valptr));
}
} // namespace
ScoreMap::~ScoreMap() {
Clear();
}
bool ScoreMap::AddOrUpdate(string_view field, double value) {
void* AllocateScored(string_view field, double value) {
size_t meta_offset = field.size() + 1;
DoubleUnion u;
u.d = value;
// The layout is:
// key, '\0', 8-byte double value
@ -50,23 +35,35 @@ bool ScoreMap::AddOrUpdate(string_view field, double value) {
memcpy(newkey, field.data(), field.size());
}
absl::little_endian::Store64(newkey + meta_offset, u.u);
absl::little_endian::Store64(newkey + meta_offset, absl::bit_cast<uint64_t>(value));
return newkey;
}
} // namespace
ScoreMap::~ScoreMap() {
Clear();
}
pair<void*, bool> ScoreMap::AddOrUpdate(string_view field, double value) {
void* newkey = AllocateScored(field, value);
// Replace the whole entry.
sds prev_entry = (sds)AddOrReplaceObj(newkey, false);
if (prev_entry) {
ObjDelete(prev_entry, false);
return false;
return {newkey, false};
}
return true;
return {newkey, true};
}
bool ScoreMap::AddOrSkip(std::string_view field, double value) {
std::pair<void*, bool> ScoreMap::AddOrSkip(std::string_view field, double value) {
void* obj = FindInternal(&field, 1); // 1 - string_view
if (obj)
return false;
return {obj, false};
return AddOrUpdate(field, value);
}

View file

@ -77,13 +77,13 @@ class ScoreMap : public DenseSet {
}
};
// Returns true if field was added
// otherwise updates its value and returns false.
bool AddOrUpdate(std::string_view field, double value);
// Returns pointer to the internal objest and the insertion result.
// i.e. true if field was added, otherwise updates its value and returns false.
std::pair<void*, bool> AddOrUpdate(std::string_view field, double value);
// Returns true if field was added
// false, if already exists. In that case no update is done.
bool AddOrSkip(std::string_view field, double value);
std::pair<void*, bool> AddOrSkip(std::string_view field, double value);
bool Erase(std::string_view s1);
@ -92,6 +92,11 @@ class ScoreMap : public DenseSet {
/// @return sds
std::optional<double> Find(std::string_view key);
// returns the internal object if found, otherwise nullptr.
void* FindObj(sds ele) {
return FindInternal(ele, 0);
}
void Clear();
iterator begin() {

View file

@ -54,7 +54,7 @@ class ScoreMapTest : public ::testing::Test {
};
TEST_F(ScoreMapTest, Basic) {
EXPECT_TRUE(sm_->AddOrUpdate("foo", 5));
EXPECT_TRUE(sm_->AddOrUpdate("foo", 5).second);
EXPECT_EQ(5, sm_->Find("foo"));
auto it = sm_->begin();
@ -70,13 +70,13 @@ TEST_F(ScoreMapTest, Basic) {
}
size_t sz = sm_->ObjMallocUsed();
EXPECT_FALSE(sm_->AddOrUpdate("foo", 17));
EXPECT_FALSE(sm_->AddOrUpdate("foo", 17).second);
EXPECT_EQ(sm_->ObjMallocUsed(), sz);
it = sm_->begin();
EXPECT_EQ(17, it->second);
EXPECT_FALSE(sm_->AddOrSkip("foo", 31));
EXPECT_FALSE(sm_->AddOrSkip("foo", 31).second);
EXPECT_EQ(17, it->second);
}

View file

@ -13,6 +13,7 @@ extern "C" {
#include "redis/zmalloc.h"
}
#include "base/endian.h"
#include "base/logging.h"
using namespace std;
@ -56,6 +57,12 @@ inline bool IsUnder(bool reverse, double score, const zrangespec& spec) {
return reverse ? zslValueGteMin(score, &spec) : zslValueLteMax(score, &spec);
}
double GetObjScore(const void* obj) {
sds s = (sds)obj;
char* ptr = s + sdslen(s) + 1;
return absl::bit_cast<double>(absl::little_endian::Load64(ptr));
}
} // namespace
void SortedMap::RdImpl::Init() {
@ -356,6 +363,33 @@ bool SortedMap::RdImpl::Iterate(unsigned start_rank, unsigned len, bool reverse,
return success;
}
uint64_t SortedMap::RdImpl::Scan(uint64_t cursor,
absl::FunctionRef<void(std::string_view, double)> cb) const {
auto scanCb = [](void* privdata, const dictEntry* de) {
auto* cb = reinterpret_cast<absl::FunctionRef<void(std::string_view, double)>*>(privdata);
sds key = (sds)de->key;
double score = *reinterpret_cast<double*>(dictGetVal(de));
(*cb)(std::string_view(key, sdslen(key)), score);
};
return dictScan(this->dict, cursor, scanCb, NULL, &cb);
}
int SortedMap::DfImpl::ScoreSdsPolicy::KeyCompareTo::operator()(ScoreSds a, ScoreSds b) const {
double sa = GetObjScore(a);
double sb = GetObjScore(b);
if (sa < sb)
return -1;
if (sa > sb)
return 1;
sds sdsa = (sds)a;
sds sdsb = (sds)b;
return sdscmp(sdsa, sdsb);
}
int SortedMap::DfImpl::Add(double score, sds ele, int in_flags, int* out_flags, double* newscore) {
LOG(FATAL) << "TBD";
return 0;
@ -366,7 +400,7 @@ optional<double> SortedMap::DfImpl::GetScore(sds ele) const {
return std::nullopt;
}
void SortedMap::DfImpl::Init() {
void SortedMap::DfImpl::Init(PMR_NS::memory_resource* mr) {
LOG(FATAL) << "TBD";
}
@ -427,19 +461,26 @@ bool SortedMap::DfImpl::Iterate(unsigned start_rank, unsigned len, bool reverse,
return false;
}
uint64_t SortedMap::DfImpl::Scan(uint64_t cursor,
absl::FunctionRef<void(std::string_view, double)> cb) const {
LOG(FATAL) << "TBD";
return 0;
}
/***************************************************************************/
/* SortedMap */
/***************************************************************************/
SortedMap::SortedMap() : impl_(RdImpl()) {
std::visit(Overload{[](RdImpl& impl) { impl.Init(); }, [](DfImpl& impl) { impl.Init(); }}, impl_);
SortedMap::SortedMap(PMR_NS::memory_resource* mr) : impl_(RdImpl()), mr_res_(mr) {
std::visit(Overload{[](RdImpl& impl) { impl.Init(); }, [mr](DfImpl& impl) { impl.Init(mr); }},
impl_);
}
SortedMap::~SortedMap() {
std::visit(Overload{[](RdImpl& impl) { impl.Free(); }, [](DfImpl& impl) { impl.Free(); }}, impl_);
std::visit(Overload{[](auto& impl) { impl.Free(); }}, impl_);
}
// taken from zsetConvert
unique_ptr<SortedMap> SortedMap::FromListPack(const uint8_t* lp) {
unique_ptr<SortedMap> SortedMap::FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp) {
uint8_t* zl = (uint8_t*)lp;
unsigned char *eptr, *sptr;
unsigned char* vstr;
@ -447,7 +488,7 @@ unique_ptr<SortedMap> SortedMap::FromListPack(const uint8_t* lp) {
long long vlong;
sds ele;
unique_ptr<SortedMap> zs(new SortedMap());
unique_ptr<SortedMap> zs(new SortedMap(res));
eptr = lpSeek(zl, 0);
if (eptr != NULL) {

View file

@ -39,14 +39,14 @@ class SortedMap {
using ScoredMember = std::pair<std::string, double>;
using ScoredArray = std::vector<ScoredMember>;
SortedMap();
SortedMap(PMR_NS::memory_resource* res);
SortedMap(const SortedMap&) = delete;
SortedMap& operator=(const SortedMap&) = delete;
~SortedMap();
// The ownership for the returned SortedMap stays with the caller
static std::unique_ptr<SortedMap> FromListPack(const uint8_t* lp);
static std::unique_ptr<SortedMap> FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp);
size_t Size() const {
return std::visit(Overload{[](const auto& impl) { return impl.Size(); }}, impl_);
@ -66,6 +66,7 @@ class SortedMap {
impl_);
}
// Takes ownership over member.
bool Insert(double score, sds member) {
return std::visit(Overload{[&](auto& impl) { return impl.Insert(score, member); }}, impl_);
}
@ -78,9 +79,8 @@ class SortedMap {
return std::visit(Overload{[](const auto& impl) { return impl.MallocSize(); }}, impl_);
}
// TODO: to get rid of this method.
dict* GetDict() const {
return std::get<RdImpl>(impl_).dict;
uint64_t Scan(uint64_t cursor, absl::FunctionRef<void(std::string_view, double)> cb) const {
return std::visit([&](const auto& impl) { return impl.Scan(cursor, cb); }, impl_);
}
size_t DeleteRangeByRank(unsigned start, unsigned end) {
@ -203,13 +203,26 @@ class SortedMap {
// Stops iteration if cb returns false. Returns false in this case.
bool Iterate(unsigned start_rank, unsigned len, bool reverse,
absl::FunctionRef<bool(sds, double)> cb) const;
uint64_t Scan(uint64_t cursor, absl::FunctionRef<void(std::string_view, double)> cb) const;
};
struct DfImpl {
ScoreMap* score_map = nullptr;
BPTree<uint64_t>* bptree = nullptr; // just a stub for now.
using ScoreSds = void*;
void Init();
struct ScoreSdsPolicy {
using KeyT = ScoreSds;
struct KeyCompareTo {
int operator()(KeyT a, KeyT b) const;
};
};
using ScoreTree = BPTree<ScoreSds, ScoreSdsPolicy>;
ScoreTree* score_tree = nullptr; // just a stub for now.
void Init(PMR_NS::memory_resource* mr);
void Free();
@ -260,9 +273,12 @@ class SortedMap {
// Stops iteration if cb returns false. Returns false in this case.
bool Iterate(unsigned start_rank, unsigned len, bool reverse,
absl::FunctionRef<bool(sds, double)> cb) const;
uint64_t Scan(uint64_t cursor, absl::FunctionRef<void(std::string_view, double)> cb) const;
};
std::variant<RdImpl, DfImpl> impl_;
PMR_NS::memory_resource* mr_res_;
};
} // namespace detail

View file

@ -22,6 +22,9 @@ using detail::SortedMap;
class SortedMapTest : public ::testing::Test {
protected:
SortedMapTest() : mr_(mi_heap_get_backing()) {
}
static void SetUpTestSuite() {
// configure redis lib zmalloc which requires mimalloc heap to work.
auto* tlh = mi_heap_get_backing();
@ -31,10 +34,13 @@ class SortedMapTest : public ::testing::Test {
void AddMember(zskiplist* zsl, double score, sds ele) {
zslInsert(zsl, score, ele);
}
MiMemoryResource mr_;
};
TEST_F(SortedMapTest, Add) {
SortedMap sm;
SortedMap sm(&mr_);
int out_flags;
double new_score;
@ -52,6 +58,32 @@ TEST_F(SortedMapTest, Add) {
EXPECT_EQ(1, res);
EXPECT_EQ(ZADD_OUT_UPDATED, out_flags);
EXPECT_EQ(3, new_score);
EXPECT_EQ(3, sm.GetScore(ele));
}
TEST_F(SortedMapTest, Scan) {
SortedMap sm(&mr_);
for (unsigned i = 0; i < 972; ++i) {
sm.Insert(i, sdsfromlonglong(i));
}
uint64_t cursor = 0;
unsigned cnt = 0;
do {
cursor = sm.Scan(cursor, [&](string_view str, double score) { ++cnt; });
} while (cursor != 0);
EXPECT_EQ(972, cnt);
}
TEST_F(SortedMapTest, Insert) {
SortedMap sm(&mr_);
for (unsigned i = 0; i < 256; ++i) {
sds s = sdsempty();
s = sdscatfmt(s, "a%u", i);
ASSERT_TRUE(sm.Insert(1000, s));
}
}
} // namespace dfly

View file

@ -10,7 +10,7 @@
/* Input flags. */
#define ZADD_IN_NONE 0
#define ZADD_IN_INCR (1 << 0) /* Increment the score instead of setting it. */
#define ZADD_IN_NX (1 << 1) /* Don't touch elements not already existing. */
#define ZADD_IN_NX (1 << 1) /* Don't touch elements already existing. */
#define ZADD_IN_XX (1 << 2) /* Only touch elements already existing. */
#define ZADD_IN_GT (1 << 3) /* Only update existing when new scores are higher. */
#define ZADD_IN_LT (1 << 4) /* Only update existing when new scores are lower. */

View file

@ -696,7 +696,7 @@ void RdbLoaderBase::OpaqueObjLoader::CreateList(const LoadTrace* ltrace) {
void RdbLoaderBase::OpaqueObjLoader::CreateZSet(const LoadTrace* ltrace) {
size_t zsetlen = ltrace->blob_count();
detail::SortedMap* zs = new detail::SortedMap;
detail::SortedMap* zs = new detail::SortedMap(CompactObj::memory_resource());
unsigned encoding = OBJ_ENCODING_SKIPLIST;
auto cleanup = absl::MakeCleanup([&] { delete zs; });
@ -995,7 +995,7 @@ void RdbLoaderBase::OpaqueObjLoader::HandleBlob(string_view blob) {
unsigned encoding = OBJ_ENCODING_LISTPACK;
void* inner;
if (lpBytes(lp) > server.zset_max_listpack_entries) {
inner = detail::SortedMap::FromListPack(lp).release();
inner = detail::SortedMap::FromListPack(CompactObj::memory_resource(), lp).release();
lpFree(lp);
encoding = OBJ_ENCODING_SKIPLIST;
} else {

View file

@ -164,7 +164,7 @@ OpResult<PrimeIterator> FindZEntry(const ZParams& zparams, const OpArgs& op_args
if (add_res.second || zparams.override) {
if (member_len > kMaxListPackValue) {
detail::SortedMap* zs = new detail::SortedMap();
detail::SortedMap* zs = new detail::SortedMap(CompactObj::memory_resource());
pv.InitRobj(OBJ_ZSET, OBJ_ENCODING_SKIPLIST, zs);
} else {
unsigned char* lp = lpNew(0);
@ -1564,35 +1564,21 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
} else {
CHECK_EQ(unsigned(OBJ_ENCODING_SKIPLIST), pv.Encoding());
uint32_t count = scan_op.limit;
detail::SortedMap* zs = (detail::SortedMap*)pv.RObjPtr();
dict* ht = zs->GetDict();
detail::SortedMap* sm = (detail::SortedMap*)pv.RObjPtr();
long maxiterations = count * 10;
uint64_t cur = *cursor;
struct ScanArgs {
char* sbuf;
StringVec* res;
const ScanOpts* scan_op;
} sargs = {buf, &res, &scan_op};
auto scanCb = [](void* privdata, const dictEntry* de) {
ScanArgs* sargs = (ScanArgs*)privdata;
sds key = (sds)de->key;
if (!sargs->scan_op->Matches(key)) {
return;
auto cb = [&](string_view str, double score) {
if (scan_op.Matches(str)) {
res.emplace_back(str);
char* str = RedisReplyBuilder::FormatDouble(score, buf, sizeof(buf));
res.emplace_back(str);
}
double score = *(double*)dictGetVal(de);
sargs->res->emplace_back(key, sdslen(key));
char* str = RedisReplyBuilder::FormatDouble(score, sargs->sbuf, sizeof(buf));
sargs->res->emplace_back(str);
};
do {
*cursor = dictScan(ht, *cursor, scanCb, NULL, &sargs);
} while (*cursor && maxiterations-- && res.size() < count);
cur = sm->Scan(cur, cb);
} while (cur && maxiterations-- && res.size() < count);
*cursor = cur;
}
return res;