chore: Add HuffmanDecoder class

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2025-05-07 18:57:10 +03:00
parent 05d99769e1
commit a17a556320
No known key found for this signature in database
GPG key ID: F25B77EAF8AEBA7A
4 changed files with 91 additions and 4 deletions

View file

@ -56,7 +56,7 @@ Checks: >
readability-duplicate-include,
readability-function-size,
readability-identifier-naming,
readability-inconsistent-declaration-parameter-name,
-readability-inconsistent-declaration-parameter-name,
readability-make-member-function-const,
readability-misplaced-array-index,
readability-named-parameter,

View file

@ -192,6 +192,7 @@ TEST_F(StringMatchTest, Special) {
class HuffCoderTest : public ::testing::Test {
protected:
HuffmanEncoder encoder_;
HuffmanDecoder decoder_;
string error_msg_;
const string_view good_table_{
"\x1b\x10\xd8\n\n\x19\xc6\x0c\xc3\x30\x0c\x43\x1e\x93\xe4\x11roB\xf6\xde\xbb\x18V\xc2Zk\x03"sv};
@ -220,12 +221,39 @@ TEST_F(HuffCoderTest, Encode) {
string data("x:23xx");
uint8_t dest[100];
uint32_t dest_size = sizeof(dest);
ASSERT_TRUE(encoder_.Encode(data, dest, &dest_size, &error_msg_));
array<uint8_t, 100> dest;
uint32_t dest_size = dest.size();
ASSERT_TRUE(encoder_.Encode(data, dest.data(), &dest_size, &error_msg_));
ASSERT_EQ(3, dest_size);
}
TEST_F(HuffCoderTest, Decode) {
array<unsigned, 256> hist;
hist.fill(1);
hist['a'] = 100;
hist['b'] = 50;
ASSERT_TRUE(encoder_.Build(hist.data(), hist.size() - 1, &error_msg_));
string data("aab");
array<uint8_t, 100> encoded{0};
uint32_t encoded_size = encoded.size();
ASSERT_TRUE(encoder_.Encode(data, encoded.data(), &encoded_size, &error_msg_));
ASSERT_EQ(1, encoded_size);
EXPECT_EQ(2, encoder_.GetNBits('a'));
EXPECT_EQ(3, encoder_.GetNBits('b'));
string bindata = encoder_.Export();
ASSERT_TRUE(decoder_.Load(bindata, &error_msg_)) << error_msg_;
const char* src_ptr = reinterpret_cast<const char*>(encoded.data());
array<char, 100> decode_dest{0};
size_t decoded_size = data.size();
ASSERT_TRUE(decoder_.Decode({src_ptr, encoded_size}, decoded_size, decode_dest.data()));
ASSERT_EQ("aab", string_view(decode_dest.data(), decoded_size));
}
using benchmark::DoNotOptimize;
// Parse Double benchmarks

View file

@ -78,6 +78,11 @@ bool HuffmanEncoder::Encode(std::string_view data, uint8_t* dest, uint32_t* dest
return true;
}
unsigned HuffmanEncoder::GetNBits(uint8_t symbol) const {
DCHECK(huf_ctable_);
return HUF_getNbBitsFromCTable(huf_ctable_.get(), symbol);
}
unsigned HuffmanEncoder::BitCount(uint8_t symbol) const {
DCHECK(huf_ctable_);
return HUF_getNbBitsFromCTable(huf_ctable_.get(), symbol);
@ -107,4 +112,39 @@ string HuffmanEncoder::Export() const {
return res;
}
bool HuffmanDecoder::Load(std::string_view binary_data, std::string* error_msg) {
DCHECK(!huf_dtable_);
huf_dtable_.reset(new HUF_DTable[HUF_DTABLE_SIZE(HUF_TABLELOG_MAX)]);
huf_dtable_[0] = (HUF_TABLELOG_MAX - 1) * 0x01000001; // some sort of magic number
constexpr size_t kWspSize = HUF_DECOMPRESS_WORKSPACE_SIZE;
unique_ptr<uint8_t[]> wrksp(new uint8_t[kWspSize]);
size_t res = HUF_readDTableX1_wksp(huf_dtable_.get(), binary_data.data(), binary_data.size(),
wrksp.get(), kWspSize, 0);
if (HUF_isError(res)) {
*error_msg = HUF_getErrorName(res);
huf_dtable_.reset();
return false;
}
if (res != binary_data.size()) {
*error_msg = "Corrupted data";
huf_dtable_.reset();
return false;
}
return true;
}
bool HuffmanDecoder::Decode(std::string_view src, size_t dest_size, char* dest) const {
DCHECK(huf_dtable_);
size_t res =
HUF_decompress1X_usingDTable(dest, dest_size, src.data(), src.size(), huf_dtable_.get(), 1);
if (HUF_isError(res)) {
LOG(FATAL) << "Failed to decompress: " << HUF_getErrorName(res);
return false;
}
return true;
}
} // namespace dfly

View file

@ -40,6 +40,8 @@ class HuffmanEncoder {
return table_max_symbol_;
}
unsigned GetNBits(uint8_t symbol) const;
private:
using HUF_CElt = size_t;
std::unique_ptr<HUF_CElt[]> huf_ctable_;
@ -47,4 +49,21 @@ class HuffmanEncoder {
uint8_t num_bits_ = 0;
};
class HuffmanDecoder {
public:
bool Load(std::string_view binary_data, std::string* error_msg);
bool valid() const {
return bool(huf_dtable_);
}
// decoded_size should be the *precise* size of the decoded data, otherwise the function will
// fail. dest should point to a buffer of at least decoded_size bytes.
// Returns true if decompression was successful, false if the data is corrupted.
bool Decode(std::string_view src, size_t decoded_size, char* dest) const;
private:
using HUF_DTable = uint32_t;
std::unique_ptr<HUF_DTable[]> huf_dtable_;
};
} // namespace dfly