From c7974a4e8062843dfa2c5550b9f75dfecb0acd17 Mon Sep 17 00:00:00 2001 From: adiholden Date: Sun, 4 Dec 2022 11:14:24 +0200 Subject: [PATCH] =?UTF-8?q?bug(rdb=20loader):=20When=20reading=20from=20zs?= =?UTF-8?q?td=20uncompressed=20buf=20skip=20ensure=20=E2=80=A6=20(#525)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bug(rdb loader): When reading from zstd uncompressed buf skip ensure read flow Signed-off-by: adi_holden --- src/server/rdb_load.cc | 31 ++++++++++++++++++++++++------- src/server/rdb_load.h | 7 +------ src/server/rdb_test.cc | 13 +++++++++++++ 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index 6e5c4e011..a4357fdd5 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -38,12 +38,14 @@ ABSL_DECLARE_FLAG(int32_t, list_compress_depth); ABSL_DECLARE_FLAG(uint32_t, dbnum); ABSL_DECLARE_FLAG(bool, use_set2); -#define SET_OR_RETURN(expr, dest) \ - do { \ - auto exp_val = (expr); \ - if (!exp_val) \ - return exp_val.error(); \ - dest = exp_val.value(); \ +#define SET_OR_RETURN(expr, dest) \ + do { \ + auto exp_val = (expr); \ + if (!exp_val) { \ + VLOG(1) << "Error while calling " #expr; \ + return exp_val.error(); \ + } \ + dest = exp_val.value(); \ } while (0) #define SET_OR_UNEXPECT(expr, dest) \ @@ -230,6 +232,7 @@ io::Result ZstdDecompressImpl::Decompress(std::string_view str) { LOG(ERROR) << "Invalid ZSTD compressed string"; return Unexpected(errc::invalid_encoding); } + uncompressed_mem_buf_.Reserve(uncomp_size + 1); // Uncompress string to membuf @@ -1705,6 +1708,19 @@ error_code RdbLoader::Load(io::Source* src) { return kOk; } +std::error_code RdbLoaderBase::EnsureRead(size_t min_sz) { + // In the flow of reading compressed data, we store the uncompressed data to in uncompressed + // buffer. When parsing entries we call ensure read with 9 bytes to read the length of key/value. + // If the key/value is very small (less than 9 bytes) the remainded data in uncompressed buffer + // might contain less than 9 bytes. We need to make sure that we dont read from sink to the + // uncompressed buffer and therefor in this flow we return here. + if (mem_buf_ != &origin_mem_buf_) + return std::error_code{}; + if (mem_buf_->InputLen() >= min_sz) + return std::error_code{}; + return EnsureReadInternal(min_sz); +} + error_code RdbLoaderBase::EnsureReadInternal(size_t min_sz) { DCHECK_LT(mem_buf_->InputLen(), min_sz); @@ -1792,7 +1808,8 @@ error_code RdbLoaderBase::HandleCompressedBlob() { } error_code RdbLoaderBase::HandleCompressedBlobFinish() { - // TODO validate that all uncompressed data was fetched + CHECK_NE(&origin_mem_buf_, mem_buf_); + CHECK_EQ(mem_buf_->InputLen(), size_t(0)); mem_buf_ = &origin_mem_buf_; return kOk; } diff --git a/src/server/rdb_load.h b/src/server/rdb_load.h index 77136b608..b181ade4b 100644 --- a/src/server/rdb_load.h +++ b/src/server/rdb_load.h @@ -132,12 +132,7 @@ class RdbLoaderBase { static size_t StrLen(const RdbVariant& tset); - std::error_code EnsureRead(size_t min_sz) { - if (mem_buf_->InputLen() >= min_sz) - return std::error_code{}; - - return EnsureReadInternal(min_sz); - } + std::error_code EnsureRead(size_t min_sz); std::error_code EnsureReadInternal(size_t min_sz); diff --git a/src/server/rdb_test.cc b/src/server/rdb_test.cc index ff17cb499..fafbc5e1e 100644 --- a/src/server/rdb_test.cc +++ b/src/server/rdb_test.cc @@ -159,6 +159,19 @@ TEST_F(RdbTest, ComressionModeSaveDragonflyAndReload) { } } +TEST_F(RdbTest, RdbLoaderOnReadCompressedDataShouldNotEnterEnsureReadFlow) { + SetFlag(&FLAGS_compression_mode, 2); + for (int i = 0; i < 1000; ++i) { + Run({"set", StrCat(i), "1"}); + } + RespExpr resp = Run({"save", "df"}); + ASSERT_EQ(resp, "OK"); + + auto save_info = service_->server_family().GetLastSaveInfo(); + resp = Run({"debug", "load", save_info->file_name}); + ASSERT_EQ(resp, "OK"); +} + TEST_F(RdbTest, Reload) { absl::FlagSaver fs;