diff --git a/docs/api_status.md b/docs/api_status.md index 01386a991..814338a41 100644 --- a/docs/api_status.md +++ b/docs/api_status.md @@ -119,7 +119,7 @@ with respect to Memcached and Redis APIs. - [X] SETEX - [X] APPEND - [X] PREPEND (dragonfly specific) - - [ ] BITCOUNT + - [x] BITCOUNT - [ ] BITFIELD - [ ] BITOP - [ ] BITPOS diff --git a/src/server/bitops_family.cc b/src/server/bitops_family.cc index 3a8076b3b..29d3086c9 100644 --- a/src/server/bitops_family.cc +++ b/src/server/bitops_family.cc @@ -1,4 +1,4 @@ -// Copyright 2022, Roman Gershman. All rights reserved. +// Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // @@ -26,7 +26,7 @@ namespace dfly { using namespace facade; namespace { -static const int32_t OFFSET_FACTOR = 8; // number of bits in byte +const int32_t OFFSET_FACTOR = 8; // number of bits in byte // The following is the list of the functions that would handle the // commands that handle the bit operations @@ -40,8 +40,13 @@ void SetBit(CmdArgList args, ConnectionContext* cntx); OpResult ReadValue(const OpArgs& op_args, std::string_view key); OpResult ReadValueBitsetAt(const OpArgs& op_args, std::string_view key, uint32_t offset); +OpResult CountBitsForValue(const OpArgs& op_args, std::string_view key, int64_t start, + int64_t end, bool bit_value); std::string GetString(EngineShard* shard, const PrimeValue& pv); bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry); +std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end); +std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits); +std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end); // ------------------------------------------------------------------------- // // Bits manipulation functions @@ -57,7 +62,7 @@ constexpr int32_t GetByteIndex(uint32_t offset) noexcept { return offset / OFFSET_FACTOR; } -uint8_t GetByteValue(const std::string& str, uint32_t offset) { +uint8_t GetByteValue(std::string_view str, uint32_t offset) { return static_cast(str[GetByteIndex(offset)]); } @@ -65,6 +70,86 @@ constexpr bool CheckBitStatus(uint8_t byte, uint32_t offset) { return byte & (0x1 << offset); } +constexpr std::uint8_t CountBitsRange(std::uint8_t byte, std::uint8_t from, uint8_t to) { + int count = 0; + for (int i = from; i < to; i++) { + count += CheckBitStatus(byte, GetNormalizedBitIndex(i)); + } + return count; +} + +// Count the number of bits that are on, on bytes boundaries: i.e. Start and end are the indices for +// bytes locations inside str CountBitSetByByteIndices +std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end) { + if (start >= end) { + return 0; + } + end = std::min(end, at.size()); // don't overflow + std::uint32_t count = + std::accumulate(std::next(at.begin(), start), std::next(at.begin(), end), 0, + [](auto counter, uint8_t ch) { return counter + absl::popcount(ch); }); + return count; +} + +// Count the number of bits that are on, on bits boundaries: i.e. Start and end are the indices for +// bits locations inside str +std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end) { + auto first_byte_index = GetByteIndex(start); + auto last_byte_index = GetByteIndex(end); + if (start % OFFSET_FACTOR == 0 && end % OFFSET_FACTOR == 0) { + return CountBitSetByByteIndices(at, first_byte_index, last_byte_index); + } + const auto last_bit_first_byte = + first_byte_index != last_byte_index ? OFFSET_FACTOR : GetBitIndex(end); + const auto first_byte = GetByteValue(at, start); + std::uint32_t count = CountBitsRange(first_byte, GetBitIndex(start), last_bit_first_byte); + if (first_byte_index < last_byte_index) { + first_byte_index++; + const auto last_byte = GetByteValue(at, end); + count += CountBitsRange(last_byte, 0, GetBitIndex(end)); + count += CountBitSetByByteIndices(at, first_byte_index, last_byte_index); + } + return count; +} + +// General purpose function to count the number of bits that are on. +// The parameters for start, end and bits are defaulted to the start of the string, +// end of the string and bits are false. +// Note that when bits is false, it means that we are looking on byte boundaries. +std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits) { + const int32_t size = bits ? str.size() * OFFSET_FACTOR : str.size(); + + auto NormalizedOffset = [size](int32_t orig) { + if (orig < 0) { + orig = size + orig; + } + return orig; + }; + + if (start > 0 && end > 0 && end < start) { + return 0; // for illegal range with positive we just return 0 + } + + if (start < 0 && end < 0 && start > end) { + return 0; // for illegal range with negative we just return 0 + } + + start = NormalizedOffset(start); + if (end > 0 && end < start) { + return 0; + } + end = NormalizedOffset(end); + if (start > end) { + std::swap(start, end); // we're going backward + } + if (end > size) { + end = size; // don't overflow + } + ++end; + return bits ? CountBitSetByBitIndices(str, start, end) + : CountBitSetByByteIndices(str, start, end); +} + // return true if bit is on bool GetBitValue(const std::string& entry, uint32_t offset) { const auto byte_val{GetByteValue(entry, offset)}; @@ -159,7 +244,45 @@ void BitPos(CmdArgList args, ConnectionContext* cntx) { } void BitCount(CmdArgList args, ConnectionContext* cntx) { - (*cntx)->SendLong(0); + // Support for the command BITCOUNT + // See details at https://redis.io/commands/bitcount/ + // Please note that if the key don't exists, it would return 0 + + if (args.size() == 3 || args.size() > 5) { + return (*cntx)->SendError(kSyntaxErr); + } + // return (*cntx)->SendLong(0); + std::string_view key = ArgS(args, 1); + bool as_bit = false; + int64_t start = 0; + int64_t end = std::numeric_limits::max(); + if (args.size() >= 4) { + if (absl::SimpleAtoi(ArgS(args, 2), &start) == 0 || + absl::SimpleAtoi(ArgS(args, 3), &end) == 0) { + return (*cntx)->SendError(kInvalidIntErr); + } + if (args.size() == 5) { + ToUpper(&args[4]); + as_bit = ArgS(args, 4) == "BIT"; + } + } + auto cb = [&](Transaction* t, EngineShard* shard) { + return CountBitsForValue(t->GetOpArgs(shard), key, start, end, as_bit); + }; + Transaction* trans = cntx->transaction; + OpResult result = trans->ScheduleSingleHopT(std::move(cb)); + if (result) { + (*cntx)->SendLong(result.value()); + } else { + switch (result.status()) { + case OpStatus::WRONG_TYPE: + (*cntx)->SendError(kWrongTypeErr); + break; + default: + (*cntx)->SendLong(0); + break; + } + } } void BitField(CmdArgList args, ConnectionContext* cntx) { @@ -285,6 +408,23 @@ OpResult ReadValue(const OpArgs& op_args, std::string_view key) { return GetString(op_args.shard, pv); } +OpResult CountBitsForValue(const OpArgs& op_args, std::string_view key, int64_t start, + int64_t end, bool bit_value) { + OpResult result = ReadValue(op_args, key); + + if (result) { + if (result.value().empty()) { + return 0; + } + if (end == std::numeric_limits::max()) { + end = result.value().size(); + } + return CountBitSet(result.value(), start, end, bit_value); + } else { + return result.status(); + } +} + } // namespace void BitOpsFamily::Register(CommandRegistry* registry) { diff --git a/src/server/bitops_family.h b/src/server/bitops_family.h index fe368eae1..09fcf7e63 100644 --- a/src/server/bitops_family.h +++ b/src/server/bitops_family.h @@ -1,4 +1,4 @@ -// Copyright 2021, Roman Gershman. All rights reserved. +// Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // diff --git a/src/server/bitops_family_test.cc b/src/server/bitops_family_test.cc index 9235fc193..1e1ab47d6 100644 --- a/src/server/bitops_family_test.cc +++ b/src/server/bitops_family_test.cc @@ -1,4 +1,4 @@ -// Copyright 2022, Roman Gershman. All rights reserved. +// Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // @@ -76,4 +76,54 @@ TEST_F(BitOpsFamilyTest, SetBitMissingKey) { } } +const int32_t EXPECTED_VALUES_BYTES_BIT_COUNT[] = { // got this from redis 0 as start index + 4, 7, 11, 14, 17, 21, 21, 21, 21}; + +const int32_t BYTES_EXPECTED_VALUE_LEN = + sizeof(EXPECTED_VALUES_BYTES_BIT_COUNT) / sizeof(EXPECTED_VALUES_BYTES_BIT_COUNT[0]); + +TEST_F(BitOpsFamilyTest, BitCountByte) { + // This would run without the bit flag - meaning it count on bytes boundaries + auto resp = Run({"set", "foo", "farbar"}); + EXPECT_EQ(resp, "OK"); + EXPECT_EQ(0, CheckedInt({"bitcount", "foo2"})); // on none existing key we are expecting 0 + + for (int32_t i = 0; i < BYTES_EXPECTED_VALUE_LEN; i++) { + EXPECT_EQ(EXPECTED_VALUES_BYTES_BIT_COUNT[i], + CheckedInt({"bitcount", "foo", "0", std::to_string(i)})); + } + EXPECT_EQ(21, CheckedInt({"bitcount", "foo"})); // the total number of bits in this value +} + +TEST_F(BitOpsFamilyTest, BitCountByteSubRange) { + // This test test using some sub ranges of bit count on bytes + auto resp = Run({"set", "foo", "farbar"}); + EXPECT_EQ(resp, "OK"); + EXPECT_EQ(3, CheckedInt({"bitcount", "foo", "1", "1"})); + EXPECT_EQ(7, CheckedInt({"bitcount", "foo", "1", "2"})); + EXPECT_EQ(4, CheckedInt({"bitcount", "foo", "2", "2"})); + EXPECT_EQ(0, CheckedInt({"bitcount", "foo", "3", "2"})); // illegal range + EXPECT_EQ(10, CheckedInt({"bitcount", "foo", "-3", "-1"})); + EXPECT_EQ(13, CheckedInt({"bitcount", "foo", "-5", "-2"})); + EXPECT_EQ(0, CheckedInt({"bitcount", "foo", "-1", "-2"})); // illegal range +} + +TEST_F(BitOpsFamilyTest, BitCountByteBitSubRange) { + // This test test using some sub ranges of bit count on bytes + auto resp = Run({"set", "foo", "abcdef"}); + EXPECT_EQ(resp, "OK"); + resp = Run({"bitcount", "foo", "bar", "BIT"}); + ASSERT_THAT(resp, ErrArg("value is not an integer or out of range")); + + EXPECT_EQ(1, CheckedInt({"bitcount", "foo", "1", "1", "BIT"})); + EXPECT_EQ(2, CheckedInt({"bitcount", "foo", "1", "2", "BIT"})); + EXPECT_EQ(1, CheckedInt({"bitcount", "foo", "2", "2", "BIT"})); + EXPECT_EQ(0, CheckedInt({"bitcount", "foo", "3", "2", "bit"})); // illegal range + EXPECT_EQ(2, CheckedInt({"bitcount", "foo", "-3", "-1", "bit"})); + EXPECT_EQ(2, CheckedInt({"bitcount", "foo", "-5", "-2", "bit"})); + EXPECT_EQ(4, CheckedInt({"bitcount", "foo", "1", "9", "bit"})); + EXPECT_EQ(7, CheckedInt({"bitcount", "foo", "2", "19", "bit"})); + EXPECT_EQ(0, CheckedInt({"bitcount", "foo", "-1", "-2", "bit"})); // illegal range +} + } // end of namespace dfly