chore: factor out rdb_load utilities into separate files (#4315)

* chore: factor out rdb_load utilities into separate files

rdb_load.cc is huge and contains many auxillary classes.
This PR moves DecompressImpl and ErrorRdb code into detail/

It also fixes minor bugs around error conditions with de-compression:
a. Do not check-fail on invalid opcode and return error_code instead.
b. Print correctly LZ4 errors.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>

* chore: fixes

---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-12-16 11:16:02 +02:00 committed by GitHub
parent 027eff2ad3
commit 53d6b64233
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 278 additions and 203 deletions

View file

@ -45,11 +45,12 @@ endif()
add_library(dragonfly_lib bloom_family.cc add_library(dragonfly_lib bloom_family.cc
config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc engine_shard.cc config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc engine_shard.cc
engine_shard_set.cc family_utils.cc engine_shard_set.cc error.cc family_utils.cc
generic_family.cc hset_family.cc http_api.cc json_family.cc generic_family.cc hset_family.cc http_api.cc json_family.cc
list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc
protocol_client.cc protocol_client.cc
snapshot.cc script_mgr.cc server_family.cc snapshot.cc script_mgr.cc server_family.cc
detail/decompress.cc
detail/save_stages_controller.cc detail/save_stages_controller.cc
detail/snapshot_storage.cc detail/snapshot_storage.cc
set_family.cc stream_family.cc string_family.cc set_family.cc stream_family.cc string_family.cc

View file

@ -0,0 +1,173 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/detail/decompress.h"
#include <lz4frame.h>
#include <zstd.h>
#include "base/logging.h"
#include "server/error.h"
#include "server/rdb_extensions.h"
namespace dfly {
namespace detail {
using io::IoBuf;
using rdb::errc;
using namespace std;
inline auto Unexpected(errc ev) {
return nonstd::make_unexpected(RdbError(ev));
}
class ZstdDecompress : public DecompressImpl {
public:
ZstdDecompress() {
dctx_ = ZSTD_createDCtx();
}
~ZstdDecompress() {
ZSTD_freeDCtx(dctx_);
}
io::Result<io::IoBuf*> Decompress(std::string_view str);
private:
ZSTD_DCtx* dctx_;
};
io::Result<io::IoBuf*> ZstdDecompress::Decompress(std::string_view str) {
// Prepare membuf memory to uncompressed string.
auto uncomp_size = ZSTD_getFrameContentSize(str.data(), str.size());
if (uncomp_size == ZSTD_CONTENTSIZE_UNKNOWN) {
LOG(ERROR) << "Zstd compression missing frame content size";
return Unexpected(errc::invalid_encoding);
}
if (uncomp_size == ZSTD_CONTENTSIZE_ERROR) {
LOG(ERROR) << "Invalid ZSTD compressed string";
return Unexpected(errc::invalid_encoding);
}
uncompressed_mem_buf_.Reserve(uncomp_size + 1);
// Uncompress string to membuf
IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < uncomp_size) {
return Unexpected(errc::out_of_memory);
}
size_t const d_size =
ZSTD_decompressDCtx(dctx_, dest.data(), dest.size(), str.data(), str.size());
if (d_size == 0 || d_size != uncomp_size) {
LOG(ERROR) << "Invalid ZSTD compressed string";
return Unexpected(errc::rdb_file_corrupted);
}
uncompressed_mem_buf_.CommitWrite(d_size);
// Add opcode of compressed blob end to membuf.
dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < 1) {
return Unexpected(errc::out_of_memory);
}
dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END;
uncompressed_mem_buf_.CommitWrite(1);
return &uncompressed_mem_buf_;
}
class Lz4Decompress : public DecompressImpl {
public:
Lz4Decompress() {
auto result = LZ4F_createDecompressionContext(&dctx_, LZ4F_VERSION);
CHECK(!LZ4F_isError(result));
}
~Lz4Decompress() {
auto result = LZ4F_freeDecompressionContext(dctx_);
CHECK(!LZ4F_isError(result));
}
io::Result<base::IoBuf*> Decompress(std::string_view str);
private:
LZ4F_dctx* dctx_;
};
io::Result<base::IoBuf*> Lz4Decompress::Decompress(std::string_view data) {
LZ4F_frameInfo_t frame_info;
size_t frame_size = data.size();
// Get content size from frame data
size_t consumed = frame_size; // The nb of bytes consumed from data will be written into consumed
size_t res = LZ4F_getFrameInfo(dctx_, &frame_info, data.data(), &consumed);
if (LZ4F_isError(res)) {
LOG(ERROR) << "LZ4F_getFrameInfo failed with error " << LZ4F_getErrorName(res);
return Unexpected(errc::rdb_file_corrupted);
;
}
if (frame_info.contentSize == 0) {
LOG(ERROR) << "Missing frame content size";
return Unexpected(errc::rdb_file_corrupted);
}
// reserve place for uncompressed data and end opcode
size_t reserve = frame_info.contentSize + 1;
uncompressed_mem_buf_.Reserve(reserve);
IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < reserve) {
return Unexpected(errc::out_of_memory);
}
// Uncompress data to membuf
string_view src = data.substr(consumed);
size_t src_size = src.size();
size_t ret = 1;
while (ret != 0) {
IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer();
size_t dest_capacity = dest.size();
// It will read up to src_size bytes from src,
// and decompress data into dest, of capacity dest_capacity
// The nb of bytes consumed from src will be written into src_size
// The nb of bytes decompressed into dest will be written into dest_capacity
ret = LZ4F_decompress(dctx_, dest.data(), &dest_capacity, src.data(), &src_size, nullptr);
if (LZ4F_isError(ret)) {
LOG(ERROR) << "LZ4F_decompress failed with error " << LZ4F_getErrorName(ret);
return Unexpected(errc::rdb_file_corrupted);
}
consumed += src_size;
uncompressed_mem_buf_.CommitWrite(dest_capacity);
src = src.substr(src_size);
src_size = src.size();
}
if (consumed != frame_size) {
return Unexpected(errc::rdb_file_corrupted);
}
if (uncompressed_mem_buf_.InputLen() != frame_info.contentSize) {
return Unexpected(errc::rdb_file_corrupted);
}
// Add opcode of compressed blob end to membuf.
dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < 1) {
return Unexpected(errc::out_of_memory);
}
dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END;
uncompressed_mem_buf_.CommitWrite(1);
return &uncompressed_mem_buf_;
}
unique_ptr<DecompressImpl> DecompressImpl::CreateLZ4() {
return make_unique<Lz4Decompress>();
}
unique_ptr<DecompressImpl> DecompressImpl::CreateZstd() {
return make_unique<ZstdDecompress>();
}
} // namespace detail
} // namespace dfly

View file

@ -0,0 +1,32 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <memory>
#include "io/io.h"
#include "io/io_buf.h"
namespace dfly {
namespace detail {
class DecompressImpl {
public:
static std::unique_ptr<DecompressImpl> CreateLZ4();
static std::unique_ptr<DecompressImpl> CreateZstd();
DecompressImpl() : uncompressed_mem_buf_{1U << 14} {
}
virtual ~DecompressImpl() {
}
virtual io::Result<io::IoBuf*> Decompress(std::string_view str) = 0;
protected:
io::IoBuf uncompressed_mem_buf_;
};
} // namespace detail
} // namespace dfly

57
src/server/error.cc Normal file
View file

@ -0,0 +1,57 @@
// Copyright 2024, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/error.h"
#include <absl/strings/str_cat.h>
using namespace std;
namespace dfly {
namespace rdb {
class error_category : public std::error_category {
public:
const char* name() const noexcept final {
return "dragonfly.rdbload";
}
string message(int ev) const final;
error_condition default_error_condition(int ev) const noexcept final;
bool equivalent(int ev, const error_condition& condition) const noexcept final {
return condition.value() == ev && &condition.category() == this;
}
bool equivalent(const error_code& error, int ev) const noexcept final {
return error.value() == ev && &error.category() == this;
}
};
string error_category::message(int ev) const {
switch (ev) {
case errc::wrong_signature:
return "Wrong signature while trying to load from rdb file";
case errc::out_of_memory:
return "Out of memory, or used memory is too high";
default:
return absl::StrCat("Internal error when loading RDB file ", ev);
break;
}
}
error_condition error_category::default_error_condition(int ev) const noexcept {
return error_condition{ev, *this};
}
static error_category rdb_category;
} // namespace rdb
error_code RdbError(rdb::errc ev) {
return error_code{static_cast<int>(ev), rdb::rdb_category};
}
} // namespace dfly

View file

@ -79,4 +79,6 @@ enum errc {
} // namespace rdb } // namespace rdb
std::error_code RdbError(rdb::errc ev);
} // namespace dfly } // namespace dfly

View file

@ -22,8 +22,6 @@ extern "C" {
#include <absl/strings/match.h> #include <absl/strings/match.h>
#include <absl/strings/str_cat.h> #include <absl/strings/str_cat.h>
#include <absl/strings/str_split.h> #include <absl/strings/str_split.h>
#include <lz4frame.h>
#include <zstd.h>
#include <cstring> #include <cstring>
@ -52,7 +50,6 @@ extern "C" {
#include "server/serializer_commons.h" #include "server/serializer_commons.h"
#include "server/server_state.h" #include "server/server_state.h"
#include "server/set_family.h" #include "server/set_family.h"
#include "server/tiering/common.h" // for _KB literal
#include "server/transaction.h" #include "server/transaction.h"
#include "strings/human_readable.h" #include "strings/human_readable.h"
@ -72,55 +69,12 @@ using namespace tiering::literals;
namespace { namespace {
constexpr char kErrCat[] = "dragonfly.rdbload";
// Maximum length of each LoadTrace segment. // Maximum length of each LoadTrace segment.
// //
// Note kMaxBlobLen must be a multiple of 6 to avoid truncating elements // Note kMaxBlobLen must be a multiple of 6 to avoid truncating elements
// containing 2 or 3 items. // containing 2 or 3 items.
constexpr size_t kMaxBlobLen = 4092; constexpr size_t kMaxBlobLen = 4092;
class error_category : public std::error_category {
public:
const char* name() const noexcept final {
return kErrCat;
}
string message(int ev) const final;
error_condition default_error_condition(int ev) const noexcept final;
bool equivalent(int ev, const error_condition& condition) const noexcept final {
return condition.value() == ev && &condition.category() == this;
}
bool equivalent(const error_code& error, int ev) const noexcept final {
return error.value() == ev && &error.category() == this;
}
};
string error_category::message(int ev) const {
switch (ev) {
case errc::wrong_signature:
return "Wrong signature while trying to load from rdb file";
case errc::out_of_memory:
return "Out of memory, or used memory is too high";
default:
return absl::StrCat("Internal error when loading RDB file ", ev);
break;
}
}
error_condition error_category::default_error_condition(int ev) const noexcept {
return error_condition{ev, *this};
}
error_category rdb_category;
inline error_code RdbError(errc ev) {
return error_code{ev, rdb_category};
}
inline auto Unexpected(errc ev) { inline auto Unexpected(errc ev) {
return make_unexpected(RdbError(ev)); return make_unexpected(RdbError(ev));
} }
@ -239,152 +193,6 @@ bool RdbTypeAllowedEmpty(int type) {
} // namespace } // namespace
class DecompressImpl {
public:
DecompressImpl() : uncompressed_mem_buf_{16_KB} {
}
virtual ~DecompressImpl() {
}
virtual io::Result<io::IoBuf*> Decompress(std::string_view str) = 0;
protected:
io::IoBuf uncompressed_mem_buf_;
};
class ZstdDecompress : public DecompressImpl {
public:
ZstdDecompress() {
dctx_ = ZSTD_createDCtx();
}
~ZstdDecompress() {
ZSTD_freeDCtx(dctx_);
}
io::Result<io::IoBuf*> Decompress(std::string_view str);
private:
ZSTD_DCtx* dctx_;
};
io::Result<io::IoBuf*> ZstdDecompress::Decompress(std::string_view str) {
// Prepare membuf memory to uncompressed string.
auto uncomp_size = ZSTD_getFrameContentSize(str.data(), str.size());
if (uncomp_size == ZSTD_CONTENTSIZE_UNKNOWN) {
LOG(ERROR) << "Zstd compression missing frame content size";
return Unexpected(errc::invalid_encoding);
}
if (uncomp_size == ZSTD_CONTENTSIZE_ERROR) {
LOG(ERROR) << "Invalid ZSTD compressed string";
return Unexpected(errc::invalid_encoding);
}
uncompressed_mem_buf_.Reserve(uncomp_size + 1);
// Uncompress string to membuf
IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < uncomp_size) {
return Unexpected(errc::out_of_memory);
}
size_t const d_size =
ZSTD_decompressDCtx(dctx_, dest.data(), dest.size(), str.data(), str.size());
if (d_size == 0 || d_size != uncomp_size) {
LOG(ERROR) << "Invalid ZSTD compressed string";
return Unexpected(errc::rdb_file_corrupted);
}
uncompressed_mem_buf_.CommitWrite(d_size);
// Add opcode of compressed blob end to membuf.
dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < 1) {
return Unexpected(errc::out_of_memory);
}
dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END;
uncompressed_mem_buf_.CommitWrite(1);
return &uncompressed_mem_buf_;
}
class Lz4Decompress : public DecompressImpl {
public:
Lz4Decompress() {
auto result = LZ4F_createDecompressionContext(&dctx_, LZ4F_VERSION);
CHECK(!LZ4F_isError(result));
}
~Lz4Decompress() {
auto result = LZ4F_freeDecompressionContext(dctx_);
CHECK(!LZ4F_isError(result));
}
io::Result<base::IoBuf*> Decompress(std::string_view str);
private:
LZ4F_dctx* dctx_;
};
io::Result<base::IoBuf*> Lz4Decompress::Decompress(std::string_view data) {
LZ4F_frameInfo_t frame_info;
size_t frame_size = data.size();
// Get content size from frame data
size_t consumed = frame_size; // The nb of bytes consumed from data will be written into consumed
size_t res = LZ4F_getFrameInfo(dctx_, &frame_info, data.data(), &consumed);
if (LZ4F_isError(res)) {
return make_unexpected(error_code{int(res), generic_category()});
}
if (frame_info.contentSize == 0) {
LOG(ERROR) << "Missing frame content size";
return Unexpected(errc::rdb_file_corrupted);
}
// reserve place for uncompressed data and end opcode
size_t reserve = frame_info.contentSize + 1;
uncompressed_mem_buf_.Reserve(reserve);
IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < reserve) {
return Unexpected(errc::out_of_memory);
}
// Uncompress data to membuf
string_view src = data.substr(consumed);
size_t src_size = src.size();
size_t ret = 1;
while (ret != 0) {
IoBuf::Bytes dest = uncompressed_mem_buf_.AppendBuffer();
size_t dest_capacity = dest.size();
// It will read up to src_size bytes from src,
// and decompress data into dest, of capacity dest_capacity
// The nb of bytes consumed from src will be written into src_size
// The nb of bytes decompressed into dest will be written into dest_capacity
ret = LZ4F_decompress(dctx_, dest.data(), &dest_capacity, src.data(), &src_size, nullptr);
if (LZ4F_isError(ret)) {
return make_unexpected(error_code{int(ret), generic_category()});
}
consumed += src_size;
uncompressed_mem_buf_.CommitWrite(dest_capacity);
src = src.substr(src_size);
src_size = src.size();
}
if (consumed != frame_size) {
return Unexpected(errc::rdb_file_corrupted);
}
if (uncompressed_mem_buf_.InputLen() != frame_info.contentSize) {
return Unexpected(errc::rdb_file_corrupted);
}
// Add opcode of compressed blob end to membuf.
dest = uncompressed_mem_buf_.AppendBuffer();
if (dest.size() < 1) {
return Unexpected(errc::out_of_memory);
}
dest[0] = RDB_OPCODE_COMPRESSED_BLOB_END;
uncompressed_mem_buf_.CommitWrite(1);
return &uncompressed_mem_buf_;
}
class RdbLoaderBase::OpaqueObjLoader { class RdbLoaderBase::OpaqueObjLoader {
public: public:
OpaqueObjLoader(int rdb_type, PrimeValue* pv, LoadConfig config) OpaqueObjLoader(int rdb_type, PrimeValue* pv, LoadConfig config)
@ -2492,17 +2300,19 @@ io::Result<uint64_t> RdbLoaderBase::LoadLen(bool* is_encoded) {
return res; return res;
} }
void RdbLoaderBase::AllocateDecompressOnce(int op_type) { error_code RdbLoaderBase::AllocateDecompressOnce(int op_type) {
if (decompress_impl_) { if (decompress_impl_) {
return; return {};
} }
if (op_type == RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START) { if (op_type == RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START) {
decompress_impl_.reset(new ZstdDecompress()); decompress_impl_ = detail::DecompressImpl::CreateZstd();
} else if (op_type == RDB_OPCODE_COMPRESSED_LZ4_BLOB_START) { } else if (op_type == RDB_OPCODE_COMPRESSED_LZ4_BLOB_START) {
decompress_impl_.reset(new Lz4Decompress()); decompress_impl_ = detail::DecompressImpl::CreateLZ4();
} else { } else {
CHECK(false) << "Decompressor allocation should not be done"; return RdbError(errc::unsupported_operation);
} }
return {};
} }
error_code RdbLoaderBase::SkipModuleData() { error_code RdbLoaderBase::SkipModuleData() {
@ -2550,7 +2360,8 @@ error_code RdbLoaderBase::SkipModuleData() {
} }
error_code RdbLoaderBase::HandleCompressedBlob(int op_type) { error_code RdbLoaderBase::HandleCompressedBlob(int op_type) {
AllocateDecompressOnce(op_type); RETURN_ON_ERR(AllocateDecompressOnce(op_type));
// Fetch uncompress blob // Fetch uncompress blob
string res; string res;
SET_OR_RETURN(FetchGenericString(), res); SET_OR_RETURN(FetchGenericString(), res);

View file

@ -14,6 +14,7 @@ extern "C" {
#include "io/io.h" #include "io/io.h"
#include "io/io_buf.h" #include "io/io_buf.h"
#include "server/common.h" #include "server/common.h"
#include "server/detail/decompress.h"
#include "server/journal/serializer.h" #include "server/journal/serializer.h"
struct streamID; struct streamID;
@ -25,8 +26,6 @@ class ScriptMgr;
class CompactObj; class CompactObj;
class Service; class Service;
class DecompressImpl;
using RdbVersion = std::uint16_t; using RdbVersion = std::uint16_t;
class RdbLoaderBase { class RdbLoaderBase {
@ -184,7 +183,7 @@ class RdbLoaderBase {
std::error_code SkipModuleData(); std::error_code SkipModuleData();
std::error_code HandleCompressedBlob(int op_type); std::error_code HandleCompressedBlob(int op_type);
std::error_code HandleCompressedBlobFinish(); std::error_code HandleCompressedBlobFinish();
void AllocateDecompressOnce(int op_type); std::error_code AllocateDecompressOnce(int op_type);
std::error_code HandleJournalBlob(Service* service); std::error_code HandleJournalBlob(Service* service);
@ -203,7 +202,7 @@ class RdbLoaderBase {
size_t bytes_read_ = 0; size_t bytes_read_ = 0;
size_t source_limit_ = SIZE_MAX; size_t source_limit_ = SIZE_MAX;
base::PODArray<uint8_t> compr_buf_; base::PODArray<uint8_t> compr_buf_;
std::unique_ptr<DecompressImpl> decompress_impl_; std::unique_ptr<detail::DecompressImpl> decompress_impl_;
JournalReader journal_reader_{nullptr, 0}; JournalReader journal_reader_{nullptr, 0};
std::optional<uint64_t> journal_offset_ = std::nullopt; std::optional<uint64_t> journal_offset_ = std::nullopt;
RdbVersion rdb_version_ = RDB_VERSION; RdbVersion rdb_version_ = RDB_VERSION;