feat(server): Implement PFMERGE (#1180)

* feat(server): Implement PFMERGE.

* Disable lock check on failing tests.
This commit is contained in:
Chaka 2023-05-04 15:12:48 +03:00 committed by GitHub
parent cb82680aca
commit a2f68d1b3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 221 additions and 42 deletions

View file

@ -1039,32 +1039,68 @@ int64_t pfcountSingle(struct HllBufferPtr hll_ptr) {
return card; return card;
} }
/* Merge dense-encoded HLL */
static void hllMergeDense(uint8_t* max, struct HllBufferPtr to) {
uint8_t* registers = max + HLL_HDR_SIZE;
uint8_t val;
struct hllhdr* hll_hdr = (struct hllhdr*)to.hll;
for (int i = 0; i < HLL_REGISTERS; i++) {
HLL_DENSE_GET_REGISTER(val, hll_hdr->registers, i);
if (val > registers[i]) {
registers[i] = val;
}
}
}
int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count) { int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count) {
struct hllhdr* hdr; struct hllhdr* hdr;
uint8_t max[HLL_HDR_SIZE + HLL_REGISTERS], *registers; uint8_t max[HLL_HDR_SIZE + HLL_REGISTERS];
/* Compute an HLL with M[i] = MAX(M[i]_j). */ /* Compute an HLL with M[i] = MAX(M[i]_j). */
memset(max, 0, sizeof(max)); memset(max, 0, sizeof(max));
hdr = (struct hllhdr*)max; hdr = (struct hllhdr*)max;
hdr->encoding = HLL_RAW; /* Special internal-only encoding. */ hdr->encoding = HLL_RAW; /* Special internal-only encoding. */
registers = max + HLL_HDR_SIZE;
for (size_t j = 0; j < hlls_count; j++) { for (size_t j = 0; j < hlls_count; j++) {
/* Check type and size. */ /* Check type and size. */
struct HllBufferPtr hll = hlls[j]; struct HllBufferPtr hll = hlls[j];
if (isValidHLL(hll) != HLL_VALID_DENSE) if (isValidHLL(hll) != HLL_VALID_DENSE) {
return C_ERR; return C_ERR;
/* Merge dense-encoded HLL */
uint8_t val;
struct hllhdr* hll_hdr = (struct hllhdr*)hll.hll;
for (int i = 0; i < HLL_REGISTERS; i++) {
HLL_DENSE_GET_REGISTER(val, hll_hdr->registers, i);
if (val > registers[i])
registers[i] = val;
} }
hllMergeDense(max, hll);
} }
/* Compute cardinality of the resulting set. */ /* Compute cardinality of the resulting set. */
return hllCount(hdr, NULL); return hllCount(hdr, NULL);
} }
int pfmerge(struct HllBufferPtr* in_hlls, size_t in_hlls_count, struct HllBufferPtr out_hll) {
if (isValidHLL(out_hll) != HLL_VALID_DENSE) {
return C_ERR;
}
uint8_t max[HLL_REGISTERS];
/* Compute an HLL with M[i] = MAX(M[i]_j).
* We store the maximum into the max array of registers. We'll write
* it to the target variable later. */
memset(max, 0, sizeof(max));
for (size_t j = 0; j < in_hlls_count; j++) {
struct HllBufferPtr hll = in_hlls[j];
if (isValidHLL(hll) != HLL_VALID_DENSE) {
return C_ERR;
}
hllMergeDense(max, hll);
}
struct hllhdr* hdr = (struct hllhdr*)out_hll.hll;
for (size_t j = 0; j < HLL_REGISTERS; j++) {
hllDenseSet(hdr->registers, j, max[j]);
}
HLL_INVALIDATE_CACHE(hdr);
return C_OK;
}

View file

@ -51,4 +51,10 @@ int64_t pfcountSingle(struct HllBufferPtr hll_ptr);
* All `hlls` elements must be valid, dense-encoded HLLs. */ * All `hlls` elements must be valid, dense-encoded HLLs. */
int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count); int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count);
/* Merges array of HLLs pointed to be `in_hlls` of size `in_hlls_count` into `out_hll`.
* Returns 0 upon success, otherwise a negative number.
* Failure can occur when any of `in_hlls` or `out_hll` is not a dense-encoded HLL.
* `out_hll` *can* be one of the elements in `in_hlls`. */
int pfmerge(struct HllBufferPtr* in_hlls, size_t in_hlls_count, struct HllBufferPtr out_hll);
#endif #endif

View file

@ -291,6 +291,8 @@ TEST_F(DflyEngineTest, LimitMemory) {
} }
TEST_F(DflyEngineTest, FlushAll) { TEST_F(DflyEngineTest, FlushAll) {
DisableLockCheck();
auto fb0 = pp_->at(0)->LaunchFiber([&] { Run({"flushall"}); }); auto fb0 = pp_->at(0)->LaunchFiber([&] { Run({"flushall"}); });
auto fb1 = pp_->at(1)->LaunchFiber([&] { auto fb1 = pp_->at(1)->LaunchFiber([&] {

View file

@ -143,51 +143,60 @@ OpResult<int64_t> CountHllsSingle(const OpArgs& op_args, string_view key) {
} }
} }
vector<OpResult<string>> ReadValues(const OpArgs& op_args, ArgSlice keys) { OpResult<vector<string>> ReadValues(const OpArgs& op_args, ArgSlice keys) {
vector<OpResult<string>> values; try {
for (size_t i = 0; i < keys.size(); ++i) { vector<string> values;
OpResult<PrimeIterator> it = for (size_t i = 0; i < keys.size(); ++i) {
op_args.shard->db_slice().Find(op_args.db_cntx, keys[i], OBJ_STRING); OpResult<PrimeIterator> it =
if (it.ok()) { op_args.shard->db_slice().Find(op_args.db_cntx, keys[i], OBJ_STRING);
string hll; if (it.ok()) {
it.value()->second.GetString(&hll); string hll;
ConvertToDenseIfNeeded(&hll); it.value()->second.GetString(&hll);
if (isValidHLL(StringToHllPtr(hll)) != HLL_VALID_DENSE) { ConvertToDenseIfNeeded(&hll);
values.push_back(OpStatus::INVALID_VALUE); if (isValidHLL(StringToHllPtr(hll)) != HLL_VALID_DENSE) {
} else { return OpStatus::INVALID_VALUE;
values.push_back(std::move(hll)); } else {
values.push_back(std::move(hll));
}
} else if (it.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;
} }
} else if (it.status() == OpStatus::WRONG_TYPE) { }
values.push_back(OpStatus::WRONG_TYPE); return values;
} catch (const std::bad_alloc&) {
return OpStatus::OUT_OF_MEMORY;
}
}
vector<HllBufferPtr> ConvertShardVector(const vector<vector<string>>& hlls) {
vector<HllBufferPtr> ptrs;
ptrs.reserve(hlls.size());
for (auto& shard_hlls : hlls) {
for (auto& hll : shard_hlls) {
ptrs.push_back(StringToHllPtr(hll));
} }
} }
return values; return ptrs;
} }
OpResult<int64_t> PFCountMulti(CmdArgList args, ConnectionContext* cntx) { OpResult<int64_t> PFCountMulti(CmdArgList args, ConnectionContext* cntx) {
vector<vector<OpResult<string>>> hlls; vector<vector<string>> hlls;
hlls.resize(shard_set->size()); hlls.resize(shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) { auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id(); ShardId sid = shard->shard_id();
ArgSlice shard_args = t->GetShardArgs(shard->shard_id()); ArgSlice shard_args = t->GetShardArgs(shard->shard_id());
hlls[sid] = ReadValues(t->GetOpArgs(shard), shard_args); auto result = ReadValues(t->GetOpArgs(shard), shard_args);
return OpStatus::OK; if (result.ok()) {
hlls[sid] = std::move(result.value());
}
return result.status();
}; };
Transaction* trans = cntx->transaction; Transaction* trans = cntx->transaction;
trans->ScheduleSingleHop(std::move(cb)); trans->ScheduleSingleHop(std::move(cb));
vector<HllBufferPtr> ptrs; vector<HllBufferPtr> ptrs = ConvertShardVector(hlls);
ptrs.reserve(hlls.size());
for (auto& shard_hlls : hlls) {
for (auto& hll : shard_hlls) {
if (!hll.ok()) {
return hll.status();
}
ptrs.push_back(StringToHllPtr(hll.value()));
}
}
int64_t pf_count = pfcountMulti(ptrs.data(), ptrs.size()); int64_t pf_count = pfcountMulti(ptrs.data(), ptrs.size());
if (pf_count < 0) { if (pf_count < 0) {
return OpStatus::INVALID_VALUE; return OpStatus::INVALID_VALUE;
@ -211,13 +220,75 @@ void PFCount(CmdArgList args, ConnectionContext* cntx) {
} }
} }
OpResult<int> PFMergeInternal(CmdArgList args, ConnectionContext* cntx) {
vector<vector<string>> hlls;
hlls.resize(shard_set->size());
atomic_bool success = true;
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
ArgSlice shard_args = t->GetShardArgs(shard->shard_id());
auto result = ReadValues(t->GetOpArgs(shard), shard_args);
if (result.ok()) {
hlls[sid] = std::move(result.value());
} else {
success = false;
}
return result.status();
};
Transaction* trans = cntx->transaction;
trans->Schedule();
trans->Execute(std::move(cb), false);
if (!success) {
trans->Execute([](Transaction*, EngineShard*) { return OpStatus::OK; }, true);
return OpStatus::INVALID_VALUE;
}
vector<HllBufferPtr> ptrs = ConvertShardVector(hlls);
string hll;
hll.resize(getDenseHllSize());
createDenseHll(StringToHllPtr(hll));
int result = pfmerge(ptrs.data(), ptrs.size(), StringToHllPtr(hll));
auto set_cb = [&](Transaction* t, EngineShard* shard) {
string_view key = ArgS(args, 0);
const OpArgs& op_args = t->GetOpArgs(shard);
auto& db_slice = op_args.shard->db_slice();
auto [it, inserted] = db_slice.AddOrFind(t->GetDbContext(), key);
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
it->second.SetString(hll);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key, !inserted);
return OpStatus::OK;
};
trans->Execute(std::move(set_cb), true);
return result;
}
void PFMerge(CmdArgList args, ConnectionContext* cntx) {
OpResult<int> result = PFMergeInternal(args, cntx);
if (result.ok()) {
if (result.value() == 0) {
(*cntx)->SendOk();
} else {
(*cntx)->SendError(HllFamily::kInvalidHllErr);
}
} else {
HandleOpValueResult(result, cntx);
}
}
} // namespace } // namespace
void HllFamily::Register(CommandRegistry* registry) { void HllFamily::Register(CommandRegistry* registry) {
using CI = CommandId; using CI = CommandId;
*registry << CI{"PFADD", CO::WRITE, -3, 1, 1, 1}.SetHandler(PFAdd) *registry << CI{"PFADD", CO::WRITE, -3, 1, 1, 1}.SetHandler(PFAdd)
<< CI{"PFCOUNT", CO::WRITE, -2, 1, -1, 1}.SetHandler(PFCount); << CI{"PFCOUNT", CO::WRITE, -2, 1, -1, 1}.SetHandler(PFCount)
<< CI{"PFMERGE", CO::WRITE, -2, 1, -1, 1}.SetHandler(PFMerge);
} }
const char HllFamily::kInvalidHllErr[] = "Key is not a valid HyperLogLog string value."; const char HllFamily::kInvalidHllErr[] = "Key is not a valid HyperLogLog string value.";

View file

@ -95,4 +95,41 @@ TEST_F(HllFamilyTest, CountMultiple) {
EXPECT_EQ(CheckedInt({"pfcount", "key1", "key4"}), 5); EXPECT_EQ(CheckedInt({"pfcount", "key1", "key4"}), 5);
} }
TEST_F(HllFamilyTest, MergeToNew) {
EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1);
EXPECT_EQ(CheckedInt({"pfadd", "key2", "4", "5"}), 1);
EXPECT_EQ(Run({"pfmerge", "key3", "key1", "key2"}), "OK");
EXPECT_EQ(CheckedInt({"pfcount", "key3"}), 5);
}
TEST_F(HllFamilyTest, MergeToExisting) {
EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1);
EXPECT_EQ(CheckedInt({"pfadd", "key2", "4", "5"}), 1);
EXPECT_EQ(Run({"pfmerge", "key2", "key1"}), "OK");
EXPECT_EQ(CheckedInt({"pfcount", "key2"}), 5);
}
TEST_F(HllFamilyTest, MergeNonExisting) {
EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1);
EXPECT_EQ(Run({"pfmerge", "key3", "key1", "key2"}), "OK");
EXPECT_EQ(CheckedInt({"pfcount", "key3"}), 3);
}
TEST_F(HllFamilyTest, MergeOverlapping) {
EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1);
EXPECT_EQ(CheckedInt({"pfadd", "key2", "2", "3"}), 1);
EXPECT_EQ(CheckedInt({"pfadd", "key3", "1", "3"}), 1);
EXPECT_EQ(CheckedInt({"pfadd", "key4", "2", "3"}), 1);
EXPECT_EQ(CheckedInt({"pfadd", "key5", "3"}), 1);
EXPECT_EQ(Run({"pfmerge", "key6", "key1", "key2", "key3", "key4", "key5"}), "OK");
EXPECT_EQ(CheckedInt({"pfcount", "key6"}), 3);
}
TEST_F(HllFamilyTest, MergeInvalid) {
EXPECT_EQ(CheckedInt({"pfadd", "key1", "1", "2", "3"}), 1);
EXPECT_EQ(Run({"set", "key2", "..."}), "OK");
EXPECT_THAT(Run({"pfmerge", "key1", "key2"}), ErrArg(HllFamily::kInvalidHllErr));
EXPECT_EQ(CheckedInt({"pfcount", "key1"}), 3);
}
} // namespace dfly } // namespace dfly

View file

@ -124,6 +124,8 @@ TEST_F(MultiTest, Multi) {
} }
TEST_F(MultiTest, MultiGlobalCommands) { TEST_F(MultiTest, MultiGlobalCommands) {
DisableLockCheck();
ASSERT_THAT(Run({"set", "key", "val"}), "OK"); ASSERT_THAT(Run({"set", "key", "val"}), "OK");
ASSERT_THAT(Run({"multi"}), "OK"); ASSERT_THAT(Run({"multi"}), "OK");

View file

@ -368,6 +368,9 @@ TEST_P(HllRdbTest, Hll) {
EXPECT_EQ(CheckedInt({"pfadd", GetParam(), "2"}), 1); EXPECT_EQ(CheckedInt({"pfadd", GetParam(), "2"}), 1);
EXPECT_EQ(CheckedInt({"pfcount", GetParam()}), 2); EXPECT_EQ(CheckedInt({"pfcount", GetParam()}), 2);
EXPECT_EQ(Run({"pfmerge", "key3", GetParam(), "key2"}), "OK");
EXPECT_EQ(CheckedInt({"pfcount", "key3"}), 2);
} }
INSTANTIATE_TEST_SUITE_P(HllRdbTest, HllRdbTest, Values("key-sparse", "key-dense")); INSTANTIATE_TEST_SUITE_P(HllRdbTest, HllRdbTest, Values("key-sparse", "key-dense"));

View file

@ -156,7 +156,25 @@ void BaseFamilyTest::SetUp() {
LOG(INFO) << "Starting " << test_info->name(); LOG(INFO) << "Starting " << test_info->name();
} }
void BaseFamilyTest::DisableLockCheck() {
check_locks_ = false;
}
unsigned BaseFamilyTest::NumLocked() {
atomic_uint count = 0;
shard_set->RunBriefInParallel([&](EngineShard* shard) {
for (const auto& db : shard->db_slice().databases()) {
count += db->trans_locks.size();
}
});
return count;
}
void BaseFamilyTest::TearDown() { void BaseFamilyTest::TearDown() {
if (check_locks_) {
CHECK_EQ(NumLocked(), 0U);
}
service_->Shutdown(); service_->Shutdown();
service_.reset(); service_.reset();
pp_->Stop(); pp_->Stop();

View file

@ -90,9 +90,13 @@ class BaseFamilyTest : public ::testing::Test {
const facade::Connection::PubMessage::MessageData& GetPublishedMessage(std::string_view conn_id, const facade::Connection::PubMessage::MessageData& GetPublishedMessage(std::string_view conn_id,
size_t index) const; size_t index) const;
void DisableLockCheck();
static unsigned NumLocked();
std::unique_ptr<util::ProactorPool> pp_; std::unique_ptr<util::ProactorPool> pp_;
std::unique_ptr<Service> service_; std::unique_ptr<Service> service_;
unsigned num_threads_ = 3; unsigned num_threads_ = 3;
bool check_locks_ = true;
absl::flat_hash_map<std::string, std::unique_ptr<TestConnWrapper>> connections_; absl::flat_hash_map<std::string, std::unique_ptr<TestConnWrapper>> connections_;
Mutex mu_; Mutex mu_;