diff --git a/src/core/bloom.cc b/src/core/bloom.cc index ca55fa19f..179cd19b5 100644 --- a/src/core/bloom.cc +++ b/src/core/bloom.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include "base/logging.h" diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index ae7c12249..53dca4677 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -275,22 +275,6 @@ struct Connection::Shutdown { } }; -Connection::PubMessage::PubMessage(string pattern, shared_ptr buf, size_t channel_len, - size_t message_len) - : pattern{std::move(pattern)}, - buf{std::move(buf)}, - channel_len{channel_len}, - message_len{message_len} { -} - -string_view Connection::PubMessage::Channel() const { - return {buf.get(), channel_len}; -} - -string_view Connection::PubMessage::Message() const { - return {buf.get() + channel_len, message_len}; -} - void Connection::PipelineMessage::SetArgs(const RespVec& args) { auto* next = storage.data(); for (size_t i = 0; i < args.size(); ++i) { @@ -361,7 +345,7 @@ size_t Connection::PipelineMessage::StorageCapacity() const { size_t Connection::MessageHandle::UsedMemory() const { struct MessageSize { size_t operator()(const PubMessagePtr& msg) { - return sizeof(PubMessage) + (msg->channel_len + msg->message_len); + return sizeof(PubMessage) + (msg->channel.size() + msg->message.size()); } size_t operator()(const PipelineMessagePtr& msg) { return sizeof(PipelineMessage) + msg->args.capacity() * sizeof(MutableSlice) + @@ -449,8 +433,8 @@ void Connection::DispatchOperations::operator()(const PubMessage& pub_msg) { arr[i++] = "pmessage"; arr[i++] = pub_msg.pattern; } - arr[i++] = pub_msg.Channel(); - arr[i++] = pub_msg.Message(); + arr[i++] = pub_msg.channel; + arr[i++] = pub_msg.message; rbuilder->SendStringArr(absl::Span{arr.data(), i}, RedisReplyBuilder::CollectionType::PUSH); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 2635f1197..951864e97 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -71,15 +71,9 @@ class Connection : public util::Connection { // PubSub message, either incoming message for active subscription or reply for new subscription. struct PubMessage { - std::string pattern{}; // non-empty for pattern subscriber - std::shared_ptr buf; // stores channel name and message - size_t channel_len, message_len; // lengths in buf - - std::string_view Channel() const; - std::string_view Message() const; - - PubMessage(std::string pattern, std::shared_ptr buf, size_t channel_len, - size_t message_len); + std::string pattern{}; // non-empty for pattern subscriber + std::shared_ptr buf; // stores channel name and message + std::string_view channel, message; // channel and message parts from buf }; // Pipeline message, accumulated Redis command to be executed. diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index d0f3b9df7..80b296a54 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -496,7 +496,7 @@ void RedisReplyBuilder::SendMGetResponse(MGetResponse resp) { void RedisReplyBuilder::SendSimpleStrArr(StrSpan arr) { string res = absl::StrCat("*", arr.Size(), kCRLF); - for (std::string_view str : arr) + for (string_view str : arr) StrAppend(&res, "+", str, kCRLF); SendRaw(res); diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 4533d8bad..62449efc4 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -32,7 +32,7 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc - ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc) + ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc channel_store.cc) SET(DF_SEARCH_SRCS search/search_family.cc search/doc_index.cc search/doc_accessors.cc search/aggregator.cc) @@ -43,7 +43,7 @@ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") cxx_test(tiered_storage_test dfly_test_lib LABELS DFLY) endif() -add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc channel_store.cc +add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc generic_family.cc hset_family.cc http_api.cc json_family.cc list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 822f82d9d..d1d7dbe83 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -10,6 +10,8 @@ extern "C" { #include "redis/util.h" } +#include + #include "base/logging.h" #include "server/engine_shard_set.h" #include "server/server_state.h" @@ -23,6 +25,32 @@ bool Matches(string_view pattern, string_view channel) { return stringmatchlen(pattern.data(), pattern.size(), channel.data(), channel.size(), 0) == 1; } +// Build functor for sending messages to connection +auto BuildSender(string_view channel, facade::ArgRange messages) { + absl::FixedArray views(messages.Size()); + size_t messages_size = accumulate(messages.begin(), messages.end(), 0, + [](int sum, string_view str) { return sum + str.size(); }); + auto buf = shared_ptr{new char[channel.size() + messages_size]}; + { + memcpy(buf.get(), channel.data(), channel.size()); + char* ptr = buf.get() + channel.size(); + + size_t i = 0; + for (string_view message : messages) { + memcpy(ptr, message.data(), message.size()); + views[i++] = {ptr, message.size()}; + ptr += message.size(); + } + } + + return [channel, buf = std::move(buf), views = std::move(views)](facade::Connection* conn, + string pattern) { + string_view channel_view{buf.get(), channel.size()}; + for (std::string_view message_view : views) + conn->SendPubMessageAsync({std::move(pattern), buf, channel_view, message_view}); + }; +} + } // namespace bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) { @@ -95,6 +123,39 @@ void ChannelStore::Destroy() { ChannelStore::ControlBlock ChannelStore::control_block; +unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange messages) const { + vector subscribers = FetchSubscribers(channel); + if (subscribers.empty()) + return 0; + + // Make sure none of the threads publish buffer limits is reached. We don't reserve memory ahead + // and don't prevent the buffer from possibly filling, but the approach is good enough for + // limiting fast producers. Most importantly, we can use DispatchBrief below as we block here + optional last_thread; + for (auto& sub : subscribers) { + DCHECK_LE(last_thread.value_or(0), sub.Thread()); + if (last_thread && *last_thread == sub.Thread()) // skip same thread + continue; + + if (sub.EnsureMemoryBudget()) // Invalid pointers are skipped + last_thread = sub.Thread(); + } + + auto subscribers_ptr = make_shared(std::move(subscribers)); + auto cb = [subscribers_ptr, send = BuildSender(channel, messages)](unsigned idx, auto*) { + auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx, + ChannelStore::Subscriber::ByThreadId); + while (it != subscribers_ptr->end() && it->Thread() == idx) { + if (auto* ptr = it->Get(); ptr) + send(ptr, it->pattern); + it++; + } + }; + shard_set->pool()->DispatchBrief(std::move(cb)); + + return subscribers_ptr->size(); +} + vector ChannelStore::FetchSubscribers(string_view channel) const { vector res; diff --git a/src/server/channel_store.h b/src/server/channel_store.h index 7f09d4b50..2a67606c5 100644 --- a/src/server/channel_store.h +++ b/src/server/channel_store.h @@ -54,6 +54,9 @@ class ChannelStore { ChannelStore(); + // Send messages to channel, block on connection backpressure + unsigned SendMessages(std::string_view channel, facade::ArgRange messages) const; + // Fetch all subscribers for channel, including matching patterns. std::vector FetchSubscribers(std::string_view channel) const; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 6d5aed565..4fc84e43a 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -9,6 +9,7 @@ #include "base/flags.h" #include "base/logging.h" #include "generic_family.h" +#include "server/channel_store.h" #include "server/cluster/cluster_defs.h" #include "server/engine_shard_set.h" #include "server/error.h" @@ -33,6 +34,9 @@ ABSL_FLAG(double, table_growth_margin, 0.4, "Prevents table from growing if number of free slots x average object size x this ratio " "is larger than memory budget."); +ABSL_FLAG(std::string, notify_keyspace_events, "", + "notify-keyspace-events. Only Ex is supported for now"); + namespace dfly { using namespace std; @@ -204,9 +208,7 @@ unsigned PrimeEvictionPolicy::Evict(const PrimeTable::HotspotBuckets& eb, PrimeT // log the evicted keys to journal. if (auto journal = db_slice_->shard_owner()->journal(); journal) { - ArgSlice delete_args(&key, 1); - journal->RecordEntry(0, journal::Op::EXPIRED, cntx_.db_index, 1, cluster::KeySlot(key), - Payload("DEL", delete_args), false); + RecordExpiry(cntx_.db_index, key); } db_slice_->PerformDeletion(DbSlice::Iterator(last_slot_it, StringOrView::FromView(key)), table); @@ -268,6 +270,13 @@ DbSlice::DbSlice(uint32_t index, bool caching_mode, EngineShard* owner) CreateDb(0); expire_base_[0] = expire_base_[1] = 0; soft_budget_limit_ = (0.3 * max_memory_limit / shard_set->size()); + + std::string keyspace_events = GetFlag(FLAGS_notify_keyspace_events); + if (!keyspace_events.empty() && keyspace_events != "Ex") { + LOG(ERROR) << "Only Ex is currently supported"; + exit(0); + } + expired_keys_events_recording_ = !keyspace_events.empty(); } DbSlice::~DbSlice() { @@ -1041,11 +1050,15 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato << ", expire table size: " << db->expire.size() << ", prime table size: " << db->prime.size() << util::fb2::GetStacktrace(); } + // Replicate expiry if (auto journal = owner_->journal(); journal) { RecordExpiry(cntx.db_index, key); } + if (expired_keys_events_recording_) + db->expired_keys_events_.emplace_back(key); + auto obj_type = it->second.ObjType(); if (doc_del_cb_ && (obj_type == OBJ_JSON || obj_type == OBJ_HASH)) { doc_del_cb_(key, cntx, it->second); @@ -1160,6 +1173,13 @@ auto DbSlice::DeleteExpiredStep(const Context& cntx, unsigned count) -> DeleteEx } } + // Send and clear accumulated expired key events + if (auto& events = db_arr_[cntx.db_index]->expired_keys_events_; !events.empty()) { + ChannelStore* store = ServerState::tlocal()->channel_store(); + store->SendMessages(absl::StrCat("__keyevent@", cntx.db_index, "__:expired"), events); + events.clear(); + } + return result; } @@ -1188,6 +1208,8 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes string tmp; int32_t starting_segment_id = rand() % num_segments; size_t used_memory_before = owner_->UsedMemory(); + + bool record_keys = owner_->journal() != nullptr || expired_keys_events_recording_; vector keys_to_journal; { @@ -1216,9 +1238,8 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes if (lt.Find(LockTag(key)).has_value()) continue; - if (auto journal = owner_->journal(); journal) { - keys_to_journal.push_back(string(key)); - } + if (record_keys) + keys_to_journal.emplace_back(key); PerformDeletion(Iterator(evict_it, StringOrView::FromView(key)), db_table.get()); ++evicted; @@ -1236,12 +1257,12 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes finish: // send the deletion to the replicas. // fiber preemption could happen in this phase. - if (auto journal = owner_->journal(); journal) { - for (string_view key : keys_to_journal) { - ArgSlice delete_args(&key, 1); - journal->RecordEntry(0, journal::Op::EXPIRED, db_ind, 1, cluster::KeySlot(key), - Payload("DEL", delete_args), false); - } + for (string_view key : keys_to_journal) { + if (auto journal = owner_->journal(); journal) + RecordExpiry(db_ind, key); + + if (expired_keys_events_recording_) + db_table->expired_keys_events_.emplace_back(key); } auto time_finish = absl::GetCurrentTimeNanos(); diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 6e5184a67..72c113820 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -569,6 +569,9 @@ class DbSlice { // Registered by shard indices on when first document index is created. DocDeletionCallback doc_del_cb_; + // Record whenever a key expired to DbTable::expired_keys_events_ for keyspace notifications + bool expired_keys_events_recording_ = true; + struct Hash { size_t operator()(const facade::Connection::WeakRef& c) const { return std::hash()(c.GetClientId()); diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index 2d7219f42..aa1558935 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -514,8 +514,8 @@ TEST_F(DflyEngineTest, PSubscribe) { ASSERT_EQ(1, SubscriberMessagesLen("IO1")); const auto& msg = GetPublishedMessage("IO1", 0); - EXPECT_EQ("foo", msg.Message()); - EXPECT_EQ("ab", msg.Channel()); + EXPECT_EQ("foo", msg.message); + EXPECT_EQ("ab", msg.channel); EXPECT_EQ("a*", msg.pattern); } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 399d25a05..1b38e35ed 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -2243,49 +2243,10 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { void Service::Publish(CmdArgList args, ConnectionContext* cntx) { string_view channel = ArgS(args, 0); - string_view msg = ArgS(args, 1); + string_view messages[] = {ArgS(args, 1)}; auto* cs = ServerState::tlocal()->channel_store(); - vector subscribers = cs->FetchSubscribers(channel); - int num_published = subscribers.size(); - if (!subscribers.empty()) { - // Make sure neither of the threads limits is reached. - // This check actually doesn't reserve any memory ahead and doesn't prevent the buffer - // from eventually filling up, especially if multiple clients are unblocked simultaneously, - // but is generally good enough to limit too fast producers. - // Most importantly, this approach allows not blocking and not awaiting in the dispatch below, - // thus not adding any overhead to backpressure checks. - optional last_thread; - for (auto& sub : subscribers) { - DCHECK_LE(last_thread.value_or(0), sub.Thread()); - if (last_thread && *last_thread == sub.Thread()) // skip same thread - continue; - - if (sub.EnsureMemoryBudget()) // Invalid pointers are skipped - last_thread = sub.Thread(); - } - - auto subscribers_ptr = make_shared(std::move(subscribers)); - auto buf = shared_ptr{new char[channel.size() + msg.size()]}; - memcpy(buf.get(), channel.data(), channel.size()); - memcpy(buf.get() + channel.size(), msg.data(), msg.size()); - - auto cb = [subscribers_ptr, buf, channel, msg](unsigned idx, util::ProactorBase*) { - auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx, - ChannelStore::Subscriber::ByThreadId); - - while (it != subscribers_ptr->end() && it->Thread() == idx) { - if (auto* ptr = it->Get(); ptr) { - ptr->SendPubMessageAsync( - {std::move(it->pattern), std::move(buf), channel.size(), msg.size()}); - } - it++; - } - }; - shard_set->pool()->DispatchBrief(std::move(cb)); - } - - cntx->SendLong(num_published); + cntx->SendLong(cs->SendMessages(channel, messages)); } void Service::Subscribe(CmdArgList args, ConnectionContext* cntx) { diff --git a/src/server/table.h b/src/server/table.h index 97f2e612e..cfc44d0b4 100644 --- a/src/server/table.h +++ b/src/server/table.h @@ -123,6 +123,9 @@ struct DbTable : boost::intrusive_ref_counter> watched_keys; + // Keyspace notifications: list of expired keys since last batch of messages was published. + mutable std::vector expired_keys_events_; + mutable DbTableStats stats; std::vector slots_stats; ExpireTable::Cursor expire_cursor; diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py index 9eaf81af2..cda6d6d32 100755 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -410,6 +410,32 @@ async def test_subscribers_with_active_publisher(df_server: DflyInstance, max_co await async_pool.disconnect() +@dfly_args({"notify_keyspace_events": "Ex"}) +async def test_keyspace_events(async_client: aioredis.Redis): + pclient = async_client.pubsub() + await pclient.subscribe("__keyevent@0__:expired") + + keys = [] + for i in range(10, 50): + keys.append(f"k{i}") + await async_client.set(keys[-1], "X", px=200 + i * 10) + + # We don't support immediate expiration: + # keys += ['immediate'] + # await async_client.set(keys[-1], 'Y', exat=123) # expired 50 years ago + + events = [] + async for message in pclient.listen(): + if message["type"] == "subscribe": + continue + + events.append(message) + if len(events) >= len(keys): + break + + assert set(ev["data"] for ev in events) == set(keys) + + async def test_big_command(df_server, size=8 * 1024): reader, writer = await asyncio.open_connection("127.0.0.1", df_server.port)