From ad5aa66350ffa38465262d16c7ac301f8ee2bc50 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Wed, 7 May 2025 19:58:43 +0300 Subject: [PATCH] chore: Add HuffmanDecoder class (#5078) Signed-off-by: Roman Gershman --- .clang-tidy | 2 +- src/core/dfly_core_test.cc | 34 +++++++++++++++++++++++++++++--- src/core/huff_coder.cc | 40 ++++++++++++++++++++++++++++++++++++++ src/core/huff_coder.h | 19 ++++++++++++++++++ 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index be27a8182..17e45ba8b 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -56,7 +56,7 @@ Checks: > readability-duplicate-include, readability-function-size, readability-identifier-naming, - readability-inconsistent-declaration-parameter-name, + -readability-inconsistent-declaration-parameter-name, readability-make-member-function-const, readability-misplaced-array-index, readability-named-parameter, diff --git a/src/core/dfly_core_test.cc b/src/core/dfly_core_test.cc index b5f907c2b..f5f22274f 100644 --- a/src/core/dfly_core_test.cc +++ b/src/core/dfly_core_test.cc @@ -192,6 +192,7 @@ TEST_F(StringMatchTest, Special) { class HuffCoderTest : public ::testing::Test { protected: HuffmanEncoder encoder_; + HuffmanDecoder decoder_; string error_msg_; const string_view good_table_{ "\x1b\x10\xd8\n\n\x19\xc6\x0c\xc3\x30\x0c\x43\x1e\x93\xe4\x11roB\xf6\xde\xbb\x18V\xc2Zk\x03"sv}; @@ -220,12 +221,39 @@ TEST_F(HuffCoderTest, Encode) { string data("x:23xx"); - uint8_t dest[100]; - uint32_t dest_size = sizeof(dest); - ASSERT_TRUE(encoder_.Encode(data, dest, &dest_size, &error_msg_)); + array dest; + uint32_t dest_size = dest.size(); + ASSERT_TRUE(encoder_.Encode(data, dest.data(), &dest_size, &error_msg_)); ASSERT_EQ(3, dest_size); } +TEST_F(HuffCoderTest, Decode) { + array hist; + hist.fill(1); + hist['a'] = 100; + hist['b'] = 50; + + ASSERT_TRUE(encoder_.Build(hist.data(), hist.size() - 1, &error_msg_)); + string data("aab"); + + array encoded{0}; + uint32_t encoded_size = encoded.size(); + ASSERT_TRUE(encoder_.Encode(data, encoded.data(), &encoded_size, &error_msg_)); + ASSERT_EQ(1, encoded_size); + + EXPECT_EQ(2, encoder_.GetNBits('a')); + EXPECT_EQ(3, encoder_.GetNBits('b')); + + string bindata = encoder_.Export(); + ASSERT_TRUE(decoder_.Load(bindata, &error_msg_)) << error_msg_; + + const char* src_ptr = reinterpret_cast(encoded.data()); + array decode_dest{0}; + size_t decoded_size = data.size(); + ASSERT_TRUE(decoder_.Decode({src_ptr, encoded_size}, decoded_size, decode_dest.data())); + ASSERT_EQ("aab", string_view(decode_dest.data(), decoded_size)); +} + using benchmark::DoNotOptimize; // Parse Double benchmarks diff --git a/src/core/huff_coder.cc b/src/core/huff_coder.cc index aecd2acc7..c74d92e14 100644 --- a/src/core/huff_coder.cc +++ b/src/core/huff_coder.cc @@ -78,6 +78,11 @@ bool HuffmanEncoder::Encode(std::string_view data, uint8_t* dest, uint32_t* dest return true; } +unsigned HuffmanEncoder::GetNBits(uint8_t symbol) const { + DCHECK(huf_ctable_); + return HUF_getNbBitsFromCTable(huf_ctable_.get(), symbol); +} + unsigned HuffmanEncoder::BitCount(uint8_t symbol) const { DCHECK(huf_ctable_); return HUF_getNbBitsFromCTable(huf_ctable_.get(), symbol); @@ -107,4 +112,39 @@ string HuffmanEncoder::Export() const { return res; } +bool HuffmanDecoder::Load(std::string_view binary_data, std::string* error_msg) { + DCHECK(!huf_dtable_); + huf_dtable_.reset(new HUF_DTable[HUF_DTABLE_SIZE(HUF_TABLELOG_MAX)]); + huf_dtable_[0] = (HUF_TABLELOG_MAX - 1) * 0x01000001; // some sort of magic number + + constexpr size_t kWspSize = HUF_DECOMPRESS_WORKSPACE_SIZE; + unique_ptr wrksp(new uint8_t[kWspSize]); + + size_t res = HUF_readDTableX1_wksp(huf_dtable_.get(), binary_data.data(), binary_data.size(), + wrksp.get(), kWspSize, 0); + if (HUF_isError(res)) { + *error_msg = HUF_getErrorName(res); + huf_dtable_.reset(); + return false; + } + if (res != binary_data.size()) { + *error_msg = "Corrupted data"; + huf_dtable_.reset(); + return false; + } + return true; +} + +bool HuffmanDecoder::Decode(std::string_view src, size_t dest_size, char* dest) const { + DCHECK(huf_dtable_); + size_t res = + HUF_decompress1X_usingDTable(dest, dest_size, src.data(), src.size(), huf_dtable_.get(), 1); + + if (HUF_isError(res)) { + LOG(FATAL) << "Failed to decompress: " << HUF_getErrorName(res); + return false; + } + return true; +} + } // namespace dfly diff --git a/src/core/huff_coder.h b/src/core/huff_coder.h index d9bf13b1e..11e9f9511 100644 --- a/src/core/huff_coder.h +++ b/src/core/huff_coder.h @@ -40,6 +40,8 @@ class HuffmanEncoder { return table_max_symbol_; } + unsigned GetNBits(uint8_t symbol) const; + private: using HUF_CElt = size_t; std::unique_ptr huf_ctable_; @@ -47,4 +49,21 @@ class HuffmanEncoder { uint8_t num_bits_ = 0; }; +class HuffmanDecoder { + public: + bool Load(std::string_view binary_data, std::string* error_msg); + bool valid() const { + return bool(huf_dtable_); + } + + // decoded_size should be the *precise* size of the decoded data, otherwise the function will + // fail. dest should point to a buffer of at least decoded_size bytes. + // Returns true if decompression was successful, false if the data is corrupted. + bool Decode(std::string_view src, size_t decoded_size, char* dest) const; + + private: + using HUF_DTable = uint32_t; + std::unique_ptr huf_dtable_; +}; + } // namespace dfly