From 28a2db1044c7b9afb8717e7fb6bcb16b69f8c592 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Wed, 2 Mar 2022 19:06:49 +0200 Subject: [PATCH] Implement hset method --- src/core/compact_object.cc | 97 ++++++++++++++++++++------------ src/core/compact_object_test.cc | 9 ++- src/redis/zmalloc.h | 1 - src/server/CMakeLists.txt | 1 + src/server/engine_shard_set.cc | 7 ++- src/server/engine_shard_set.h | 2 +- src/server/hset_family.cc | 77 +++++++++++++++++++++++++ src/server/hset_family.h | 22 +++++--- src/server/hset_family_test.cc | 36 ++++++++++++ src/server/list_family.cc | 4 +- src/server/main_service.cc | 4 ++ src/server/rdb_test.cc | 2 +- src/server/redis_parser_test.cc | 2 +- src/server/set_family.cc | 16 +++--- src/server/string_family.cc | 3 +- src/server/string_family_test.cc | 4 ++ src/server/transaction.cc | 9 ++- 17 files changed, 232 insertions(+), 64 deletions(-) create mode 100644 src/server/hset_family_test.cc diff --git a/src/core/compact_object.cc b/src/core/compact_object.cc index c2eecbef7..5c4994051 100644 --- a/src/core/compact_object.cc +++ b/src/core/compact_object.cc @@ -13,6 +13,7 @@ extern "C" { #include "redis/object.h" #include "redis/util.h" #include "redis/zmalloc.h" // for non-string objects. +#include "redis/zset.h" } #include @@ -47,6 +48,60 @@ size_t DictMallocSize(dict* d) { return res = dictSize(d) * 16; // approximation. } +inline void FreeObjSet(unsigned encoding, void* ptr) { + switch (encoding) { + case OBJ_ENCODING_HT: + dictRelease((dict*)ptr); + break; + case OBJ_ENCODING_INTSET: + zfree((void*)ptr); + break; + default: + LOG(FATAL) << "Unknown set encoding type"; + } +} + +size_t MallocUsedSet(unsigned encoding, void* ptr) { + switch (encoding) { + case OBJ_ENCODING_HT: + return DictMallocSize((dict*)ptr); + case OBJ_ENCODING_INTSET: + return intsetBlobLen((intset*)ptr); + default: + LOG(FATAL) << "Unknown set encoding type " << encoding; + } +} + +inline void FreeObjHash(unsigned encoding, void* ptr) { + switch (encoding) { + case OBJ_ENCODING_HT: + dictRelease((dict*)ptr); + break; + case OBJ_ENCODING_LISTPACK: + lpFree((uint8_t*)ptr); + break; + default: + LOG(FATAL) << "Unknown hset encoding type " << encoding; + } +} + +inline void FreeObjZset(unsigned encoding, void* ptr) { + zset* zs = (zset*)ptr; + switch (encoding) { + case OBJ_ENCODING_SKIPLIST: + zs = (zset*)ptr; + dictRelease(zs->dict); + zslFree(zs->zsl); + zfree(zs); + break; + case OBJ_ENCODING_LISTPACK: + zfree(ptr); + break; + default: + LOG(FATAL) << "Unknown sorted set encoding" << encoding; + } +} + // Deniel's Lemire function validate_ascii_fast() - under Apache/MIT license. // See https://github.com/lemire/fastvalidate-utf-8/ // The function returns true (1) if all chars passed in src are @@ -180,14 +235,7 @@ size_t RobjWrapper::MallocUsed() const { CHECK_EQ(encoding, OBJ_ENCODING_QUICKLIST); return QlMAllocSize((quicklist*)ptr); case OBJ_SET: - switch (encoding) { - case OBJ_ENCODING_HT: - return DictMallocSize((dict*)ptr); - case OBJ_ENCODING_INTSET: - return intsetBlobLen((intset*)ptr); - default: - LOG(FATAL) << "Unknown set encoding type"; - } + return MallocUsedSet(encoding, ptr); break; default: LOG(FATAL) << "Not supported " << type; @@ -216,43 +264,21 @@ void RobjWrapper::Free(std::pmr::memory_resource* mr) { switch (type) { case OBJ_STRING: DVLOG(2) << "Freeing string object"; - if (encoding == OBJ_ENCODING_RAW) { - blob.Free(mr); - } else { - CHECK_EQ(OBJ_ENCODING_INT, encoding); - } + DCHECK_EQ(OBJ_ENCODING_RAW, encoding); + blob.Free(mr); break; case OBJ_LIST: CHECK_EQ(encoding, OBJ_ENCODING_QUICKLIST); quicklistRelease((quicklist*)ptr); break; - case OBJ_SET: - switch (encoding) { - case OBJ_ENCODING_HT: - dictRelease((dict*)ptr); - break; - case OBJ_ENCODING_INTSET: - zfree((void*)ptr); - break; - default: - LOG(FATAL) << "Unknown set encoding type"; - } + FreeObjSet(encoding, ptr); break; case OBJ_ZSET: - LOG(FATAL) << "TBD"; + FreeObjZset(encoding, ptr); break; case OBJ_HASH: - switch (encoding) { - case OBJ_ENCODING_HT: - dictRelease((dict*)ptr); - break; - case OBJ_ENCODING_LISTPACK: - lpFree((uint8_t*)ptr); - break; - default: - LOG(FATAL) << "Unknown hset encoding type"; - } + FreeObjHash(encoding, ptr); break; case OBJ_MODULE: LOG(FATAL) << "Unsupported OBJ_MODULE type"; @@ -285,6 +311,7 @@ uint64_t RobjWrapper::HashCode() const { bool RobjWrapper::Equal(const RobjWrapper& ow) const { if (ow.type != type || ow.encoding != encoding) return false; + if (type == OBJ_STRING) { DCHECK_EQ(OBJ_ENCODING_RAW, encoding); return blob.AsView() == ow.blob.AsView(); diff --git a/src/core/compact_object_test.cc b/src/core/compact_object_test.cc index 9abcc924c..71d9dacfa 100644 --- a/src/core/compact_object_test.cc +++ b/src/core/compact_object_test.cc @@ -149,17 +149,22 @@ TEST_F(CompactObjectTest, HSet) { sds key1 = sdsnew("key1"); sds val1 = sdsnew("val1"); - // returns 0 on insert. EXPECT_EQ(0, hashTypeSet(os, key1, val1, HASH_SET_TAKE_FIELD | HASH_SET_TAKE_VALUE)); cobj_.SyncRObj(); } TEST_F(CompactObjectTest, ZSet) { - // unrelated, checking sds static encoding used in zset special strings. + // unrelated, checking that sds static encoding works. + // it is used in zset special strings. char kMinStrData[] = "\110" "minstring"; EXPECT_EQ(9, sdslen(kMinStrData + 1)); + robj* src = createZsetListpackObject(); + cobj_.ImportRObj(src); + + EXPECT_EQ(OBJ_ZSET, cobj_.ObjType()); + EXPECT_EQ(OBJ_ENCODING_LISTPACK, cobj_.Encoding()); } } // namespace dfly diff --git a/src/redis/zmalloc.h b/src/redis/zmalloc.h index 512526221..8b3d12cfd 100644 --- a/src/redis/zmalloc.h +++ b/src/redis/zmalloc.h @@ -113,7 +113,6 @@ size_t zmalloc_usable_size(const void* p); // roman: void zlibc_free(void *ptr); -extern __thread ssize_t used_memory_tl; void init_zmalloc_threadlocal(); #undef __zm_str diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 001482243..95106306d 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -17,6 +17,7 @@ cxx_link(dfly_test_lib dragonfly_lib gtest_main_ext) cxx_test(dragonfly_test dfly_test_lib LABELS DFLY) cxx_test(generic_family_test dfly_test_lib LABELS DFLY) +cxx_test(hset_family_test dfly_test_lib LABELS DFLY) cxx_test(list_family_test dfly_test_lib LABELS DFLY) cxx_test(memcache_parser_test dfly_test_lib LABELS DFLY) cxx_test(redis_parser_test dfly_test_lib LABELS DFLY) diff --git a/src/server/engine_shard_set.cc b/src/server/engine_shard_set.cc index ab5854b21..cd1451e1c 100644 --- a/src/server/engine_shard_set.cc +++ b/src/server/engine_shard_set.cc @@ -73,13 +73,16 @@ EngineShard::EngineShard(util::ProactorBase* pb, bool update_db_time, mi_heap_t* }); } - tmp_str = sdsempty(); + tmp_str1 = sdsempty(); + tmp_str2 = sdsempty(); } EngineShard::~EngineShard() { queue_.Shutdown(); fiber_q_.join(); - sdsfree(tmp_str); + sdsfree(tmp_str1); + sdsfree(tmp_str2); + if (periodic_task_) { ProactorBase::me()->CancelPeriodic(periodic_task_); } diff --git a/src/server/engine_shard_set.h b/src/server/engine_shard_set.h index 373bc57cb..818c55cd5 100644 --- a/src/server/engine_shard_set.h +++ b/src/server/engine_shard_set.h @@ -128,7 +128,7 @@ class EngineShard { } // for everyone to use for string transformations during atomic cpu sequences. - sds tmp_str; + sds tmp_str1, tmp_str2; private: EngineShard(util::ProactorBase* pb, bool update_db_time, mi_heap_t* heap); diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index d515d6488..d588dfaeb 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -4,10 +4,37 @@ #include "server/hset_family.h" +extern "C" { +#include "redis/listpack.h" +#include "redis/object.h" +#include "redis/redis_aux.h" +} + +#include "base/logging.h" #include "server/command_registry.h" +#include "server/conn_context.h" +#include "server/engine_shard_set.h" +#include "server/transaction.h" + +using namespace std; namespace dfly { +namespace { + +bool IsGoodForListpack(CmdArgList args, const uint8_t* lp) { + size_t sum = 0; + for (auto s : args) { + if (s.size() > server.hash_max_listpack_value) + return false; + sum += s.size(); + } + + return lpSafeToAdd(const_cast(lp), sum); +} + +} // namespace + void HSetFamily::HDel(CmdArgList args, ConnectionContext* cntx) { } @@ -24,6 +51,19 @@ void HSetFamily::HIncrBy(CmdArgList args, ConnectionContext* cntx) { } void HSetFamily::HSet(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 1); + + args.remove_prefix(2); + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpHSet(OpArgs{shard, t->db_index()}, key, args, false); + }; + + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result) { + (*cntx)->SendLong(*result); + } else { + (*cntx)->SendError(result.status()); + } } void HSetFamily::HSetNx(CmdArgList args, ConnectionContext* cntx) { @@ -32,6 +72,43 @@ void HSetFamily::HSetNx(CmdArgList args, ConnectionContext* cntx) { void HSetFamily::HStrLen(CmdArgList args, ConnectionContext* cntx) { } +OpResult HSetFamily::OpHSet(const OpArgs& op_args, std::string_view key, + CmdArgList values, bool skip_if_exists) { + DCHECK(!values.empty() && 0 == values.size() % 2); + + auto& db_slice = op_args.shard->db_slice(); + const auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key); + + if (inserted) { + robj* ro = createHashObject(); + it->second.ImportRObj(ro); + } else { + if (it->second.ObjType() != OBJ_HASH) + return OpStatus::WRONG_TYPE; + } + + robj* hset = it->second.AsRObj(); + uint8_t* lp = (uint8_t*)hset->ptr; + + if (hset->encoding == OBJ_ENCODING_LISTPACK && !IsGoodForListpack(values, lp)) { + hashTypeConvert(hset, OBJ_ENCODING_HT); + } + unsigned created = 0; + + // TODO: we could avoid double copying by reimplementing hashTypeSet with better interface. + for (size_t i = 0; i < values.size(); i += 2) { + op_args.shard->tmp_str1 = + sdscpylen(op_args.shard->tmp_str1, values[i].data(), values[i].size()); + op_args.shard->tmp_str2 = + sdscpylen(op_args.shard->tmp_str2, values[i + 1].data(), values[i + 1].size()); + + created += !hashTypeSet(hset, op_args.shard->tmp_str1, op_args.shard->tmp_str2, HASH_SET_COPY); + } + it->second.SyncRObj(); + + return created; +} + using CI = CommandId; #define HFUNC(x) SetHandler(&HSetFamily::x) diff --git a/src/server/hset_family.h b/src/server/hset_family.h index 00d503517..112e394d9 100644 --- a/src/server/hset_family.h +++ b/src/server/hset_family.h @@ -4,6 +4,7 @@ #pragma once +#include "core/op_status.h" #include "server/common_types.h" namespace dfly { @@ -16,14 +17,19 @@ class HSetFamily { static void Register(CommandRegistry* registry); private: - static void HDel(CmdArgList args, ConnectionContext* cntx); - static void HLen(CmdArgList args, ConnectionContext* cntx); - static void HExists(CmdArgList args, ConnectionContext* cntx); - static void HGet(CmdArgList args, ConnectionContext* cntx); - static void HIncrBy(CmdArgList args, ConnectionContext* cntx); - static void HSet(CmdArgList args, ConnectionContext* cntx); - static void HSetNx(CmdArgList args, ConnectionContext* cntx); - static void HStrLen(CmdArgList args, ConnectionContext* cntx); + static void HDel(CmdArgList args, ConnectionContext* cntx); + static void HLen(CmdArgList args, ConnectionContext* cntx); + static void HExists(CmdArgList args, ConnectionContext* cntx); + static void HGet(CmdArgList args, ConnectionContext* cntx); + static void HIncrBy(CmdArgList args, ConnectionContext* cntx); + + // hmset is deprecated, we should not implement it unless we have to. + static void HSet(CmdArgList args, ConnectionContext* cntx); + static void HSetNx(CmdArgList args, ConnectionContext* cntx); + static void HStrLen(CmdArgList args, ConnectionContext* cntx); + + static OpResult OpHSet(const OpArgs& op_args, std::string_view key, CmdArgList values, + bool skip_if_exists); }; } // namespace dfly diff --git a/src/server/hset_family_test.cc b/src/server/hset_family_test.cc new file mode 100644 index 000000000..42f75b138 --- /dev/null +++ b/src/server/hset_family_test.cc @@ -0,0 +1,36 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/hset_family.h" + +#include "base/gtest.h" +#include "base/logging.h" +#include "server/test_utils.h" + +using namespace testing; +using namespace std; +using namespace util; +using namespace boost; + +namespace dfly { + +class HSetFamilyTest : public BaseFamilyTest { + protected: +}; + +TEST_F(HSetFamilyTest, HSet) { + auto resp = Run({"hset", "x", "a"}); + EXPECT_THAT(resp[0], ErrArg("wrong number")); + + resp = Run({"hset", "x", "a", "b"}); + EXPECT_THAT(resp[0], IntArg(1)); + resp = Run({"hset", "x", "a", "b"}); + EXPECT_THAT(resp[0], IntArg(0)); + resp = Run({"hset", "x", "a", "c"}); + EXPECT_THAT(resp[0], IntArg(0)); + resp = Run({"hset", "y", "a", "c", "d", "e"}); + EXPECT_THAT(resp[0], IntArg(2)); +} + +} // namespace dfly diff --git a/src/server/list_family.cc b/src/server/list_family.cc index ff2c331d2..b74875020 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -349,8 +349,8 @@ OpResult ListFamily::OpPush(const OpArgs& op_args, std::string_view ke int pos = (dir == ListDir::LEFT) ? QUICKLIST_HEAD : QUICKLIST_TAIL; for (auto v : vals) { - es->tmp_str = sdscpylen(es->tmp_str, v.data(), v.size()); - quicklistPush(ql, es->tmp_str, sdslen(es->tmp_str), pos); + es->tmp_str1 = sdscpylen(es->tmp_str1, v.data(), v.size()); + quicklistPush(ql, es->tmp_str1, sdslen(es->tmp_str1), pos); } if (new_key) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 9e8387666..edc0cb4ba 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -19,12 +19,14 @@ extern "C" { #include "server/conn_context.h" #include "server/error.h" #include "server/generic_family.h" +#include "server/hset_family.h" #include "server/list_family.h" #include "server/script_mgr.h" #include "server/server_state.h" #include "server/set_family.h" #include "server/string_family.h" #include "server/transaction.h" +#include "server/zset_family.h" #include "util/metrics/metrics.h" #include "util/uring/uring_fiber_algo.h" #include "util/varz.h" @@ -828,6 +830,8 @@ void Service::RegisterCommands() { GenericFamily::Register(®istry_); ListFamily::Register(®istry_); SetFamily::Register(®istry_); + HSetFamily::Register(®istry_); + ZSetFamily::Register(®istry_); server_family_.Register(®istry_); diff --git a/src/server/rdb_test.cc b/src/server/rdb_test.cc index bcfd83d84..d0c29454a 100644 --- a/src/server/rdb_test.cc +++ b/src/server/rdb_test.cc @@ -31,7 +31,7 @@ class RdbTest : public testing::Test { pp_->Stop(); } - static void SetUpTestCase() { + static void SetUpTestSuite() { crc64_init(); init_zmalloc_threadlocal(); } diff --git a/src/server/redis_parser_test.cc b/src/server/redis_parser_test.cc index 37a964735..b0e2c39e4 100644 --- a/src/server/redis_parser_test.cc +++ b/src/server/redis_parser_test.cc @@ -38,7 +38,7 @@ MATCHER_P(ArrArg, expected, absl::StrCat(negation ? "is not" : "is", " equal to: class RedisParserTest : public testing::Test { protected: - static void SetUpTestCase() { + static void SetUpTestSuite() { init_zmalloc_threadlocal(); } diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 63f1803b8..5e80b9696 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -216,8 +216,8 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS uint32_t res = 0; for (auto val : vals) { - es->tmp_str = sdscpylen(es->tmp_str, val.data(), val.size()); - res += setTypeAdd(o, es->tmp_str); + es->tmp_str1 = sdscpylen(es->tmp_str1, val.data(), val.size()); + res += setTypeAdd(o, es->tmp_str1); } it->second.SyncRObj(); @@ -239,8 +239,8 @@ OpResult OpRem(const OpArgs& op_args, std::string_view key, const ArgS robj* o = find_res.value()->second.AsRObj(); for (auto val : vals) { - es->tmp_str = sdscpylen(es->tmp_str, val.data(), val.size()); - res += setTypeRemove(o, es->tmp_str); + es->tmp_str1 = sdscpylen(es->tmp_str1, val.data(), val.size()); + res += setTypeRemove(o, es->tmp_str1); } if (res && setTypeSize(o) == 0) { @@ -276,8 +276,8 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { OpResult res = es->db_slice().Find(t->db_index(), k, OBJ_SET); if (res && index == 0) { CHECK(!res->is_done()); - es->tmp_str = sdscpylen(es->tmp_str, member_.data(), member_.size()); - int found_memb = setTypeIsMember(res.value()->second.AsRObj(), es->tmp_str); + es->tmp_str1 = sdscpylen(es->tmp_str1, member_.data(), member_.size()); + int found_memb = setTypeIsMember(res.value()->second.AsRObj(), es->tmp_str1); found_[0] = (found_memb == 1); } else { found_[index] = res.status(); @@ -368,9 +368,9 @@ void SetFamily::SIsMember(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { OpResult find_res = shard->db_slice().Find(t->db_index(), key, OBJ_SET); - shard->tmp_str = sdscpylen(shard->tmp_str, val.data(), val.size()); + shard->tmp_str1 = sdscpylen(shard->tmp_str1, val.data(), val.size()); - int res = setTypeIsMember(find_res.value()->second.AsRObj(), shard->tmp_str); + int res = setTypeIsMember(find_res.value()->second.AsRObj(), shard->tmp_str1); return res == 1 ? OpStatus::OK : OpStatus::INVALID_VALUE; }; diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 3e085a6e4..4b3ae13ae 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -422,7 +422,8 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) { ArgSlice largs = t->ShardArgsInShard(es->shard_id()); - CHECK(!largs.empty() && largs.size() % 2 == 0); + + DCHECK(!largs.empty() && largs.size() % 2 == 0); SetCmd::SetParams params{t->db_index()}; SetCmd sg(&es->db_slice()); diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index b591aea11..05f96a456 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -95,6 +95,10 @@ TEST_F(StringFamilyTest, Set) { } TEST_F(StringFamilyTest, MGetSet) { + Run({"mset", "z", "0"}); // single key + auto resp = Run({"mget", "z"}); // single key + EXPECT_THAT(resp, RespEq("0")); + Run({"mset", "x", "0", "b", "0"}); ASSERT_EQ(2, GetDebugInfo("IO0").shards_count); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 9af70c415..eac56e665 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -151,6 +151,7 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) { CHECK_GT(args.size(), 1U); // first entry is the command name. DCHECK_EQ(unique_shard_cnt_, 0u); + DCHECK(args_.empty()); KeyIndex key_index = DetermineKeys(cid_, args); @@ -166,9 +167,13 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) { bool single_key = !multi_ && (key_index.start + key_index.step) >= key_index.end; if (single_key) { + DCHECK_GT(key_index.step, 0u); + shard_data_.resize(1); // Single key optimization - auto key = ArgS(args, key_index.start); - args_.push_back(key); + for (unsigned j = key_index.start; j < key_index.start + key_index.step; ++j) { + args_.push_back(ArgS(args, j)); + } + string_view key = args_.front(); unique_shard_cnt_ = 1; unique_shard_id_ = Shard(key, ess_->size());