fix(json_family): Fix memory tracking for JSON (#4777)

* fix(json_family): Fix memory tracking for JSON

fixes dragonflydb#4725

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* small fix

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2025-03-20 10:23:58 +01:00 committed by GitHub
parent ce5c44b57b
commit ca65a49a0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 7 deletions

View file

@ -55,6 +55,19 @@ size_t QlMAllocSize(quicklist* ql, bool slow) {
return node_size + ql->count * 16; // we account for each member 16 bytes.
}
size_t UpdateSize(size_t size, int64_t update) {
if (update >= 0) {
return size + update;
}
int64_t result = static_cast<int64_t>(size) + update;
if (result < 0) {
DCHECK(false) << "Can't decrease " << size << " from " << -update;
LOG_EVERY_N(ERROR, 30) << "Can't decrease " << size << " from " << -update;
}
return result;
}
inline void FreeObjSet(unsigned encoding, void* ptr, MemoryResource* mr) {
switch (encoding) {
case kEncodingStrMap2: {
@ -927,10 +940,7 @@ void CompactObj::SetJson(JsonType&& j) {
void CompactObj::SetJsonSize(int64_t size) {
if (taglen_ == JSON_TAG && JsonEnconding() == kEncodingJsonCons) {
// JSON.SET or if mem hasn't changed from a JSON op then we just update.
if (size < 0) {
DCHECK(static_cast<int64_t>(u_.json_obj.cons.bytes_used) >= size);
}
u_.json_obj.cons.bytes_used += size;
u_.json_obj.cons.bytes_used = UpdateSize(u_.json_obj.cons.bytes_used, size);
}
}

View file

@ -339,6 +339,8 @@ OpResult<DbSlice::ItAndUpdater> SetJson(const OpArgs& op_args, string_view key,
op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, res.it->second);
JsonMemTracker tracker;
std::optional<JsonType> parsed_json = ShardJsonFromString(json_str);
if (!parsed_json) {
VLOG(1) << "got invalid JSON string '" << json_str << "' cannot be saved";
@ -354,7 +356,10 @@ OpResult<DbSlice::ItAndUpdater> SetJson(const OpArgs& op_args, string_view key,
} else {
res.it->second.SetJson(std::move(*parsed_json));
}
tracker.SetJsonSize(res.it->second, res.is_new);
op_args.shard->search_indices()->AddDoc(key, op_args.db_cntx, res.it->second);
return std::move(res);
}
@ -942,7 +947,6 @@ OpResult<long> OpDel(const OpArgs& op_args, string_view key, string_view path,
return 0;
}
JsonMemTracker tracker;
// FindMutable because we need to run the AutoUpdater at the end which will account
// the deltas calculated from the MemoryTracker
auto it_res = op_args.GetDbSlice().FindMutable(op_args.db_cntx, key, OBJ_JSON);
@ -953,6 +957,7 @@ OpResult<long> OpDel(const OpArgs& op_args, string_view key, string_view path,
PrimeValue& pv = it_res->it->second;
JsonType* json_val = pv.GetJson();
JsonMemTracker tracker;
absl::Cleanup update_size_on_exit([tracker, &pv]() mutable { tracker.SetJsonSize(pv, false); });
if (json_path.HoldsJsonPath()) {
@ -1329,10 +1334,8 @@ OpResult<bool> OpSet(const OpArgs& op_args, string_view key, string_view path,
}
}
JsonMemTracker mem_tracker;
auto st = SetJson(op_args, key, json_str);
RETURN_ON_BAD_STATUS(st);
mem_tracker.SetJsonSize(st->it->second, st->is_new);
return true;
}

View file

@ -2951,3 +2951,35 @@ async def test_preempt_in_atomic_section_of_heartbeat(df_factory: DflyInstanceFa
await wait_for_replicas_state(*c_replicas)
await fill_task
async def test_bug_in_json_memory_tracking(df_factory: DflyInstanceFactory):
"""
This test reproduces a bug in the JSON memory tracking.
"""
master = df_factory.create(proactor_threads=2, serialization_max_chunk_size=1)
replicas = [df_factory.create(proactor_threads=2) for i in range(2)]
# Start instances and connect clients
df_factory.start_all([master] + replicas)
c_master = master.client()
c_replicas = [replica.client() for replica in replicas]
total = 100000
await c_master.execute_command(f"DEBUG POPULATE {total} tmp 1000 TYPE SET ELEMENTS 100")
thresehold = 25000
for i in range(thresehold):
rand = random.randint(1, 4)
await c_master.execute_command(f"EXPIRE tmp:{i} {rand} NX")
seeder = StaticSeeder(key_target=100_000)
fill_task = asyncio.create_task(seeder.run(master.client()))
for replica in c_replicas:
await replica.execute_command(f"REPLICAOF LOCALHOST {master.port}")
async with async_timeout.timeout(240):
await wait_for_replicas_state(*c_replicas)
await fill_task