From 53d6b64233dac9e8d81866abc7d67e07f1ecfc92 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Mon, 16 Dec 2024 11:16:02 +0200 Subject: [PATCH] chore: factor out rdb_load utilities into separate files (#4315) * chore: factor out rdb_load utilities into separate files rdb_load.cc is huge and contains many auxillary classes. This PR moves DecompressImpl and ErrorRdb code into detail/ It also fixes minor bugs around error conditions with de-compression: a. Do not check-fail on invalid opcode and return error_code instead. b. Print correctly LZ4 errors. Signed-off-by: Roman Gershman * chore: fixes --------- Signed-off-by: Roman Gershman --- src/server/CMakeLists.txt | 3 +- src/server/detail/decompress.cc | 173 ++++++++++++++++++++++++++ src/server/detail/decompress.h | 32 +++++ src/server/error.cc | 57 +++++++++ src/server/error.h | 2 + src/server/rdb_load.cc | 207 ++------------------------------ src/server/rdb_load.h | 7 +- 7 files changed, 278 insertions(+), 203 deletions(-) create mode 100644 src/server/detail/decompress.cc create mode 100644 src/server/detail/decompress.h create mode 100644 src/server/error.cc diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index b7d862c21..b837379bb 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -45,11 +45,12 @@ endif() add_library(dragonfly_lib bloom_family.cc config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc engine_shard.cc - engine_shard_set.cc family_utils.cc + engine_shard_set.cc error.cc family_utils.cc generic_family.cc hset_family.cc http_api.cc json_family.cc list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc protocol_client.cc snapshot.cc script_mgr.cc server_family.cc + detail/decompress.cc detail/save_stages_controller.cc detail/snapshot_storage.cc set_family.cc stream_family.cc string_family.cc diff --git a/src/server/detail/decompress.cc b/src/server/detail/decompress.cc new file mode 100644 index 000000000..ae8a6d580 --- /dev/null +++ b/src/server/detail/decompress.cc @@ -0,0 +1,173 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/detail/decompress.h" + +#include +#include + +#include "base/logging.h" +#include "server/error.h" +#include "server/rdb_extensions.h" + +namespace dfly { + +namespace detail { + +using io::IoBuf; +using rdb::errc; +using namespace std; + +inline auto Unexpected(errc ev) { + return nonstd::make_unexpected(RdbError(ev)); +} + +class ZstdDecompress : public DecompressImpl { + public: + ZstdDecompress() { + dctx_ = ZSTD_createDCtx(); + } + ~ZstdDecompress() { + ZSTD_freeDCtx(dctx_); + } + + io::Result Decompress(std::string_view str); + + private: + ZSTD_DCtx* dctx_; +}; + +io::Result ZstdDecompress::Decompress(std::string_view str) { + // Prepare membuf memory to uncompressed string. + auto uncomp_size = ZSTD_getFrameContentSize(str.data(), str.size()); + if (uncomp_size == ZSTD_CONTENTSIZE_UNKNOWN) { + LOG(ERROR) << "Zstd compression missing frame content size"; + return Unexpected(errc::invalid_encoding); + } + if (uncomp_size == ZSTD_CONTENTSIZE_ERROR) { + LOG(ERROR) << "Invalid ZSTD compressed string"; + return Unexpected(errc::invalid_encoding); + } + + uncompressed_mem_buf_.Reserve(uncomp_size + 1); + + // Uncompress string to membuf + IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); + if (dest.size() < uncomp_size) { + return Unexpected(errc::out_of_memory); + } + size_t const d_size = + ZSTD_decompressDCtx(dctx_, dest.data(), dest.size(), str.data(), str.size()); + if (d_size == 0 || d_size != uncomp_size) { + LOG(ERROR) << "Invalid ZSTD compressed string"; + return Unexpected(errc::rdb_file_corrupted); + } + uncompressed_mem_buf_.CommitWrite(d_size); + + // Add opcode of compressed blob end to membuf. + dest = uncompressed_mem_buf_.AppendBuffer(); + if (dest.size() < 1) { + return Unexpected(errc::out_of_memory); + } + dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; + uncompressed_mem_buf_.CommitWrite(1); + + return &uncompressed_mem_buf_; +} + +class Lz4Decompress : public DecompressImpl { + public: + Lz4Decompress() { + auto result = LZ4F_createDecompressionContext(&dctx_, LZ4F_VERSION); + CHECK(!LZ4F_isError(result)); + } + ~Lz4Decompress() { + auto result = LZ4F_freeDecompressionContext(dctx_); + CHECK(!LZ4F_isError(result)); + } + + io::Result Decompress(std::string_view str); + + private: + LZ4F_dctx* dctx_; +}; + +io::Result Lz4Decompress::Decompress(std::string_view data) { + LZ4F_frameInfo_t frame_info; + size_t frame_size = data.size(); + + // Get content size from frame data + size_t consumed = frame_size; // The nb of bytes consumed from data will be written into consumed + size_t res = LZ4F_getFrameInfo(dctx_, &frame_info, data.data(), &consumed); + if (LZ4F_isError(res)) { + LOG(ERROR) << "LZ4F_getFrameInfo failed with error " << LZ4F_getErrorName(res); + return Unexpected(errc::rdb_file_corrupted); + ; + } + + if (frame_info.contentSize == 0) { + LOG(ERROR) << "Missing frame content size"; + return Unexpected(errc::rdb_file_corrupted); + } + + // reserve place for uncompressed data and end opcode + size_t reserve = frame_info.contentSize + 1; + uncompressed_mem_buf_.Reserve(reserve); + IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); + if (dest.size() < reserve) { + return Unexpected(errc::out_of_memory); + } + + // Uncompress data to membuf + string_view src = data.substr(consumed); + size_t src_size = src.size(); + + size_t ret = 1; + while (ret != 0) { + IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); + size_t dest_capacity = dest.size(); + + // It will read up to src_size bytes from src, + // and decompress data into dest, of capacity dest_capacity + // The nb of bytes consumed from src will be written into src_size + // The nb of bytes decompressed into dest will be written into dest_capacity + ret = LZ4F_decompress(dctx_, dest.data(), &dest_capacity, src.data(), &src_size, nullptr); + if (LZ4F_isError(ret)) { + LOG(ERROR) << "LZ4F_decompress failed with error " << LZ4F_getErrorName(ret); + return Unexpected(errc::rdb_file_corrupted); + } + consumed += src_size; + + uncompressed_mem_buf_.CommitWrite(dest_capacity); + src = src.substr(src_size); + src_size = src.size(); + } + if (consumed != frame_size) { + return Unexpected(errc::rdb_file_corrupted); + } + if (uncompressed_mem_buf_.InputLen() != frame_info.contentSize) { + return Unexpected(errc::rdb_file_corrupted); + } + + // Add opcode of compressed blob end to membuf. + dest = uncompressed_mem_buf_.AppendBuffer(); + if (dest.size() < 1) { + return Unexpected(errc::out_of_memory); + } + dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; + uncompressed_mem_buf_.CommitWrite(1); + + return &uncompressed_mem_buf_; +} + +unique_ptr DecompressImpl::CreateLZ4() { + return make_unique(); +} + +unique_ptr DecompressImpl::CreateZstd() { + return make_unique(); +} + +} // namespace detail +} // namespace dfly diff --git a/src/server/detail/decompress.h b/src/server/detail/decompress.h new file mode 100644 index 000000000..cc0a556d3 --- /dev/null +++ b/src/server/detail/decompress.h @@ -0,0 +1,32 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// +#pragma once + +#include + +#include "io/io.h" +#include "io/io_buf.h" + +namespace dfly { + +namespace detail { + +class DecompressImpl { + public: + static std::unique_ptr CreateLZ4(); + static std::unique_ptr CreateZstd(); + + DecompressImpl() : uncompressed_mem_buf_{1U << 14} { + } + virtual ~DecompressImpl() { + } + + virtual io::Result Decompress(std::string_view str) = 0; + + protected: + io::IoBuf uncompressed_mem_buf_; +}; + +} // namespace detail +} // namespace dfly diff --git a/src/server/error.cc b/src/server/error.cc new file mode 100644 index 000000000..2c0631c92 --- /dev/null +++ b/src/server/error.cc @@ -0,0 +1,57 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/error.h" + +#include + +using namespace std; + +namespace dfly { +namespace rdb { + +class error_category : public std::error_category { + public: + const char* name() const noexcept final { + return "dragonfly.rdbload"; + } + + string message(int ev) const final; + + error_condition default_error_condition(int ev) const noexcept final; + + bool equivalent(int ev, const error_condition& condition) const noexcept final { + return condition.value() == ev && &condition.category() == this; + } + + bool equivalent(const error_code& error, int ev) const noexcept final { + return error.value() == ev && &error.category() == this; + } +}; + +string error_category::message(int ev) const { + switch (ev) { + case errc::wrong_signature: + return "Wrong signature while trying to load from rdb file"; + case errc::out_of_memory: + return "Out of memory, or used memory is too high"; + default: + return absl::StrCat("Internal error when loading RDB file ", ev); + break; + } +} + +error_condition error_category::default_error_condition(int ev) const noexcept { + return error_condition{ev, *this}; +} + +static error_category rdb_category; + +} // namespace rdb + +error_code RdbError(rdb::errc ev) { + return error_code{static_cast(ev), rdb::rdb_category}; +} + +} // namespace dfly diff --git a/src/server/error.h b/src/server/error.h index 00e1218a5..59a33349e 100644 --- a/src/server/error.h +++ b/src/server/error.h @@ -79,4 +79,6 @@ enum errc { } // namespace rdb +std::error_code RdbError(rdb::errc ev); + } // namespace dfly diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index db4ddd1c7..bd926b45b 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -22,8 +22,6 @@ extern "C" { #include #include #include -#include -#include #include @@ -52,7 +50,6 @@ extern "C" { #include "server/serializer_commons.h" #include "server/server_state.h" #include "server/set_family.h" -#include "server/tiering/common.h" // for _KB literal #include "server/transaction.h" #include "strings/human_readable.h" @@ -72,55 +69,12 @@ using namespace tiering::literals; namespace { -constexpr char kErrCat[] = "dragonfly.rdbload"; - // Maximum length of each LoadTrace segment. // // Note kMaxBlobLen must be a multiple of 6 to avoid truncating elements // containing 2 or 3 items. constexpr size_t kMaxBlobLen = 4092; -class error_category : public std::error_category { - public: - const char* name() const noexcept final { - return kErrCat; - } - - string message(int ev) const final; - - error_condition default_error_condition(int ev) const noexcept final; - - bool equivalent(int ev, const error_condition& condition) const noexcept final { - return condition.value() == ev && &condition.category() == this; - } - - bool equivalent(const error_code& error, int ev) const noexcept final { - return error.value() == ev && &error.category() == this; - } -}; - -string error_category::message(int ev) const { - switch (ev) { - case errc::wrong_signature: - return "Wrong signature while trying to load from rdb file"; - case errc::out_of_memory: - return "Out of memory, or used memory is too high"; - default: - return absl::StrCat("Internal error when loading RDB file ", ev); - break; - } -} - -error_condition error_category::default_error_condition(int ev) const noexcept { - return error_condition{ev, *this}; -} - -error_category rdb_category; - -inline error_code RdbError(errc ev) { - return error_code{ev, rdb_category}; -} - inline auto Unexpected(errc ev) { return make_unexpected(RdbError(ev)); } @@ -239,152 +193,6 @@ bool RdbTypeAllowedEmpty(int type) { } // namespace -class DecompressImpl { - public: - DecompressImpl() : uncompressed_mem_buf_{16_KB} { - } - virtual ~DecompressImpl() { - } - virtual io::Result Decompress(std::string_view str) = 0; - - protected: - io::IoBuf uncompressed_mem_buf_; -}; - -class ZstdDecompress : public DecompressImpl { - public: - ZstdDecompress() { - dctx_ = ZSTD_createDCtx(); - } - ~ZstdDecompress() { - ZSTD_freeDCtx(dctx_); - } - - io::Result Decompress(std::string_view str); - - private: - ZSTD_DCtx* dctx_; -}; - -io::Result ZstdDecompress::Decompress(std::string_view str) { - // Prepare membuf memory to uncompressed string. - auto uncomp_size = ZSTD_getFrameContentSize(str.data(), str.size()); - if (uncomp_size == ZSTD_CONTENTSIZE_UNKNOWN) { - LOG(ERROR) << "Zstd compression missing frame content size"; - return Unexpected(errc::invalid_encoding); - } - if (uncomp_size == ZSTD_CONTENTSIZE_ERROR) { - LOG(ERROR) << "Invalid ZSTD compressed string"; - return Unexpected(errc::invalid_encoding); - } - - uncompressed_mem_buf_.Reserve(uncomp_size + 1); - - // Uncompress string to membuf - IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); - if (dest.size() < uncomp_size) { - return Unexpected(errc::out_of_memory); - } - size_t const d_size = - ZSTD_decompressDCtx(dctx_, dest.data(), dest.size(), str.data(), str.size()); - if (d_size == 0 || d_size != uncomp_size) { - LOG(ERROR) << "Invalid ZSTD compressed string"; - return Unexpected(errc::rdb_file_corrupted); - } - uncompressed_mem_buf_.CommitWrite(d_size); - - // Add opcode of compressed blob end to membuf. - dest = uncompressed_mem_buf_.AppendBuffer(); - if (dest.size() < 1) { - return Unexpected(errc::out_of_memory); - } - dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; - uncompressed_mem_buf_.CommitWrite(1); - - return &uncompressed_mem_buf_; -} - -class Lz4Decompress : public DecompressImpl { - public: - Lz4Decompress() { - auto result = LZ4F_createDecompressionContext(&dctx_, LZ4F_VERSION); - CHECK(!LZ4F_isError(result)); - } - ~Lz4Decompress() { - auto result = LZ4F_freeDecompressionContext(dctx_); - CHECK(!LZ4F_isError(result)); - } - - io::Result Decompress(std::string_view str); - - private: - LZ4F_dctx* dctx_; -}; - -io::Result Lz4Decompress::Decompress(std::string_view data) { - LZ4F_frameInfo_t frame_info; - size_t frame_size = data.size(); - - // Get content size from frame data - size_t consumed = frame_size; // The nb of bytes consumed from data will be written into consumed - size_t res = LZ4F_getFrameInfo(dctx_, &frame_info, data.data(), &consumed); - if (LZ4F_isError(res)) { - return make_unexpected(error_code{int(res), generic_category()}); - } - if (frame_info.contentSize == 0) { - LOG(ERROR) << "Missing frame content size"; - return Unexpected(errc::rdb_file_corrupted); - } - - // reserve place for uncompressed data and end opcode - size_t reserve = frame_info.contentSize + 1; - uncompressed_mem_buf_.Reserve(reserve); - IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); - if (dest.size() < reserve) { - return Unexpected(errc::out_of_memory); - } - - // Uncompress data to membuf - string_view src = data.substr(consumed); - size_t src_size = src.size(); - - size_t ret = 1; - while (ret != 0) { - IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer(); - size_t dest_capacity = dest.size(); - - // It will read up to src_size bytes from src, - // and decompress data into dest, of capacity dest_capacity - // The nb of bytes consumed from src will be written into src_size - // The nb of bytes decompressed into dest will be written into dest_capacity - ret = LZ4F_decompress(dctx_, dest.data(), &dest_capacity, src.data(), &src_size, nullptr); - if (LZ4F_isError(ret)) { - return make_unexpected(error_code{int(ret), generic_category()}); - } - consumed += src_size; - - uncompressed_mem_buf_.CommitWrite(dest_capacity); - src = src.substr(src_size); - src_size = src.size(); - } - if (consumed != frame_size) { - return Unexpected(errc::rdb_file_corrupted); - } - if (uncompressed_mem_buf_.InputLen() != frame_info.contentSize) { - return Unexpected(errc::rdb_file_corrupted); - } - - // Add opcode of compressed blob end to membuf. - dest = uncompressed_mem_buf_.AppendBuffer(); - if (dest.size() < 1) { - return Unexpected(errc::out_of_memory); - } - dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END; - uncompressed_mem_buf_.CommitWrite(1); - - return &uncompressed_mem_buf_; -} - class RdbLoaderBase::OpaqueObjLoader { public: OpaqueObjLoader(int rdb_type, PrimeValue* pv, LoadConfig config) @@ -2492,17 +2300,19 @@ io::Result RdbLoaderBase::LoadLen(bool* is_encoded) { return res; } -void RdbLoaderBase::AllocateDecompressOnce(int op_type) { +error_code RdbLoaderBase::AllocateDecompressOnce(int op_type) { if (decompress_impl_) { - return; + return {}; } + if (op_type == RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START) { - decompress_impl_.reset(new ZstdDecompress()); + decompress_impl_ = detail::DecompressImpl::CreateZstd(); } else if (op_type == RDB_OPCODE_COMPRESSED_LZ4_BLOB_START) { - decompress_impl_.reset(new Lz4Decompress()); + decompress_impl_ = detail::DecompressImpl::CreateLZ4(); } else { - CHECK(false) << "Decompressor allocation should not be done"; + return RdbError(errc::unsupported_operation); } + return {}; } error_code RdbLoaderBase::SkipModuleData() { @@ -2550,7 +2360,8 @@ error_code RdbLoaderBase::SkipModuleData() { } error_code RdbLoaderBase::HandleCompressedBlob(int op_type) { - AllocateDecompressOnce(op_type); + RETURN_ON_ERR(AllocateDecompressOnce(op_type)); + // Fetch uncompress blob string res; SET_OR_RETURN(FetchGenericString(), res); diff --git a/src/server/rdb_load.h b/src/server/rdb_load.h index 0ece10089..703f74231 100644 --- a/src/server/rdb_load.h +++ b/src/server/rdb_load.h @@ -14,6 +14,7 @@ extern "C" { #include "io/io.h" #include "io/io_buf.h" #include "server/common.h" +#include "server/detail/decompress.h" #include "server/journal/serializer.h" struct streamID; @@ -25,8 +26,6 @@ class ScriptMgr; class CompactObj; class Service; -class DecompressImpl; - using RdbVersion = std::uint16_t; class RdbLoaderBase { @@ -184,7 +183,7 @@ class RdbLoaderBase { std::error_code SkipModuleData(); std::error_code HandleCompressedBlob(int op_type); std::error_code HandleCompressedBlobFinish(); - void AllocateDecompressOnce(int op_type); + std::error_code AllocateDecompressOnce(int op_type); std::error_code HandleJournalBlob(Service* service); @@ -203,7 +202,7 @@ class RdbLoaderBase { size_t bytes_read_ = 0; size_t source_limit_ = SIZE_MAX; base::PODArray compr_buf_; - std::unique_ptr decompress_impl_; + std::unique_ptr decompress_impl_; JournalReader journal_reader_{nullptr, 0}; std::optional journal_offset_ = std::nullopt; RdbVersion rdb_version_ = RDB_VERSION;