From a2f68d1b3b14d4e412fb96141915800cb94e1907 Mon Sep 17 00:00:00 2001 From: Chaka Date: Thu, 4 May 2023 15:12:48 +0300 Subject: [PATCH] feat(server): Implement PFMERGE (#1180) * feat(server): Implement PFMERGE. * Disable lock check on failing tests. --- src/redis/hyperloglog.c | 60 ++++++++++++---- src/redis/hyperloglog.h | 6 ++ src/server/dragonfly_test.cc | 2 + src/server/hll_family.cc | 131 ++++++++++++++++++++++++++-------- src/server/hll_family_test.cc | 37 ++++++++++ src/server/multi_test.cc | 2 + src/server/rdb_test.cc | 3 + src/server/test_utils.cc | 18 +++++ src/server/test_utils.h | 4 ++ 9 files changed, 221 insertions(+), 42 deletions(-) diff --git a/src/redis/hyperloglog.c b/src/redis/hyperloglog.c index 470e8628f..8d9b67ef8 100644 --- a/src/redis/hyperloglog.c +++ b/src/redis/hyperloglog.c @@ -1039,32 +1039,68 @@ int64_t pfcountSingle(struct HllBufferPtr hll_ptr) { return card; } +/* Merge dense-encoded HLL */ +static void hllMergeDense(uint8_t* max, struct HllBufferPtr to) { + uint8_t* registers = max + HLL_HDR_SIZE; + uint8_t val; + struct hllhdr* hll_hdr = (struct hllhdr*)to.hll; + + for (int i = 0; i < HLL_REGISTERS; i++) { + HLL_DENSE_GET_REGISTER(val, hll_hdr->registers, i); + if (val > registers[i]) { + registers[i] = val; + } + } +} + int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count) { struct hllhdr* hdr; - uint8_t max[HLL_HDR_SIZE + HLL_REGISTERS], *registers; + uint8_t max[HLL_HDR_SIZE + HLL_REGISTERS]; /* Compute an HLL with M[i] = MAX(M[i]_j). */ memset(max, 0, sizeof(max)); hdr = (struct hllhdr*)max; hdr->encoding = HLL_RAW; /* Special internal-only encoding. */ - registers = max + HLL_HDR_SIZE; for (size_t j = 0; j < hlls_count; j++) { /* Check type and size. */ struct HllBufferPtr hll = hlls[j]; - if (isValidHLL(hll) != HLL_VALID_DENSE) + if (isValidHLL(hll) != HLL_VALID_DENSE) { return C_ERR; - - /* Merge dense-encoded HLL */ - uint8_t val; - struct hllhdr* hll_hdr = (struct hllhdr*)hll.hll; - - for (int i = 0; i < HLL_REGISTERS; i++) { - HLL_DENSE_GET_REGISTER(val, hll_hdr->registers, i); - if (val > registers[i]) - registers[i] = val; } + + hllMergeDense(max, hll); } /* Compute cardinality of the resulting set. */ return hllCount(hdr, NULL); } + +int pfmerge(struct HllBufferPtr* in_hlls, size_t in_hlls_count, struct HllBufferPtr out_hll) { + if (isValidHLL(out_hll) != HLL_VALID_DENSE) { + return C_ERR; + } + + uint8_t max[HLL_REGISTERS]; + + /* Compute an HLL with M[i] = MAX(M[i]_j). + * We store the maximum into the max array of registers. We'll write + * it to the target variable later. */ + memset(max, 0, sizeof(max)); + + for (size_t j = 0; j < in_hlls_count; j++) { + struct HllBufferPtr hll = in_hlls[j]; + if (isValidHLL(hll) != HLL_VALID_DENSE) { + return C_ERR; + } + + hllMergeDense(max, hll); + } + + struct hllhdr* hdr = (struct hllhdr*)out_hll.hll; + for (size_t j = 0; j < HLL_REGISTERS; j++) { + hllDenseSet(hdr->registers, j, max[j]); + } + HLL_INVALIDATE_CACHE(hdr); + + return C_OK; +} diff --git a/src/redis/hyperloglog.h b/src/redis/hyperloglog.h index d03f15e7f..8eeb81635 100644 --- a/src/redis/hyperloglog.h +++ b/src/redis/hyperloglog.h @@ -51,4 +51,10 @@ int64_t pfcountSingle(struct HllBufferPtr hll_ptr); * All `hlls` elements must be valid, dense-encoded HLLs. */ int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count); +/* Merges array of HLLs pointed to be `in_hlls` of size `in_hlls_count` into `out_hll`. + * Returns 0 upon success, otherwise a negative number. + * Failure can occur when any of `in_hlls` or `out_hll` is not a dense-encoded HLL. + * `out_hll` *can* be one of the elements in `in_hlls`. */ +int pfmerge(struct HllBufferPtr* in_hlls, size_t in_hlls_count, struct HllBufferPtr out_hll); + #endif diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index a5a6b7662..2b55f9428 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -291,6 +291,8 @@ TEST_F(DflyEngineTest, LimitMemory) { } TEST_F(DflyEngineTest, FlushAll) { + DisableLockCheck(); + auto fb0 = pp_->at(0)->LaunchFiber([&] { Run({"flushall"}); }); auto fb1 = pp_->at(1)->LaunchFiber([&] { diff --git a/src/server/hll_family.cc b/src/server/hll_family.cc index d42f00734..be5402ecf 100644 --- a/src/server/hll_family.cc +++ b/src/server/hll_family.cc @@ -143,51 +143,60 @@ OpResult CountHllsSingle(const OpArgs& op_args, string_view key) { } } -vector> ReadValues(const OpArgs& op_args, ArgSlice keys) { - vector> values; - for (size_t i = 0; i < keys.size(); ++i) { - OpResult it = - op_args.shard->db_slice().Find(op_args.db_cntx, keys[i], OBJ_STRING); - if (it.ok()) { - string hll; - it.value()->second.GetString(&hll); - ConvertToDenseIfNeeded(&hll); - if (isValidHLL(StringToHllPtr(hll)) != HLL_VALID_DENSE) { - values.push_back(OpStatus::INVALID_VALUE); - } else { - values.push_back(std::move(hll)); +OpResult> ReadValues(const OpArgs& op_args, ArgSlice keys) { + try { + vector values; + for (size_t i = 0; i < keys.size(); ++i) { + OpResult it = + op_args.shard->db_slice().Find(op_args.db_cntx, keys[i], OBJ_STRING); + if (it.ok()) { + string hll; + it.value()->second.GetString(&hll); + ConvertToDenseIfNeeded(&hll); + if (isValidHLL(StringToHllPtr(hll)) != HLL_VALID_DENSE) { + return OpStatus::INVALID_VALUE; + } else { + values.push_back(std::move(hll)); + } + } else if (it.status() == OpStatus::WRONG_TYPE) { + return OpStatus::WRONG_TYPE; } - } else if (it.status() == OpStatus::WRONG_TYPE) { - values.push_back(OpStatus::WRONG_TYPE); + } + return values; + } catch (const std::bad_alloc&) { + return OpStatus::OUT_OF_MEMORY; + } +} + +vector ConvertShardVector(const vector>& hlls) { + vector ptrs; + ptrs.reserve(hlls.size()); + for (auto& shard_hlls : hlls) { + for (auto& hll : shard_hlls) { + ptrs.push_back(StringToHllPtr(hll)); } } - return values; + return ptrs; } OpResult PFCountMulti(CmdArgList args, ConnectionContext* cntx) { - vector>> hlls; + vector> hlls; hlls.resize(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { ShardId sid = shard->shard_id(); ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); - hlls[sid] = ReadValues(t->GetOpArgs(shard), shard_args); - return OpStatus::OK; + auto result = ReadValues(t->GetOpArgs(shard), shard_args); + if (result.ok()) { + hlls[sid] = std::move(result.value()); + } + return result.status(); }; Transaction* trans = cntx->transaction; trans->ScheduleSingleHop(std::move(cb)); - vector ptrs; - ptrs.reserve(hlls.size()); - for (auto& shard_hlls : hlls) { - for (auto& hll : shard_hlls) { - if (!hll.ok()) { - return hll.status(); - } - ptrs.push_back(StringToHllPtr(hll.value())); - } - } + vector ptrs = ConvertShardVector(hlls); int64_t pf_count = pfcountMulti(ptrs.data(), ptrs.size()); if (pf_count < 0) { return OpStatus::INVALID_VALUE; @@ -211,13 +220,75 @@ void PFCount(CmdArgList args, ConnectionContext* cntx) { } } +OpResult PFMergeInternal(CmdArgList args, ConnectionContext* cntx) { + vector> hlls; + hlls.resize(shard_set->size()); + + atomic_bool success = true; + auto cb = [&](Transaction* t, EngineShard* shard) { + ShardId sid = shard->shard_id(); + ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); + auto result = ReadValues(t->GetOpArgs(shard), shard_args); + if (result.ok()) { + hlls[sid] = std::move(result.value()); + } else { + success = false; + } + return result.status(); + }; + + Transaction* trans = cntx->transaction; + trans->Schedule(); + trans->Execute(std::move(cb), false); + + if (!success) { + trans->Execute([](Transaction*, EngineShard*) { return OpStatus::OK; }, true); + return OpStatus::INVALID_VALUE; + } + + vector ptrs = ConvertShardVector(hlls); + + string hll; + hll.resize(getDenseHllSize()); + createDenseHll(StringToHllPtr(hll)); + int result = pfmerge(ptrs.data(), ptrs.size(), StringToHllPtr(hll)); + + auto set_cb = [&](Transaction* t, EngineShard* shard) { + string_view key = ArgS(args, 0); + const OpArgs& op_args = t->GetOpArgs(shard); + auto& db_slice = op_args.shard->db_slice(); + auto [it, inserted] = db_slice.AddOrFind(t->GetDbContext(), key); + db_slice.PreUpdate(op_args.db_cntx.db_index, it); + it->second.SetString(hll); + db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, !inserted); + return OpStatus::OK; + }; + trans->Execute(std::move(set_cb), true); + + return result; +} + +void PFMerge(CmdArgList args, ConnectionContext* cntx) { + OpResult result = PFMergeInternal(args, cntx); + if (result.ok()) { + if (result.value() == 0) { + (*cntx)->SendOk(); + } else { + (*cntx)->SendError(HllFamily::kInvalidHllErr); + } + } else { + HandleOpValueResult(result, cntx); + } +} + } // namespace void HllFamily::Register(CommandRegistry* registry) { using CI = CommandId; *registry << CI{"PFADD", CO::WRITE, -3, 1, 1, 1}.SetHandler(PFAdd) - << CI{"PFCOUNT", CO::WRITE, -2, 1, -1, 1}.SetHandler(PFCount); + << CI{"PFCOUNT", CO::WRITE, -2, 1, -1, 1}.SetHandler(PFCount) + << CI{"PFMERGE", CO::WRITE, -2, 1, -1, 1}.SetHandler(PFMerge); } const char HllFamily::kInvalidHllErr[] = "Key is not a valid HyperLogLog string value."; diff --git a/src/server/hll_family_test.cc b/src/server/hll_family_test.cc index aa68fbd7d..33117ccdb 100644 --- a/src/server/hll_family_test.cc +++ b/src/server/hll_family_test.cc @@ -95,4 +95,41 @@ TEST_F(HllFamilyTest, CountMultiple) { EXPECT_EQ(CheckedInt({"pfcount", "key1", "key4"}), 5); } +TEST_F(HllFamilyTest, MergeToNew) { + EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1); + EXPECT_EQ(CheckedInt({"pfadd", "key2", "4", "5"}), 1); + EXPECT_EQ(Run({"pfmerge", "key3", "key1", "key2"}), "OK"); + EXPECT_EQ(CheckedInt({"pfcount", "key3"}), 5); +} + +TEST_F(HllFamilyTest, MergeToExisting) { + EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1); + EXPECT_EQ(CheckedInt({"pfadd", "key2", "4", "5"}), 1); + EXPECT_EQ(Run({"pfmerge", "key2", "key1"}), "OK"); + EXPECT_EQ(CheckedInt({"pfcount", "key2"}), 5); +} + +TEST_F(HllFamilyTest, MergeNonExisting) { + EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1); + EXPECT_EQ(Run({"pfmerge", "key3", "key1", "key2"}), "OK"); + EXPECT_EQ(CheckedInt({"pfcount", "key3"}), 3); +} + +TEST_F(HllFamilyTest, MergeOverlapping) { + EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1); + EXPECT_EQ(CheckedInt({"pfadd", "key2", "2", "3"}), 1); + EXPECT_EQ(CheckedInt({"pfadd", "key3", "1", "3"}), 1); + EXPECT_EQ(CheckedInt({"pfadd", "key4", "2", "3"}), 1); + EXPECT_EQ(CheckedInt({"pfadd", "key5", "3"}), 1); + EXPECT_EQ(Run({"pfmerge", "key6", "key1", "key2", "key3", "key4", "key5"}), "OK"); + EXPECT_EQ(CheckedInt({"pfcount", "key6"}), 3); +} + +TEST_F(HllFamilyTest, MergeInvalid) { + EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1); + EXPECT_EQ(Run({"set", "key2", "..."}), "OK"); + EXPECT_THAT(Run({"pfmerge", "key1", "key2"}), ErrArg(HllFamily::kInvalidHllErr)); + EXPECT_EQ(CheckedInt({"pfcount", "key1"}), 3); +} + } // namespace dfly diff --git a/src/server/multi_test.cc b/src/server/multi_test.cc index 965c24cd7..ee242ae77 100644 --- a/src/server/multi_test.cc +++ b/src/server/multi_test.cc @@ -124,6 +124,8 @@ TEST_F(MultiTest, Multi) { } TEST_F(MultiTest, MultiGlobalCommands) { + DisableLockCheck(); + ASSERT_THAT(Run({"set", "key", "val"}), "OK"); ASSERT_THAT(Run({"multi"}), "OK"); diff --git a/src/server/rdb_test.cc b/src/server/rdb_test.cc index b4369954f..e174abb3c 100644 --- a/src/server/rdb_test.cc +++ b/src/server/rdb_test.cc @@ -368,6 +368,9 @@ TEST_P(HllRdbTest, Hll) { EXPECT_EQ(CheckedInt({"pfadd", GetParam(), "2"}), 1); EXPECT_EQ(CheckedInt({"pfcount", GetParam()}), 2); + + EXPECT_EQ(Run({"pfmerge", "key3", GetParam(), "key2"}), "OK"); + EXPECT_EQ(CheckedInt({"pfcount", "key3"}), 2); } INSTANTIATE_TEST_SUITE_P(HllRdbTest, HllRdbTest, Values("key-sparse", "key-dense")); diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index 9af7d18b9..ba17a1561 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -156,7 +156,25 @@ void BaseFamilyTest::SetUp() { LOG(INFO) << "Starting " << test_info->name(); } +void BaseFamilyTest::DisableLockCheck() { + check_locks_ = false; +} + +unsigned BaseFamilyTest::NumLocked() { + atomic_uint count = 0; + shard_set->RunBriefInParallel([&](EngineShard* shard) { + for (const auto& db : shard->db_slice().databases()) { + count += db->trans_locks.size(); + } + }); + return count; +} + void BaseFamilyTest::TearDown() { + if (check_locks_) { + CHECK_EQ(NumLocked(), 0U); + } + service_->Shutdown(); service_.reset(); pp_->Stop(); diff --git a/src/server/test_utils.h b/src/server/test_utils.h index c7f5f059b..2088a407b 100644 --- a/src/server/test_utils.h +++ b/src/server/test_utils.h @@ -90,9 +90,13 @@ class BaseFamilyTest : public ::testing::Test { const facade::Connection::PubMessage::MessageData& GetPublishedMessage(std::string_view conn_id, size_t index) const; + void DisableLockCheck(); + static unsigned NumLocked(); + std::unique_ptr pp_; std::unique_ptr service_; unsigned num_threads_ = 3; + bool check_locks_ = true; absl::flat_hash_map> connections_; Mutex mu_;