diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 15993d084..77871054d 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -166,7 +166,7 @@ OpStatus IncrementValue(optional prev_val, IncrByParam* param) { param->emplace(new_val); return OpStatus::OK; -}; +} OpStatus OpIncrBy(const OpArgs& op_args, string_view key, string_view field, IncrByParam* param) { auto& db_slice = op_args.GetDbSlice(); @@ -264,6 +264,62 @@ OpStatus OpIncrBy(const OpArgs& op_args, string_view key, string_view field, Inc return OpStatus::OK; } +struct KeyCleanup { + using CleanupFuncT = std::function; + explicit KeyCleanup(CleanupFuncT func, const std::string_view key_view) + : f{std::move(func)}, key{key_view} { + } + ~KeyCleanup() { + if (armed) { + f(key); + } + } + + void arm() { + armed = true; + } + + CleanupFuncT f; + std::string key; + bool armed{false}; +}; + +void DeleteKey(DbSlice& db_slice, const OpArgs& op_args, std::string_view key) { + if (auto del_it = db_slice.FindMutable(op_args.db_cntx, key, OBJ_HASH); del_it) { + del_it->post_updater.Run(); + db_slice.Del(op_args.db_cntx, del_it->it); + if (op_args.shard->journal()) { + RecordJournal(op_args, "DEL"sv, {key}); + } + } +} + +std::pair, KeyCleanup> FindReadOnly(DbSlice& db_slice, + const OpArgs& op_args, + std::string_view key) { + return std::pair{db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH), + KeyCleanup{[&](const auto& k) { DeleteKey(db_slice, op_args, k); }, key}}; +} + +// The find and contains functions perform the usual search on string maps, with the added argument +// KeyCleanup. This object is armed if the string map becomes empty during search due to keys being +// expired. An armed object on destruction removes the key which has just become empty. +StringMap::iterator Find(StringMap* sm, const std::string_view field, KeyCleanup& defer_cleanup) { + auto it = sm->Find(field); + if (sm->Empty()) { + defer_cleanup.arm(); + } + return it; +} + +bool Contains(StringMap* sm, const std::string_view field, KeyCleanup& defer_cleanup) { + auto result = sm->Contains(field); + if (sm->Empty()) { + defer_cleanup.arm(); + } + return result; +} + OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor, const ScanOpts& scan_op) { constexpr size_t HASH_TABLE_ENTRIES_FACTOR = 2; // return key/value @@ -274,7 +330,8 @@ OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t * of returning no or very few elements. (taken from redis code at db.c line 904 */ constexpr size_t INTERATION_FACTOR = 10; - auto find_res = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_HASH); + DbSlice& db_slice = op_args.GetDbSlice(); + auto [find_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key); if (!find_res) { DVLOG(1) << "ScanOp: find failed: " << find_res << ", baling out"; @@ -328,6 +385,10 @@ OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t do { *cursor = sm->Scan(*cursor, scanCb); } while (*cursor && max_iterations-- && res.size() < count); + + if (sm->Empty()) { + defer_cleanup.arm(); + } } return res; @@ -368,13 +429,15 @@ OpResult OpDel(const OpArgs& op_args, string_view key, CmdArgList valu StringMap* sm = GetStringMap(pv, op_args.db_cntx); for (auto s : values) { - bool res = sm->Erase(ToSV(s)); - if (res) { + if (sm->Erase(ToSV(s))) { ++deleted; - if (sm->UpperBoundSize() == 0) { - key_remove = true; - break; - } + } + + // Even if the previous Erase op did not erase anything, it can remove expired fields as a + // side effect. + if (sm->Empty()) { + key_remove = true; + break; } } } @@ -395,7 +458,7 @@ OpResult> OpHMGet(const OpArgs& op_args, std::string_view key, Cm DCHECK(!fields.empty()); auto& db_slice = op_args.GetDbSlice(); - auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); + auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key); if (!it_res) return it_res.status(); @@ -443,8 +506,7 @@ OpResult> OpHMGet(const OpArgs& op_args, std::string_view key, Cm StringMap* sm = GetStringMap(pv, op_args.db_cntx); for (size_t i = 0; i < fields.size(); ++i) { - auto it = sm->Find(ToSV(fields[i])); - if (it != sm->end()) { + if (auto it = Find(sm, ToSV(fields[i]), defer_cleanup); it != sm->end()) { result[i].emplace(it->second, sdslen(it->second)); } } @@ -468,7 +530,7 @@ OpResult OpLen(const OpArgs& op_args, string_view key) { OpResult OpExist(const OpArgs& op_args, string_view key, string_view field) { auto& db_slice = op_args.GetDbSlice(); - auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); + auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key); if (!it_res) { if (it_res.status() == OpStatus::KEY_NOTFOUND) @@ -486,13 +548,13 @@ OpResult OpExist(const OpArgs& op_args, string_view key, string_view field) DCHECK_EQ(kEncodingStrMap2, pv.Encoding()); StringMap* sm = GetStringMap(pv, op_args.db_cntx); - - return sm->Contains(field) ? 1 : 0; + return Contains(sm, field, defer_cleanup) ? 1 : 0; }; OpResult OpGet(const OpArgs& op_args, string_view key, string_view field) { auto& db_slice = op_args.GetDbSlice(); - auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); + auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key); + if (!it_res) return it_res.status(); @@ -510,12 +572,11 @@ OpResult OpGet(const OpArgs& op_args, string_view key, string_view field DCHECK_EQ(pv.Encoding(), kEncodingStrMap2); StringMap* sm = GetStringMap(pv, op_args.db_cntx); - auto it = sm->Find(field); + if (const auto it = Find(sm, field, defer_cleanup); it != sm->end()) { + return string(it->second, sdslen(it->second)); + } - if (it == sm->end()) - return OpStatus::KEY_NOTFOUND; - - return string(it->second, sdslen(it->second)); + return OpStatus::KEY_NOTFOUND; } OpResult> OpGetAll(const OpArgs& op_args, string_view key, uint8_t mask) { @@ -570,10 +631,7 @@ OpResult> OpGetAll(const OpArgs& op_args, string_view key, uint8_ // and the enconding is guaranteed to be a DenseSet since we only support expiring // value with that enconding. if (res.empty()) { - // post_updater will run immediately - auto it = db_slice.FindMutable(op_args.db_cntx, key).it; - - db_slice.Del(op_args.db_cntx, it); + DeleteKey(db_slice, op_args, key); } return res; @@ -581,7 +639,7 @@ OpResult> OpGetAll(const OpArgs& op_args, string_view key, uint8_ OpResult OpStrLen(const OpArgs& op_args, string_view key, string_view field) { auto& db_slice = op_args.GetDbSlice(); - auto it_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_HASH); + auto [it_res, defer_cleanup] = FindReadOnly(db_slice, op_args, key); if (!it_res) { if (it_res.status() == OpStatus::KEY_NOTFOUND) @@ -601,7 +659,7 @@ OpResult OpStrLen(const OpArgs& op_args, string_view key, string_view fi DCHECK_EQ(pv.Encoding(), kEncodingStrMap2); StringMap* sm = GetStringMap(pv, op_args.db_cntx); - auto it = sm->Find(field); + auto it = Find(sm, field, defer_cleanup); return it != sm->end() ? sdslen(it->second) : 0; } diff --git a/src/server/hset_family_test.cc b/src/server/hset_family_test.cc index bd21fa7a9..12497dba2 100644 --- a/src/server/hset_family_test.cc +++ b/src/server/hset_family_test.cc @@ -521,4 +521,24 @@ TEST_F(HSetFamilyTest, ScanAfterExpireSet) { EXPECT_THAT(vec, Contains("avalue").Times(1)); } +TEST_F(HSetFamilyTest, KeyRemovedWhenEmpty) { + auto test_cmd = [&](const std::function& f, const std::string_view tag) { + EXPECT_THAT(Run({"HSET", "a", "afield", "avalue"}), IntArg(1)); + EXPECT_THAT(Run({"HEXPIRE", "a", "1", "FIELDS", "1", "afield"}), IntArg(1)); + AdvanceTime(1000); + + EXPECT_THAT(Run({"EXISTS", "a"}), IntArg(1)); + f(); + EXPECT_THAT(Run({"EXISTS", "a"}), IntArg(0)) << "failed when testing " << tag; + }; + + test_cmd([&] { EXPECT_THAT(Run({"HGET", "a", "afield"}), ArgType(RespExpr::NIL)); }, "HGET"); + test_cmd([&] { EXPECT_THAT(Run({"HGETALL", "a"}), RespArray(ElementsAre())); }, "HGETALL"); + test_cmd([&] { EXPECT_THAT(Run({"HDEL", "a", "afield"}), IntArg(0)); }, "HDEL"); + test_cmd([&] { EXPECT_THAT(Run({"HSCAN", "a", "0"}).GetVec()[0], "0"); }, "HSCAN"); + test_cmd([&] { EXPECT_THAT(Run({"HMGET", "a", "afield"}), ArgType(RespExpr::NIL)); }, "HMGET"); + test_cmd([&] { EXPECT_THAT(Run({"HEXISTS", "a", "afield"}), IntArg(0)); }, "HEXISTS"); + test_cmd([&] { EXPECT_THAT(Run({"HSTRLEN", "a", "afield"}), IntArg(0)); }, "HSTRLEN"); +} + } // namespace dfly