// Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "server/rdb_load.h" extern "C" { #include "redis/intset.h" #include "redis/listpack.h" #include "redis/lzfP.h" /* LZF compression library */ #include "redis/rdb.h" #include "redis/stream.h" #include "redis/util.h" #include "redis/ziplist.h" #include "redis/zmalloc.h" #include "redis/zset.h" } #include #include #include #include #include "base/endian.h" #include "base/flags.h" #include "base/logging.h" #include "core/string_set.h" #include "server/engine_shard_set.h" #include "server/error.h" #include "server/hset_family.h" #include "server/journal/executor.h" #include "server/journal/serializer.h" #include "server/main_service.h" #include "server/rdb_extensions.h" #include "server/script_mgr.h" #include "server/serializer_commons.h" #include "server/server_state.h" #include "server/set_family.h" #include "strings/human_readable.h" ABSL_DECLARE_FLAG(int32_t, list_max_listpack_size); ABSL_DECLARE_FLAG(int32_t, list_compress_depth); ABSL_DECLARE_FLAG(uint32_t, dbnum); ABSL_DECLARE_FLAG(bool, use_set2); namespace dfly { using namespace std; using base::IoBuf; using nonstd::make_unexpected; using namespace util; using absl::GetFlag; using rdb::errc; namespace { class error_category : public std::error_category { public: const char* name() const noexcept final { return "dragonfly.rdbload"; } string message(int ev) const final; error_condition default_error_condition(int ev) const noexcept final; bool equivalent(int ev, const error_condition& condition) const noexcept final { return condition.value() == ev && &condition.category() == this; } bool equivalent(const error_code& error, int ev) const noexcept final { return error.value() == ev && &error.category() == this; } }; string error_category::message(int ev) const { switch (ev) { case errc::wrong_signature: return "Wrong signature while trying to load from rdb file"; default: return absl::StrCat("Internal error when loading RDB file ", ev); break; } } error_condition error_category::default_error_condition(int ev) const noexcept { return error_condition{ev, *this}; } error_category rdb_category; inline error_code RdbError(errc ev) { return error_code{ev, rdb_category}; } inline auto Unexpected(errc ev) { return make_unexpected(RdbError(ev)); } static const error_code kOk; struct ZiplistCbArgs { long count; dict* fields; unsigned char** lp; }; /* callback for hashZiplistConvertAndValidateIntegrity. * Check that the ziplist doesn't have duplicate hash field names. * The ziplist element pointed by 'p' will be converted and stored into listpack. */ int ziplistPairsEntryConvertAndValidate(unsigned char* p, unsigned int head_count, void* userdata) { unsigned char* str; unsigned int slen; long long vll; ZiplistCbArgs* data = (ZiplistCbArgs*)userdata; if (data->fields == NULL) { data->fields = dictCreate(&hashDictType); dictExpand(data->fields, head_count / 2); } if (!ziplistGet(p, &str, &slen, &vll)) return 0; /* Even records are field names, add to dict and check that's not a dup */ if (((data->count) & 1) == 0) { sds field = str ? sdsnewlen(str, slen) : sdsfromlonglong(vll); if (dictAdd(data->fields, field, NULL) != DICT_OK) { /* Duplicate, return an error */ sdsfree(field); return 0; } } if (str) { *(data->lp) = lpAppend(*(data->lp), (unsigned char*)str, slen); } else { *(data->lp) = lpAppendInteger(*(data->lp), vll); } (data->count)++; return 1; } /* callback for ziplistValidateIntegrity. * The ziplist element pointed by 'p' will be converted and stored into listpack. */ int ziplistEntryConvertAndValidate(unsigned char* p, unsigned int head_count, void* userdata) { unsigned char* str; unsigned int slen; long long vll; unsigned char** lp = (unsigned char**)userdata; if (!ziplistGet(p, &str, &slen, &vll)) return 0; if (str) *lp = lpAppend(*lp, (unsigned char*)str, slen); else *lp = lpAppendInteger(*lp, vll); return 1; } /* Validate the integrity of the data structure while converting it to * listpack and storing it at 'lp'. * The function is safe to call on non-validated ziplists, it returns 0 * when encounter an integrity validation issue. */ int ziplistPairsConvertAndValidateIntegrity(const uint8_t* zl, size_t size, unsigned char** lp) { /* Keep track of the field names to locate duplicate ones */ ZiplistCbArgs data = {0, NULL, lp}; int ret = ziplistValidateIntegrity(const_cast(zl), size, 1, ziplistPairsEntryConvertAndValidate, &data); /* make sure we have an even number of records. */ if (data.count & 1) ret = 0; if (data.fields) dictRelease(data.fields); return ret; } bool resizeStringSet(robj* set, size_t size, bool use_set2) { if (use_set2) { ((dfly::StringSet*)set->ptr)->Reserve(size); return true; } else { return dictTryExpand((dict*)set->ptr, size) != DICT_OK; } } } // namespace class DecompressImpl { public: DecompressImpl() : uncompressed_mem_buf_{16_KB} { } ~DecompressImpl() { } virtual io::Result Decompress(std::string_view str) = 0; protected: base::IoBuf uncompressed_mem_buf_; }; class ZstdDecompress : public DecompressImpl { public: ZstdDecompress() { dctx_ = ZSTD_createDCtx(); } ~ZstdDecompress() { ZSTD_freeDCtx(dctx_); } io::Result Decompress(std::string_view str); private: ZSTD_DCtx* dctx_; }; io::Result ZstdDecompress::Decompress(std::string_view str) { // Prepare membuf memory to uncompressed string. auto uncomp_size = ZSTD_getFrameContentSize(str.data(), str.size()); if (uncomp_size == ZSTD_CONTENTSIZE_UNKNOWN) { LOG(ERROR) << "Zstd compression missing frame content size"; return Unexpected(errc::invalid_encoding); } if (uncomp_size == ZSTD_CONTENTSIZE_ERROR) { LOG(ERROR) << "Invalid ZSTD compressed string"; return Unexpected(errc::invalid_encoding); } uncompressed_mem_buf_.Reserve(uncomp_size + 1); // Uncompress string to membuf IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); if (dest.size() < uncomp_size) { return Unexpected(errc::out_of_memory); } size_t const d_size = ZSTD_decompressDCtx(dctx_, dest.data(), dest.size(), str.data(), str.size()); if (d_size == 0 || d_size != uncomp_size) { LOG(ERROR) << "Invalid ZSTD compressed string"; return Unexpected(errc::rdb_file_corrupted); } uncompressed_mem_buf_.CommitWrite(d_size); // Add opcode of compressed blob end to membuf. dest = uncompressed_mem_buf_.AppendBuffer(); if (dest.size() < 1) { return Unexpected(errc::out_of_memory); } dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; uncompressed_mem_buf_.CommitWrite(1); return &uncompressed_mem_buf_; } class Lz4Decompress : public DecompressImpl { public: Lz4Decompress() { auto result = LZ4F_createDecompressionContext(&dctx_, LZ4F_VERSION); CHECK(!LZ4F_isError(result)); } ~Lz4Decompress() { auto result = LZ4F_freeDecompressionContext(dctx_); CHECK(!LZ4F_isError(result)); } io::Result Decompress(std::string_view str); private: LZ4F_dctx* dctx_; }; io::Result Lz4Decompress::Decompress(std::string_view data) { LZ4F_frameInfo_t frame_info; size_t frame_size = data.size(); // Get content size from frame data size_t consumed = frame_size; // The nb of bytes consumed from data will be written into consumed size_t res = LZ4F_getFrameInfo(dctx_, &frame_info, data.data(), &consumed); if (LZ4F_isError(res)) { return make_unexpected(error_code{int(res), generic_category()}); } if (frame_info.contentSize == 0) { LOG(ERROR) << "Missing frame content size"; return Unexpected(errc::rdb_file_corrupted); } // reserve place for uncompressed data and end opcode size_t reserve = frame_info.contentSize + 1; uncompressed_mem_buf_.Reserve(reserve); IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); if (dest.size() < reserve) { return Unexpected(errc::out_of_memory); } // Uncompress data to membuf string_view src = data.substr(consumed); size_t src_size = src.size(); size_t ret = 1; while (ret != 0) { IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); size_t dest_capacity = dest.size(); // It will read up to src_size bytes from src, // and decompress data into dest, of capacity dest_capacity // The nb of bytes consumed from src will be written into src_size // The nb of bytes decompressed into dest will be written into dest_capacity ret = LZ4F_decompress(dctx_, dest.data(), &dest_capacity, src.data(), &src_size, nullptr); if (LZ4F_isError(ret)) { return make_unexpected(error_code{int(ret), generic_category()}); } consumed += src_size; uncompressed_mem_buf_.CommitWrite(dest_capacity); src = src.substr(src_size); src_size = src.size(); } if (consumed != frame_size) { return Unexpected(errc::rdb_file_corrupted); } if (uncompressed_mem_buf_.InputLen() != frame_info.contentSize) { return Unexpected(errc::rdb_file_corrupted); } // Add opcode of compressed blob end to membuf. dest = uncompressed_mem_buf_.AppendBuffer(); if (dest.size() < 1) { return Unexpected(errc::out_of_memory); } dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; uncompressed_mem_buf_.CommitWrite(1); return &uncompressed_mem_buf_; } class RdbLoaderBase::OpaqueObjLoader { public: OpaqueObjLoader(int rdb_type, PrimeValue* pv) : rdb_type_(rdb_type), pv_(pv) { } void operator()(robj* o) { pv_->ImportRObj(o); } void operator()(long long val) { pv_->SetInt(val); } void operator()(const base::PODArray& str); void operator()(const LzfString& lzfstr); void operator()(const unique_ptr& ptr); std::error_code ec() const { return ec_; } private: void CreateSet(const LoadTrace* ltrace); void CreateHMap(const LoadTrace* ltrace); void CreateList(const LoadTrace* ltrace); void CreateZSet(const LoadTrace* ltrace); void CreateStream(const LoadTrace* ltrace); void HandleBlob(string_view blob); sds ToSds(const RdbVariant& obj); string_view ToSV(const RdbVariant& obj); std::error_code ec_; int rdb_type_; base::PODArray tset_blob_; PrimeValue* pv_; }; RdbLoaderBase::RdbLoaderBase() : origin_mem_buf_{16_KB} { mem_buf_ = &origin_mem_buf_; } RdbLoaderBase::~RdbLoaderBase() { } void RdbLoaderBase::OpaqueObjLoader::operator()(const base::PODArray& str) { string_view sv(str.data(), str.size()); HandleBlob(sv); } void RdbLoaderBase::OpaqueObjLoader::operator()(const LzfString& lzfstr) { string tmp(lzfstr.uncompressed_len, '\0'); if (lzf_decompress(lzfstr.compressed_blob.data(), lzfstr.compressed_blob.size(), tmp.data(), tmp.size()) == 0) { LOG(ERROR) << "Invalid LZF compressed string"; ec_ = RdbError(errc::rdb_file_corrupted); return; } HandleBlob(tmp); } void RdbLoaderBase::OpaqueObjLoader::operator()(const unique_ptr& ptr) { switch (rdb_type_) { case RDB_TYPE_SET: CreateSet(ptr.get()); break; case RDB_TYPE_HASH: CreateHMap(ptr.get()); break; case RDB_TYPE_LIST_QUICKLIST: case RDB_TYPE_LIST_QUICKLIST_2: CreateList(ptr.get()); break; case RDB_TYPE_ZSET: case RDB_TYPE_ZSET_2: CreateZSet(ptr.get()); break; case RDB_TYPE_STREAM_LISTPACKS: CreateStream(ptr.get()); break; default: LOG(FATAL) << "Unsupported rdb type " << rdb_type_; } } void RdbLoaderBase::OpaqueObjLoader::CreateSet(const LoadTrace* ltrace) { size_t len = ltrace->arr.size(); bool is_intset = true; if (len <= SetFamily::MaxIntsetEntries()) { for (size_t i = 0; i < len; i++) { if (!holds_alternative(ltrace->arr[i].rdb_var)) { is_intset = false; break; } } } else { /* Use a regular set when there are too many entries. */ is_intset = false; } robj* res = nullptr; sds sdsele = nullptr; auto cleanup = absl::MakeCleanup([&] { if (sdsele) sdsfree(sdsele); decrRefCount(res); }); if (is_intset) { res = createIntsetObject(); long long llval; for (size_t i = 0; i < len; i++) { llval = get(ltrace->arr[i].rdb_var); uint8_t success; res->ptr = intsetAdd((intset*)res->ptr, llval, &success); if (!success) { LOG(ERROR) << "Duplicate set members detected"; ec_ = RdbError(errc::duplicate_key); return; } } } else { bool use_set2 = GetFlag(FLAGS_use_set2); if (use_set2) { StringSet* set = new StringSet{CompactObj::memory_resource()}; res = createObject(OBJ_SET, set); res->encoding = OBJ_ENCODING_HT; } else { res = createSetObject(); } // TODO: to move this logic to set_family similarly to ConvertToStrSet. /* It's faster to expand the dict to the right size asap in order * to avoid rehashing */ if (len > DICT_HT_INITIAL_SIZE && !resizeStringSet(res, len, use_set2)) { LOG(ERROR) << "OOM in dictTryExpand " << len; ec_ = RdbError(errc::out_of_memory); return; } if (use_set2) { for (size_t i = 0; i < len; i++) { sdsele = ToSds(ltrace->arr[i].rdb_var); if (!sdsele) return; if (!((StringSet*)res->ptr)->AddSds(sdsele)) { LOG(ERROR) << "Duplicate set members detected"; ec_ = RdbError(errc::duplicate_key); return; } } } else { for (size_t i = 0; i < len; i++) { sdsele = ToSds(ltrace->arr[i].rdb_var); if (!sdsele) return; if (dictAdd((dict*)res->ptr, sdsele, NULL) != DICT_OK) { LOG(ERROR) << "Duplicate set members detected"; ec_ = RdbError(errc::duplicate_key); return; } } } } pv_->ImportRObj(res); std::move(cleanup).Cancel(); } void RdbLoaderBase::OpaqueObjLoader::CreateHMap(const LoadTrace* ltrace) { size_t len = ltrace->arr.size() / 2; /* Too many entries? Use a hash table right from the start. */ bool keep_lp = (len <= server.hash_max_listpack_entries); size_t lp_size = 0; if (keep_lp) { for (const auto& str : ltrace->arr) { size_t str_len = StrLen(str.rdb_var); lp_size += str_len; if (str_len > server.hash_max_listpack_value) { keep_lp = false; break; } } } robj* res = nullptr; if (keep_lp) { uint8_t* lp = lpNew(lp_size); for (size_t i = 0; i < len; ++i) { /* Add pair to listpack */ string_view sv = ToSV(ltrace->arr[i * 2].rdb_var); lp = lpAppend(lp, reinterpret_cast(sv.data()), sv.size()); sv = ToSV(ltrace->arr[i * 2 + 1].rdb_var); lp = lpAppend(lp, reinterpret_cast(sv.data()), sv.size()); } if (ec_) { lpFree(lp); return; } lp = lpShrinkToFit(lp); res = createObject(OBJ_HASH, lp); res->encoding = OBJ_ENCODING_LISTPACK; } else { dict* hmap = dictCreate(&hashDictType); auto cleanup = absl::MakeCleanup([&] { dictRelease(hmap); }); if (len > DICT_HT_INITIAL_SIZE) { if (dictTryExpand(hmap, len) != DICT_OK) { LOG(ERROR) << "OOM in dictTryExpand " << len; ec_ = RdbError(errc::out_of_memory); return; } } for (size_t i = 0; i < len; ++i) { sds key = ToSds(ltrace->arr[i * 2].rdb_var); sds val = ToSds(ltrace->arr[i * 2 + 1].rdb_var); if (!key || !val) return; /* Add pair to hash table */ int ret = dictAdd(hmap, key, val); if (ret == DICT_ERR) { LOG(ERROR) << "Duplicate hash fields detected"; ec_ = RdbError(errc::rdb_file_corrupted); return; } } res = createObject(OBJ_HASH, hmap); res->encoding = OBJ_ENCODING_HT; std::move(cleanup).Cancel(); } DCHECK(res); pv_->ImportRObj(res); } void RdbLoaderBase::OpaqueObjLoader::CreateList(const LoadTrace* ltrace) { quicklist* ql = quicklistNew(GetFlag(FLAGS_list_max_listpack_size), GetFlag(FLAGS_list_compress_depth)); auto cleanup = absl::Cleanup([&] { quicklistRelease(ql); }); for (size_t i = 0; i < ltrace->arr.size(); ++i) { unsigned container = ltrace->arr[i].encoding; string_view sv = ToSV(ltrace->arr[i].rdb_var); if (ec_) return; if (container == QUICKLIST_NODE_CONTAINER_PLAIN) { quicklistAppendPlainNode(ql, (uint8_t*)sv.data(), sv.size()); continue; } uint8_t* lp = nullptr; if (rdb_type_ == RDB_TYPE_LIST_QUICKLIST_2) { uint8_t* src = (uint8_t*)sv.data(); if (!lpValidateIntegrity(src, sv.size(), 0, NULL, NULL)) { LOG(ERROR) << "Listpack integrity check failed."; ec_ = RdbError(errc::rdb_file_corrupted); return; } if (lpLength(src) == 0) { continue; } lp = (uint8_t*)zmalloc(sv.size()); memcpy(lp, src, sv.size()); } else { lp = lpNew(sv.size()); if (!ziplistValidateIntegrity((uint8_t*)sv.data(), sv.size(), 1, ziplistEntryConvertAndValidate, &lp)) { LOG(ERROR) << "Ziplist integrity check failed."; zfree(lp); ec_ = RdbError(errc::rdb_file_corrupted); return; } /* Silently skip empty ziplists, if we'll end up with empty quicklist we'll fail later. */ if (lpLength(lp) == 0) { zfree(lp); continue; } lp = lpShrinkToFit(lp); } quicklistAppendListpack(ql, lp); } if (quicklistCount(ql) == 0) { ec_ = RdbError(errc::empty_key); return; } robj* res = createObject(OBJ_LIST, ql); res->encoding = OBJ_ENCODING_QUICKLIST; std::move(cleanup).Cancel(); pv_->ImportRObj(res); } void RdbLoaderBase::OpaqueObjLoader::CreateZSet(const LoadTrace* ltrace) { robj* res = createZsetObject(); zset* zs = (zset*)res->ptr; auto cleanup = absl::Cleanup([&] { decrRefCount(res); }); size_t zsetlen = ltrace->arr.size(); if (zsetlen > DICT_HT_INITIAL_SIZE && dictTryExpand(zs->dict, zsetlen) != DICT_OK) { LOG(ERROR) << "OOM in dictTryExpand " << zsetlen; ec_ = RdbError(errc::out_of_memory); return; } size_t maxelelen = 0, totelelen = 0; for (size_t i = 0; i < zsetlen; ++i) { sds sdsele = ToSds(ltrace->arr[i].rdb_var); if (!sdsele) return; double score = ltrace->arr[i].score; /* Don't care about integer-encoded strings. */ if (sdslen(sdsele) > maxelelen) maxelelen = sdslen(sdsele); totelelen += sdslen(sdsele); zskiplistNode* znode = zslInsert(zs->zsl, score, sdsele); int ret = dictAdd(zs->dict, sdsele, &znode->score); if (ret != DICT_OK) { LOG(ERROR) << "Duplicate zset fields detected"; sdsfree(sdsele); ec_ = RdbError(errc::rdb_file_corrupted); return; } } /* Convert *after* loading, since sorted sets are not stored ordered. */ if (zsetLength(res) <= server.zset_max_listpack_entries && maxelelen <= server.zset_max_listpack_value && lpSafeToAdd(NULL, totelelen)) { zsetConvert(res, OBJ_ENCODING_LISTPACK); } std::move(cleanup).Cancel(); pv_->ImportRObj(res); } void RdbLoaderBase::OpaqueObjLoader::CreateStream(const LoadTrace* ltrace) { CHECK(ltrace->stream_trace); robj* res = createStreamObject(); stream* s = (stream*)res->ptr; auto cleanup = absl::Cleanup([&] { decrRefCount(res); }); size_t lpcnt = ltrace->arr.size() / 2; for (size_t i = 0; i < lpcnt; ++i) { string_view nodekey = ToSV(ltrace->arr[i * 2].rdb_var); string_view data = ToSV(ltrace->arr[i * 2 + 1].rdb_var); uint8_t* lp = (uint8_t*)data.data(); if (!streamValidateListpackIntegrity(lp, data.size(), 0)) { LOG(ERROR) << "Stream listpack integrity check failed."; ec_ = RdbError(errc::rdb_file_corrupted); return; } unsigned char* first = lpFirst(lp); if (first == NULL) { /* Serialized listpacks should never be empty, since on * deletion we should remove the radix tree key if the * resulting listpack is empty. */ LOG(ERROR) << "Empty listpack inside stream"; ec_ = RdbError(errc::rdb_file_corrupted); return; } uint8_t* copy_lp = (uint8_t*)zmalloc(data.size()); memcpy(copy_lp, lp, data.size()); /* Insert the key in the radix tree. */ int retval = raxTryInsert(s->rax_tree, (unsigned char*)nodekey.data(), nodekey.size(), copy_lp, NULL); if (!retval) { zfree(copy_lp); LOG(ERROR) << "Listpack re-added with existing key"; ec_ = RdbError(errc::rdb_file_corrupted); return; } } s->length = ltrace->stream_trace->stream_len; s->last_id.ms = ltrace->stream_trace->ms; s->last_id.seq = ltrace->stream_trace->seq; for (const auto& cg : ltrace->stream_trace->cgroup) { string_view cgname = ToSV(cg.name); streamID cg_id; cg_id.ms = cg.ms; cg_id.seq = cg.seq; streamCG* cgroup = streamCreateCG(s, cgname.data(), cgname.size(), &cg_id, 0); if (cgroup == NULL) { LOG(ERROR) << "Duplicated consumer group name " << cgname; ec_ = RdbError(errc::duplicate_key); return; } for (const auto& pel : cg.pel_arr) { streamNACK* nack = streamCreateNACK(NULL); nack->delivery_time = pel.delivery_time; nack->delivery_count = pel.delivery_count; if (!raxTryInsert(cgroup->pel, const_cast(pel.rawid.data()), pel.rawid.size(), nack, NULL)) { LOG(ERROR) << "Duplicated global PEL entry loading stream consumer group"; ec_ = RdbError(errc::duplicate_key); streamFreeNACK(nack); return; } } for (const auto& cons : cg.cons_arr) { sds cname = ToSds(cons.name); streamConsumer* consumer = streamCreateConsumer(cgroup, cname, NULL, 0, SCC_NO_NOTIFY | SCC_NO_DIRTIFY); sdsfree(cname); if (!consumer) { LOG(ERROR) << "Duplicate stream consumer detected."; ec_ = RdbError(errc::duplicate_key); return; } consumer->seen_time = cons.seen_time; /* Create the PEL (pending entries list) about entries owned by this specific * consumer. */ for (const auto& rawid : cons.nack_arr) { uint8_t* ptr = const_cast(rawid.data()); streamNACK* nack = (streamNACK*)raxFind(cgroup->pel, ptr, rawid.size()); if (nack == raxNotFound) { LOG(ERROR) << "Consumer entry not found in group global PEL"; ec_ = RdbError(errc::rdb_file_corrupted); return; } /* Set the NACK consumer, that was left to NULL when * loading the global PEL. Then set the same shared * NACK structure also in the consumer-specific PEL. */ nack->consumer = consumer; if (!raxTryInsert(consumer->pel, ptr, rawid.size(), nack, NULL)) { LOG(ERROR) << "Duplicated consumer PEL entry loading a stream consumer group"; streamFreeNACK(nack); ec_ = RdbError(errc::duplicate_key); return; } } } } std::move(cleanup).Cancel(); pv_->ImportRObj(res); } void RdbLoaderBase::OpaqueObjLoader::HandleBlob(string_view blob) { if (rdb_type_ == RDB_TYPE_STRING) { pv_->SetString(blob); return; } robj* res = nullptr; if (rdb_type_ == RDB_TYPE_SET_INTSET) { if (!intsetValidateIntegrity((const uint8_t*)blob.data(), blob.size(), 0)) { LOG(ERROR) << "Intset integrity check failed."; ec_ = RdbError(errc::rdb_file_corrupted); return; } const intset* is = (const intset*)blob.data(); unsigned len = intsetLen(is); if (len > SetFamily::MaxIntsetEntries()) { res = createSetObject(); if (!SetFamily::ConvertToStrSet(is, len, res)) { LOG(ERROR) << "OOM in ConvertToStrSet " << len; decrRefCount(res); ec_ = RdbError(errc::out_of_memory); return; } } else { intset* mine = (intset*)zmalloc(blob.size()); memcpy(mine, blob.data(), blob.size()); res = createObject(OBJ_SET, mine); res->encoding = OBJ_ENCODING_INTSET; } } else if (rdb_type_ == RDB_TYPE_HASH_ZIPLIST) { unsigned char* lp = lpNew(blob.size()); if (!ziplistPairsConvertAndValidateIntegrity((const uint8_t*)blob.data(), blob.size(), &lp)) { LOG(ERROR) << "Zset ziplist integrity check failed."; zfree(lp); ec_ = RdbError(errc::rdb_file_corrupted); return; } if (lpLength(lp) == 0) { lpFree(lp); ec_ = RdbError(errc::empty_key); return; } res = createObject(OBJ_HASH, lp); res->encoding = OBJ_ENCODING_LISTPACK; if (lpBytes(lp) > HSetFamily::MaxListPackLen()) hashTypeConvert(res, OBJ_ENCODING_HT); else res->ptr = lpShrinkToFit((uint8_t*)res->ptr); } else if (rdb_type_ == RDB_TYPE_ZSET_ZIPLIST) { unsigned char* lp = lpNew(blob.size()); if (!ziplistPairsConvertAndValidateIntegrity((uint8_t*)blob.data(), blob.size(), &lp)) { LOG(ERROR) << "Zset ziplist integrity check failed."; zfree(lp); ec_ = RdbError(errc::rdb_file_corrupted); return; } if (lpLength(lp) == 0) { lpFree(lp); ec_ = RdbError(errc::empty_key); return; } res = createObject(OBJ_ZSET, lp); res->encoding = OBJ_ENCODING_LISTPACK; if (lpBytes(lp) > server.zset_max_listpack_entries) zsetConvert(res, OBJ_ENCODING_SKIPLIST); else res->ptr = lpShrinkToFit(lp); } else { LOG(FATAL) << "Unsupported rdb type " << rdb_type_; } pv_->ImportRObj(res); } sds RdbLoaderBase::OpaqueObjLoader::ToSds(const RdbVariant& obj) { if (holds_alternative(obj)) { return sdsfromlonglong(get(obj)); } const base::PODArray* ch_arr = get_if>(&obj); if (ch_arr) { return sdsnewlen(ch_arr->data(), ch_arr->size()); } const LzfString* lzf = get_if(&obj); if (lzf) { sds res = sdsnewlen(NULL, lzf->uncompressed_len); if (lzf_decompress(lzf->compressed_blob.data(), lzf->compressed_blob.size(), res, lzf->uncompressed_len) == 0) { LOG(ERROR) << "Invalid LZF compressed string"; ec_ = RdbError(errc::rdb_file_corrupted); sdsfree(res); return nullptr; } return res; } LOG(FATAL) << "Unexpected variant"; return nullptr; } string_view RdbLoaderBase::OpaqueObjLoader::ToSV(const RdbVariant& obj) { if (holds_alternative(obj)) { tset_blob_.resize(32); auto val = get(obj); char* next = absl::numbers_internal::FastIntToBuffer(val, tset_blob_.data()); return string_view{tset_blob_.data(), size_t(next - tset_blob_.data())}; } const base::PODArray* ch_arr = get_if>(&obj); if (ch_arr) { return string_view(ch_arr->data(), ch_arr->size()); } const LzfString* lzf = get_if(&obj); if (lzf) { tset_blob_.resize(lzf->uncompressed_len); if (lzf_decompress(lzf->compressed_blob.data(), lzf->compressed_blob.size(), tset_blob_.data(), lzf->uncompressed_len) == 0) { LOG(ERROR) << "Invalid LZF compressed string"; ec_ = RdbError(errc::rdb_file_corrupted); return string_view{}; } return string_view{tset_blob_.data(), tset_blob_.size()}; } LOG(FATAL) << "Unexpected variant"; return string_view{}; } std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { if (size == 0) return kOk; uint8_t* next = (uint8_t*)dest; size_t bytes_read; size_t to_copy = std::min(mem_buf_->InputLen(), size); DVLOG(2) << "Copying " << to_copy << " bytes"; ::memcpy(next, mem_buf_->InputBuffer().data(), to_copy); mem_buf_->ConsumeInput(to_copy); size -= to_copy; if (size == 0) return kOk; next += to_copy; if (size + bytes_read_ > source_limit_) { LOG(ERROR) << "Out of bound read " << size + bytes_read_ << " vs " << source_limit_; return RdbError(errc::rdb_file_corrupted); } if (size > 512) { // Worth reading directly into next. io::MutableBytes mb{next, size}; SET_OR_RETURN(src_->Read(mb), bytes_read); if (bytes_read < size) return RdbError(errc::rdb_file_corrupted); bytes_read_ += bytes_read; DCHECK_LE(bytes_read_, source_limit_); return kOk; } io::MutableBytes mb = mem_buf_->AppendBuffer(); // Must be because mem_buf_ is be empty. DCHECK_GT(mb.size(), size); if (bytes_read_ + mb.size() > source_limit_) { mb = mb.subspan(0, source_limit_ - bytes_read_); } SET_OR_RETURN(src_->ReadAtLeast(mb, size), bytes_read); if (bytes_read < size) return RdbError(errc::rdb_file_corrupted); bytes_read_ += bytes_read; DCHECK_LE(bytes_read_, source_limit_); mem_buf_->CommitWrite(bytes_read); ::memcpy(next, mem_buf_->InputBuffer().data(), size); mem_buf_->ConsumeInput(size); return kOk; } size_t RdbLoaderBase::StrLen(const RdbVariant& tset) { const base::PODArray* arr = get_if>(&tset); if (arr) return arr->size(); if (holds_alternative(tset)) { auto val = get(tset); char buf[32]; char* next = absl::numbers_internal::FastIntToBuffer(val, buf); return (next - buf); } const LzfString* lzf = get_if(&tset); if (lzf) return lzf->uncompressed_len; LOG(DFATAL) << "should not reach"; return 0; } auto RdbLoaderBase::FetchGenericString() -> io::Result { bool isencoded; size_t len; SET_OR_UNEXPECT(LoadLen(&isencoded), len); if (isencoded) { switch (len) { case RDB_ENC_INT8: case RDB_ENC_INT16: case RDB_ENC_INT32: return FetchIntegerObject(len); case RDB_ENC_LZF: return FetchLzfStringObject(); default: LOG(ERROR) << "Unknown RDB string encoding len " << len; return Unexpected(errc::rdb_file_corrupted); } } string res; if (len > 0) { res.resize(len); error_code ec = FetchBuf(len, res.data()); if (ec) { return make_unexpected(ec); } } return res; } auto RdbLoaderBase::FetchLzfStringObject() -> io::Result { bool zerocopy_decompress = true; const uint8_t* cbuf = NULL; uint64_t clen, len; SET_OR_UNEXPECT(LoadLen(NULL), clen); SET_OR_UNEXPECT(LoadLen(NULL), len); if (len > 1ULL << 29 || len <= clen || clen == 0) { LOG(ERROR) << "Bad compressed string"; return Unexpected(rdb::rdb_file_corrupted); } if (mem_buf_->InputLen() >= clen) { cbuf = mem_buf_->InputBuffer().data(); } else { compr_buf_.resize(clen); zerocopy_decompress = false; /* Load the compressed representation and uncompress it to target. */ error_code ec = FetchBuf(clen, compr_buf_.data()); if (ec) { return make_unexpected(ec); } cbuf = compr_buf_.data(); } string res(len, 0); if (lzf_decompress(cbuf, clen, res.data(), len) == 0) { LOG(ERROR) << "Invalid LZF compressed string"; return Unexpected(errc::rdb_file_corrupted); } // FetchBuf consumes the input but if we have not went through that path // we need to consume now. if (zerocopy_decompress) mem_buf_->ConsumeInput(clen); return res; } auto RdbLoaderBase::FetchIntegerObject(int enctype) -> io::Result { long long val; if (enctype == RDB_ENC_INT8) { SET_OR_UNEXPECT(FetchInt(), val); } else if (enctype == RDB_ENC_INT16) { SET_OR_UNEXPECT(FetchInt(), val); } else if (enctype == RDB_ENC_INT32) { SET_OR_UNEXPECT(FetchInt(), val); } else { return Unexpected(errc::invalid_encoding); } char buf[32]; absl::numbers_internal::FastIntToBuffer(val, buf); return string(buf); } io::Result RdbLoaderBase::FetchBinaryDouble() { union { uint64_t val; double d; } u; static_assert(sizeof(u) == sizeof(uint64_t)); auto ec = EnsureRead(8); if (ec) return make_unexpected(ec); uint8_t buf[8]; mem_buf_->ReadAndConsume(8, buf); u.val = base::LE::LoadT(buf); return u.d; } io::Result RdbLoaderBase::FetchDouble() { uint8_t len; SET_OR_UNEXPECT(FetchInt(), len); constexpr double kInf = std::numeric_limits::infinity(); switch (len) { case 255: return -kInf; case 254: return kInf; case 253: return std::numeric_limits::quiet_NaN(); default:; } char buf[256]; error_code ec = FetchBuf(len, buf); if (ec) return make_unexpected(ec); buf[len] = '\0'; double val; if (sscanf(buf, "%lg", &val) != 1) return Unexpected(errc::rdb_file_corrupted); return val; } auto RdbLoaderBase::ReadKey() -> io::Result { return FetchGenericString(); } error_code RdbLoaderBase::ReadObj(int rdbtype, OpaqueObj* dest) { io::Result iores; switch (rdbtype) { case RDB_TYPE_STRING: { dest->rdb_type = RDB_TYPE_STRING; /* Read string value */ return ReadStringObj(&dest->obj); } case RDB_TYPE_SET: iores = ReadSet(); break; case RDB_TYPE_SET_INTSET: iores = ReadIntSet(); break; case RDB_TYPE_HASH_ZIPLIST: iores = ReadHZiplist(); break; case RDB_TYPE_HASH: iores = ReadHMap(); break; case RDB_TYPE_ZSET: case RDB_TYPE_ZSET_2: iores = ReadZSet(rdbtype); break; case RDB_TYPE_ZSET_ZIPLIST: iores = ReadZSetZL(); break; case RDB_TYPE_LIST_QUICKLIST: case RDB_TYPE_LIST_QUICKLIST_2: iores = ReadListQuicklist(rdbtype); break; case RDB_TYPE_STREAM_LISTPACKS: iores = ReadStreams(); break; break; default: LOG(ERROR) << "Unsupported rdb type " << rdbtype; return RdbError(errc::invalid_encoding); } if (!iores) return iores.error(); *dest = std::move(*iores); return error_code{}; } error_code RdbLoaderBase::ReadStringObj(RdbVariant* dest) { bool isencoded; size_t len; SET_OR_RETURN(LoadLen(&isencoded), len); if (isencoded) { switch (len) { case RDB_ENC_INT8: case RDB_ENC_INT16: case RDB_ENC_INT32: { io::Result io_int = ReadIntObj(len); if (!io_int) return io_int.error(); dest->emplace(*io_int); return error_code{}; } case RDB_ENC_LZF: { io::Result lzf = ReadLzf(); if (!lzf) return lzf.error(); dest->emplace(std::move(lzf.value())); return error_code{}; } default: LOG(ERROR) << "Unknown RDB string encoding " << len; return RdbError(errc::rdb_file_corrupted); } } auto& blob = dest->emplace>(); blob.resize(len); error_code ec = FetchBuf(len, blob.data()); return ec; } io::Result RdbLoaderBase::ReadIntObj(int enctype) { long long val; if (enctype == RDB_ENC_INT8) { SET_OR_UNEXPECT(FetchInt(), val); } else if (enctype == RDB_ENC_INT16) { SET_OR_UNEXPECT(FetchInt(), val); } else if (enctype == RDB_ENC_INT32) { SET_OR_UNEXPECT(FetchInt(), val); } else { return Unexpected(errc::invalid_encoding); } return val; } auto RdbLoaderBase::ReadLzf() -> io::Result { uint64_t clen; LzfString res; SET_OR_UNEXPECT(LoadLen(NULL), clen); SET_OR_UNEXPECT(LoadLen(NULL), res.uncompressed_len); if (res.uncompressed_len > 1ULL << 29) { LOG(ERROR) << "Uncompressed length is too big " << res.uncompressed_len; return Unexpected(errc::rdb_file_corrupted); } res.compressed_blob.resize(clen); /* Load the compressed representation and uncompress it to target. */ error_code ec = FetchBuf(clen, res.compressed_blob.data()); if (ec) { return make_unexpected(ec); } return res; } auto RdbLoaderBase::ReadSet() -> io::Result { size_t len; SET_OR_UNEXPECT(LoadLen(NULL), len); if (len == 0) return Unexpected(errc::empty_key); unique_ptr res(new LoadTrace); res->arr.resize(len); for (size_t i = 0; i < len; i++) { error_code ec = ReadStringObj(&res->arr[i].rdb_var); if (ec) { return make_unexpected(ec); } } return OpaqueObj{std::move(res), RDB_TYPE_SET}; } auto RdbLoaderBase::ReadIntSet() -> io::Result { RdbVariant obj; error_code ec = ReadStringObj(&obj); if (ec) { return make_unexpected(ec); } const LzfString* lzf = get_if(&obj); const base::PODArray* arr = get_if>(&obj); if (lzf) { if (lzf->uncompressed_len == 0 || lzf->compressed_blob.empty()) return Unexpected(errc::rdb_file_corrupted); } else if (arr) { if (arr->empty()) return Unexpected(errc::rdb_file_corrupted); } else { return Unexpected(errc::rdb_file_corrupted); } return OpaqueObj{std::move(obj), RDB_TYPE_SET_INTSET}; } auto RdbLoaderBase::ReadHZiplist() -> io::Result { RdbVariant str_obj; error_code ec = ReadStringObj(&str_obj); if (ec) return make_unexpected(ec); if (StrLen(str_obj) == 0) { return Unexpected(errc::rdb_file_corrupted); } return OpaqueObj{std::move(str_obj), RDB_TYPE_HASH_ZIPLIST}; } auto RdbLoaderBase::ReadHMap() -> io::Result { uint64_t len; SET_OR_UNEXPECT(LoadLen(nullptr), len); if (len == 0) return Unexpected(errc::empty_key); unique_ptr load_trace(new LoadTrace); load_trace->arr.resize(len * 2); for (size_t i = 0; i < load_trace->arr.size(); ++i) { error_code ec = ReadStringObj(&load_trace->arr[i].rdb_var); if (ec) return make_unexpected(ec); } return OpaqueObj{std::move(load_trace), RDB_TYPE_HASH}; } auto RdbLoaderBase::ReadZSet(int rdbtype) -> io::Result { /* Read sorted set value. */ uint64_t zsetlen; SET_OR_UNEXPECT(LoadLen(nullptr), zsetlen); if (zsetlen == 0) return Unexpected(errc::empty_key); unique_ptr load_trace(new LoadTrace); load_trace->arr.resize(zsetlen); double score; for (size_t i = 0; i < load_trace->arr.size(); ++i) { error_code ec = ReadStringObj(&load_trace->arr[i].rdb_var); if (ec) return make_unexpected(ec); if (rdbtype == RDB_TYPE_ZSET_2) { SET_OR_UNEXPECT(FetchBinaryDouble(), score); } else { SET_OR_UNEXPECT(FetchDouble(), score); } if (isnan(score)) { LOG(ERROR) << "Zset with NAN score detected"; return Unexpected(errc::rdb_file_corrupted); } load_trace->arr[i].score = score; } return OpaqueObj{std::move(load_trace), rdbtype}; } auto RdbLoaderBase::ReadZSetZL() -> io::Result { RdbVariant str_obj; error_code ec = ReadStringObj(&str_obj); if (ec) return make_unexpected(ec); if (StrLen(str_obj) == 0) { return Unexpected(errc::rdb_file_corrupted); } return OpaqueObj{std::move(str_obj), RDB_TYPE_ZSET_ZIPLIST}; } auto RdbLoaderBase::ReadListQuicklist(int rdbtype) -> io::Result { uint64_t len; SET_OR_UNEXPECT(LoadLen(nullptr), len); if (len == 0) return Unexpected(errc::empty_key); unique_ptr load_trace(new LoadTrace); load_trace->arr.resize(len); for (size_t i = 0; i < len; ++i) { uint64_t container = QUICKLIST_NODE_CONTAINER_PACKED; if (rdbtype == RDB_TYPE_LIST_QUICKLIST_2) { SET_OR_UNEXPECT(LoadLen(nullptr), container); if (container != QUICKLIST_NODE_CONTAINER_PACKED && container != QUICKLIST_NODE_CONTAINER_PLAIN) { LOG(ERROR) << "Quicklist integrity check failed."; return Unexpected(errc::rdb_file_corrupted); } } RdbVariant var; error_code ec = ReadStringObj(&var); if (ec) return make_unexpected(ec); if (StrLen(var) == 0) { return Unexpected(errc::rdb_file_corrupted); } load_trace->arr[i].rdb_var = std::move(var); load_trace->arr[i].encoding = container; } return OpaqueObj{std::move(load_trace), rdbtype}; } auto RdbLoaderBase::ReadStreams() -> io::Result { uint64_t listpacks; SET_OR_UNEXPECT(LoadLen(nullptr), listpacks); unique_ptr load_trace(new LoadTrace); load_trace->arr.resize(listpacks * 2); error_code ec; for (size_t i = 0; i < listpacks; ++i) { /* Get the master ID, the one we'll use as key of the radix tree * node: the entries inside the listpack itself are delta-encoded * relatively to this ID. */ RdbVariant stream_id, blob; ec = ReadStringObj(&stream_id); if (ec) return make_unexpected(ec); if (StrLen(stream_id) != sizeof(streamID)) { LOG(ERROR) << "Stream node key entry is not the size of a stream ID"; return Unexpected(errc::rdb_file_corrupted); } ec = ReadStringObj(&blob); if (ec) return make_unexpected(ec); if (StrLen(blob) == 0) { LOG(ERROR) << "Stream listpacks loading failed"; return Unexpected(errc::rdb_file_corrupted); } load_trace->arr[2 * i].rdb_var = std::move(stream_id); load_trace->arr[2 * i + 1].rdb_var = std::move(blob); } load_trace->stream_trace.reset(new StreamTrace); /* Load total number of items inside the stream. */ SET_OR_UNEXPECT(LoadLen(nullptr), load_trace->stream_trace->stream_len); /* Load the last entry ID. */ SET_OR_UNEXPECT(LoadLen(nullptr), load_trace->stream_trace->ms); SET_OR_UNEXPECT(LoadLen(nullptr), load_trace->stream_trace->seq); /* Consumer groups loading */ uint64_t cgroups_count; SET_OR_UNEXPECT(LoadLen(nullptr), cgroups_count); load_trace->stream_trace->cgroup.resize(cgroups_count); for (size_t i = 0; i < cgroups_count; ++i) { auto& cgroup = load_trace->stream_trace->cgroup[i]; /* Get the consumer group name and ID. We can then create the * consumer group ASAP and populate its structure as * we read more data. */ // sds cgname; RdbVariant cgname; ec = ReadStringObj(&cgname); if (ec) return make_unexpected(ec); cgroup.name = std::move(cgname); SET_OR_UNEXPECT(LoadLen(nullptr), cgroup.ms); SET_OR_UNEXPECT(LoadLen(nullptr), cgroup.seq); /* Load the global PEL for this consumer group, however we'll * not yet populate the NACK structures with the message * owner, since consumers for this group and their messages will * be read as a next step. So for now leave them not resolved * and later populate it. */ uint64_t pel_size; SET_OR_UNEXPECT(LoadLen(nullptr), pel_size); cgroup.pel_arr.resize(pel_size); for (size_t j = 0; j < pel_size; ++j) { auto& pel = cgroup.pel_arr[j]; error_code ec = FetchBuf(pel.rawid.size(), pel.rawid.data()); if (ec) { LOG(ERROR) << "Stream PEL ID loading failed."; return make_unexpected(ec); } // streamNACK* nack = streamCreateNACK(NULL); // auto cleanup2 = absl::Cleanup([&] { streamFreeNACK(nack); }); SET_OR_UNEXPECT(FetchInt(), pel.delivery_time); SET_OR_UNEXPECT(LoadLen(nullptr), pel.delivery_count); /*if (!raxTryInsert(cgroup->pel, rawid, sizeof(rawid), nack, NULL)) { LOG(ERROR) << "Duplicated global PEL entry loading stream consumer group"; return Unexpected(errc::duplicate_key); } std::move(cleanup2).Cancel();*/ } /* Now that we loaded our global PEL, we need to load the * consumers and their local PELs. */ uint64_t consumers_num; SET_OR_UNEXPECT(LoadLen(nullptr), consumers_num); cgroup.cons_arr.resize(consumers_num); for (size_t j = 0; j < consumers_num; ++j) { auto& consumer = cgroup.cons_arr[j]; ec = ReadStringObj(&consumer.name); if (ec) return make_unexpected(ec); SET_OR_UNEXPECT(FetchInt(), consumer.seen_time); /* Load the PEL about entries owned by this specific * consumer. */ SET_OR_UNEXPECT(LoadLen(nullptr), pel_size); consumer.nack_arr.resize(pel_size); for (size_t k = 0; k < pel_size; ++k) { auto& nack = consumer.nack_arr[k]; // unsigned char rawid[sizeof(streamID)]; error_code ec = FetchBuf(nack.size(), nack.data()); if (ec) { LOG(ERROR) << "Stream PEL ID loading failed."; return make_unexpected(ec); } /*streamNACK* nack = (streamNACK*)raxFind(cgroup->pel, rawid, sizeof(rawid)); if (nack == raxNotFound) { LOG(ERROR) << "Consumer entry not found in group global PEL"; return Unexpected(errc::rdb_file_corrupted); }*/ /* Set the NACK consumer, that was left to NULL when * loading the global PEL. Then set the same shared * NACK structure also in the consumer-specific PEL. */ /* nack->consumer = consumer; if (!raxTryInsert(consumer->pel, rawid, sizeof(rawid), nack, NULL)) { LOG(ERROR) << "Duplicated consumer PEL entry loading a stream consumer group"; streamFreeNACK(nack); return Unexpected(errc::duplicate_key); }*/ } } // while (consumers_num) } // while (cgroup_num) return OpaqueObj{std::move(load_trace), RDB_TYPE_STREAM_LISTPACKS}; } template io::Result RdbLoaderBase::FetchInt() { auto ec = EnsureRead(sizeof(T)); if (ec) return make_unexpected(ec); char buf[16]; mem_buf_->ReadAndConsume(sizeof(T), buf); return base::LE::LoadT>(buf); } // -------------- RdbLoader ---------------------------- struct RdbLoader::ObjSettings { long long now; // current epoch time in ms. int64_t expiretime = 0; // expire epoch time in ms bool has_expired = false; void Reset() { expiretime = 0; has_expired = false; } void SetExpire(int64_t val) { expiretime = val; has_expired = (val <= now); } ObjSettings() = default; }; RdbLoader::RdbLoader(Service* service) : service_{service}, script_mgr_{service == nullptr ? nullptr : service->script_mgr()} { shard_buf_.reset(new ItemsBuf[shard_set->size()]); } RdbLoader::~RdbLoader() { while (true) { Item* item = item_queue_.Pop(); if (item == nullptr) break; delete item; } } error_code RdbLoader::Load(io::Source* src) { CHECK(!src_ && src); absl::Time start = absl::Now(); src_ = src; IoBuf::Bytes bytes = mem_buf_->AppendBuffer(); io::Result read_sz = src_->ReadAtLeast(bytes, 9); if (!read_sz) return read_sz.error(); bytes_read_ = *read_sz; if (bytes_read_ < 9) { return RdbError(errc::wrong_signature); } mem_buf_->CommitWrite(bytes_read_); { auto cb = mem_buf_->InputBuffer(); if (memcmp(cb.data(), "REDIS", 5) != 0) { return RdbError(errc::wrong_signature); } char buf[64] = {0}; ::memcpy(buf, cb.data() + 5, 4); int rdbver = atoi(buf); if (rdbver < 5 || rdbver > RDB_VERSION) { // We accept starting from 5. LOG(ERROR) << "RDB Version " << rdbver << " is not supported"; return RdbError(errc::bad_version); } mem_buf_->ConsumeInput(9); } int type; /* Key-specific attributes, set by opcodes before the key type. */ ObjSettings settings; settings.now = mstime(); size_t keys_loaded = 0; auto cleanup = absl::Cleanup([&] { FinishLoad(start, &keys_loaded); }); while (!stop_early_.load(memory_order_relaxed)) { /* Read type. */ SET_OR_RETURN(FetchType(), type); /* Handle special types. */ if (type == RDB_OPCODE_EXPIRETIME) { LOG(ERROR) << "opcode RDB_OPCODE_EXPIRETIME not supported"; return RdbError(errc::invalid_encoding); } if (type == RDB_OPCODE_EXPIRETIME_MS) { int64_t val; /* EXPIRETIME_MS: milliseconds precision expire times introduced * with RDB v3. Like EXPIRETIME but no with more precision. */ SET_OR_RETURN(FetchInt(), val); settings.SetExpire(val); continue; /* Read next opcode. */ } if (type == RDB_OPCODE_FREQ) { /* FREQ: LFU frequency. */ FetchInt(); // IGNORE continue; /* Read next opcode. */ } if (type == RDB_OPCODE_IDLE) { /* IDLE: LRU idle time. */ uint64_t idle; SET_OR_RETURN(LoadLen(nullptr), idle); // ignore (void)idle; continue; /* Read next opcode. */ } if (type == RDB_OPCODE_EOF) { /* EOF: End of file, exit the main loop. */ break; } if (type == RDB_OPCODE_FULLSYNC_END) { if (full_sync_cut_cb) full_sync_cut_cb(); continue; } if (type == RDB_OPCODE_SELECTDB) { unsigned dbid = 0; /* SELECTDB: Select the specified database. */ SET_OR_RETURN(LoadLen(nullptr), dbid); if (dbid > GetFlag(FLAGS_dbnum)) { LOG(WARNING) << "database id " << dbid << " exceeds dbnum limit. Try increasing the flag."; return RdbError(errc::bad_db_index); } VLOG(1) << "Select DB: " << dbid; for (unsigned i = 0; i < shard_set->size(); ++i) { // we should flush pending items before switching dbid. FlushShardAsync(i); // Active database if not existed before. shard_set->Add(i, [dbid] { EngineShard::tlocal()->db_slice().ActivateDb(dbid); }); } cur_db_index_ = dbid; continue; /* Read next opcode. */ } if (type == RDB_OPCODE_RESIZEDB) { /* RESIZEDB: Hint about the size of the keys in the currently * selected data base, in order to avoid useless rehashing. */ uint64_t db_size, expires_size; SET_OR_RETURN(LoadLen(nullptr), db_size); SET_OR_RETURN(LoadLen(nullptr), expires_size); ResizeDb(db_size, expires_size); continue; /* Read next opcode. */ } if (type == RDB_OPCODE_AUX) { RETURN_ON_ERR(HandleAux()); continue; /* Read type again. */ } if (type == RDB_OPCODE_MODULE_AUX) { LOG(ERROR) << "Modules are not supported"; return RdbError(errc::feature_not_supported); } if (type == RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START || type == RDB_OPCODE_COMPRESSED_LZ4_BLOB_START) { RETURN_ON_ERR(HandleCompressedBlob(type)); continue; } if (type == RDB_OPCODE_COMPRESSED_BLOB_END) { RETURN_ON_ERR(HandleCompressedBlobFinish()); continue; } if (type == RDB_OPCODE_JOURNAL_BLOB) { // We should flush all changes on the current db before applying incremental changes. for (unsigned i = 0; i < shard_set->size(); ++i) { FlushShardAsync(i); } RETURN_ON_ERR(HandleJournalBlob(service_, cur_db_index_)); continue; } if (!rdbIsObjectType(type)) { return RdbError(errc::invalid_rdb_type); } ++keys_loaded; RETURN_ON_ERR(LoadKeyValPair(type, &settings)); settings.Reset(); } // main load loop if (stop_early_) { return *ec_; } /* Verify the checksum if RDB version is >= 5 */ RETURN_ON_ERR(VerifyChecksum()); return kOk; } void RdbLoader::FinishLoad(absl::Time start_time, size_t* keys_loaded) { fibers_ext::BlockingCounter bc(shard_set->size()); for (unsigned i = 0; i < shard_set->size(); ++i) { // Flush the remaining items. FlushShardAsync(i); // Send sentinel callbacks to ensure that all previous messages have been processed. shard_set->Add(i, [bc]() mutable { bc.Dec(); }); } bc.Wait(); // wait for sentinels to report. absl::Duration dur = absl::Now() - start_time; load_time_ = double(absl::ToInt64Milliseconds(dur)) / 1000; keys_loaded_ = *keys_loaded; } std::error_code RdbLoaderBase::EnsureRead(size_t min_sz) { // In the flow of reading compressed data, we store the uncompressed data to in uncompressed // buffer. When parsing entries we call ensure read with 9 bytes to read the length of // key/value. If the key/value is very small (less than 9 bytes) the remainded data in // uncompressed buffer might contain less than 9 bytes. We need to make sure that we dont read // from sink to the uncompressed buffer and therefor in this flow we return here. if (mem_buf_ != &origin_mem_buf_) return std::error_code{}; if (mem_buf_->InputLen() >= min_sz) return std::error_code{}; return EnsureReadInternal(min_sz); } error_code RdbLoaderBase::EnsureReadInternal(size_t min_sz) { DCHECK_LT(mem_buf_->InputLen(), min_sz); auto out_buf = mem_buf_->AppendBuffer(); CHECK_GT(out_buf.size(), min_sz); // If limit was applied we do not want to read more than needed // important when reading from sockets. if (bytes_read_ + out_buf.size() > source_limit_) { out_buf = out_buf.subspan(0, source_limit_ - bytes_read_); } io::Result res = src_->ReadAtLeast(out_buf, min_sz); if (!res) return res.error(); if (*res < min_sz) return RdbError(errc::rdb_file_corrupted); bytes_read_ += *res; DCHECK_LE(bytes_read_, source_limit_); mem_buf_->CommitWrite(*res); return kOk; } auto RdbLoaderBase::LoadLen(bool* is_encoded) -> io::Result { if (is_encoded) *is_encoded = false; // Every RDB file with rdbver >= 5 has 8-bytes checksum at the end, // so we can ensure we have 9 bytes to read up until that point. error_code ec = EnsureRead(9); if (ec) return make_unexpected(ec); uint64_t res = 0; uint8_t first = mem_buf_->InputBuffer()[0]; int type = (first & 0xC0) >> 6; mem_buf_->ConsumeInput(1); if (type == RDB_ENCVAL) { /* Read a 6 bit encoding type. */ if (is_encoded) *is_encoded = true; res = first & 0x3F; } else if (type == RDB_6BITLEN) { /* Read a 6 bit len. */ res = first & 0x3F; } else if (type == RDB_14BITLEN) { res = ((first & 0x3F) << 8) | mem_buf_->InputBuffer()[0]; mem_buf_->ConsumeInput(1); } else if (first == RDB_32BITLEN) { /* Read a 32 bit len. */ res = absl::big_endian::Load32(mem_buf_->InputBuffer().data()); mem_buf_->ConsumeInput(4); } else if (first == RDB_64BITLEN) { /* Read a 64 bit len. */ res = absl::big_endian::Load64(mem_buf_->InputBuffer().data()); mem_buf_->ConsumeInput(8); } else { LOG(ERROR) << "Bad length encoding " << type << " in rdbLoadLen()"; return Unexpected(errc::rdb_file_corrupted); } return res; } void RdbLoaderBase::AllocateDecompressOnce(int op_type) { if (decompress_impl_) { return; } if (op_type == RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START) { decompress_impl_.reset(new ZstdDecompress()); } else if (op_type == RDB_OPCODE_COMPRESSED_LZ4_BLOB_START) { decompress_impl_.reset(new Lz4Decompress()); } else { CHECK(false) << "Decompressor allocation should not be done"; } } error_code RdbLoaderBase::HandleCompressedBlob(int op_type) { AllocateDecompressOnce(op_type); // Fetch uncompress blob string res; SET_OR_RETURN(FetchGenericString(), res); // Decompress blob and switch membuf pointer // Last type in the compressed blob is RDB_OPCODE_COMPRESSED_BLOB_END // in which we will switch back to the origin membuf (HandleCompressedBlobFinish) string_view uncompressed_blob; SET_OR_RETURN(decompress_impl_->Decompress(res), mem_buf_); return kOk; } error_code RdbLoaderBase::HandleCompressedBlobFinish() { CHECK_NE(&origin_mem_buf_, mem_buf_); CHECK_EQ(mem_buf_->InputLen(), size_t(0)); mem_buf_ = &origin_mem_buf_; return kOk; } error_code RdbLoaderBase::HandleJournalBlob(Service* service, DbIndex dbid) { // Read the number of entries in the journal blob. size_t num_entries; bool _encoded; SET_OR_RETURN(LoadLen(&_encoded), num_entries); // Read the journal blob. string journal_blob; SET_OR_RETURN(FetchGenericString(), journal_blob); io::BytesSource bs{io::Buffer(journal_blob)}; journal_reader_.SetDb(dbid); // Parse and exectue in loop. size_t done = 0; JournalExecutor ex{service}; while (done < num_entries) { journal::ParsedEntry entry{}; SET_OR_RETURN(journal_reader_.ReadEntry(&bs), entry); ex.Execute(std::move(entry)); done++; } return std::error_code{}; } error_code RdbLoader::HandleAux() { /* AUX: generic string-string fields. Use to add state to RDB * which is backward compatible. Implementations of RDB loading * are required to skip AUX fields they don't understand. * * An AUX field is composed of two strings: key and value. */ string auxkey, auxval; SET_OR_RETURN(FetchGenericString(), auxkey); SET_OR_RETURN(FetchGenericString(), auxval); if (!auxkey.empty() && auxkey[0] == '%') { /* All the fields with a name staring with '%' are considered * information fields and are logged at startup with a log * level of NOTICE. */ LOG(INFO) << "RDB '" << auxkey << "': " << auxval; } else if (auxkey == "repl-stream-db") { // TODO } else if (auxkey == "repl-id") { // TODO } else if (auxkey == "repl-offset") { // TODO } else if (auxkey == "lua") { ServerState* ss = ServerState::tlocal(); Interpreter& script = ss->GetInterpreter(); string_view body{auxval}; string result; Interpreter::AddResult add_result = script.AddFunction(body, &result); if (add_result == Interpreter::ADD_OK) { if (script_mgr_) script_mgr_->InsertFunction(result, body); } else if (add_result == Interpreter::COMPILE_ERR) { LOG(ERROR) << "Error when compiling lua scripts"; } } else if (auxkey == "redis-ver") { LOG(INFO) << "Loading RDB produced by version " << auxval; } else if (auxkey == "ctime") { int64_t ctime; if (absl::SimpleAtoi(auxval, &ctime)) { time_t age = time(NULL) - ctime; if (age < 0) age = 0; LOG(INFO) << "RDB age " << strings::HumanReadableElapsedTime(age); } } else if (auxkey == "used-mem") { long long usedmem; if (absl::SimpleAtoi(auxval, &usedmem)) { LOG(INFO) << "RDB memory usage when created " << strings::HumanReadableNumBytes(usedmem); } } else if (auxkey == "aof-preamble") { long long haspreamble; if (absl::SimpleAtoi(auxval, &haspreamble) && haspreamble) { LOG(INFO) << "RDB has an AOF tail"; } } else if (auxkey == "redis-bits") { /* Just ignored. */ } else { /* We ignore fields we don't understand, as by AUX field * contract. */ LOG(WARNING) << "Unrecognized RDB AUX field: '" << auxkey << "'"; } return kOk; } error_code RdbLoader::VerifyChecksum() { uint64_t expected; SET_OR_RETURN(FetchInt(), expected); io::Bytes cur_buf = mem_buf_->InputBuffer(); VLOG(1) << "VerifyChecksum: input buffer len " << cur_buf.size() << ", expected " << expected; return kOk; } void RdbLoader::FlushShardAsync(ShardId sid) { auto& out_buf = shard_buf_[sid]; if (out_buf.empty()) return; auto cb = [indx = this->cur_db_index_, this, ib = std::move(out_buf)] { this->LoadItemsBuffer(indx, ib); }; shard_set->Add(sid, std::move(cb)); } std::error_code RdbLoaderBase::FromOpaque(const OpaqueObj& opaque, CompactObj* pv) { OpaqueObjLoader visitor(opaque.rdb_type, pv); std::visit(visitor, opaque.obj); return visitor.ec(); } void RdbLoader::LoadItemsBuffer(DbIndex db_ind, const ItemsBuf& ib) { DbSlice& db_slice = EngineShard::tlocal()->db_slice(); DbContext db_cntx{.db_index = db_ind, .time_now_ms = GetCurrentTimeMs()}; for (const auto* item : ib) { PrimeValue pv; if (ec_ = FromOpaque(item->val, &pv); ec_) { stop_early_ = true; break; } if (item->expire_ms > 0 && db_cntx.time_now_ms >= item->expire_ms) continue; auto [it, added] = db_slice.AddOrUpdate(db_cntx, item->key, std::move(pv), item->expire_ms); if (!added) { LOG(WARNING) << "RDB has duplicated key '" << item->key << "' in DB " << db_ind; } } for (auto* item : ib) { item_queue_.Push(item); } } void RdbLoader::ResizeDb(size_t key_num, size_t expire_num) { DCHECK_LT(key_num, 1U << 31); DCHECK_LT(expire_num, 1U << 31); } error_code RdbLoader::LoadKeyValPair(int type, ObjSettings* settings) { // We return the item in LoadItemsBuffer. Item* item = item_queue_.Pop(); if (item == nullptr) { item = new Item; } // Read key SET_OR_RETURN(ReadKey(), item->key); // Read value error_code ec = ReadObj(type, &item->val); if (ec) { VLOG(1) << "ReadObj error " << ec << " for key " << item->key; return ec; } /* Check if the key already expired. This function is used when loading * an RDB file from disk, either at startup, or when an RDB was * received from the master. In the latter case, the master is * responsible for key expiry. If we would expire keys here, the * snapshot taken by the master may not be reflected on the slave. * Similarly if the RDB is the preamble of an AOF file, we want to * load all the keys as they are, since the log of operations later * assume to work in an exact keyspace state. */ // TODO: check rdbflags&RDBFLAGS_AOF_PREAMBLE logic in rdb.c bool should_expire = settings->has_expired; // TODO: to implement if (should_expire) { // decrRefCount(val); } else { ShardId sid = Shard(item->key, shard_set->size()); item->expire_ms = settings->expiretime; auto& out_buf = shard_buf_[sid]; out_buf.emplace_back(item); constexpr size_t kBufSize = 128; if (out_buf.size() >= kBufSize) { FlushShardAsync(sid); } } return kOk; } } // namespace dfly