diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index 6e4d05cef..c464001d2 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -16,6 +16,7 @@ extern "C" { #include "base/flags.h" #include "base/logging.h" +#include "core/qlist.h" #include "redis/rdb.h" #include "server/acl/acl_commands_def.h" #include "server/blocking_controller.h" @@ -1434,10 +1435,34 @@ OpResultTyped OpFetchSortEntries(const OpArgs& op_args, std::stri return success ? res : OpStatus::INVALID_NUMERIC_RESULT; } +template +OpResult OpStore(const OpArgs& op_args, std::string_view key, IteratorBegin&& start_it, + IteratorEnd&& end_it) { + uint32_t len = 0; + + QList* ql_v2 = CompactObj::AllocateMR(); + QList::Where where = QList::TAIL; + for (auto it = start_it; it != end_it; ++it) { + ql_v2->Push(it->key, where); + } + len = ql_v2->Size(); + + PrimeValue pv; + pv.InitRobj(OBJ_LIST, kEncodingQL2, ql_v2); + + // This would overwrite existing value if any with new list. + // Set the expiry at 300 seconds. + auto op_res = op_args.GetDbSlice().AddOrUpdate(op_args.db_cntx, key, std::move(pv), 300000); + RETURN_ON_BAD_STATUS(op_res); + + return len; +} + void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) { std::string_view key = ArgS(args, 0); bool alpha = false; bool reversed = false; + std::optional store_key; std::optional> bounds; auto* builder = cmd_cntx.rb; for (size_t i = 1; i < args.size(); i++) { @@ -1459,16 +1484,38 @@ void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) { } bounds = {offset, limit}; i += 2; + } else if (arg == "STORE") { + if (i + 1 >= args.size()) { + return builder->SendError(kSyntaxErr); + } + store_key = ArgS(args, i + 1); + i += 1; } else { LOG_EVERY_T(ERROR, 1) << "Unsupported option " << arg; return builder->SendError(kSyntaxErr, kSyntaxErrType); } } - OpResultTyped fetch_result = - cmd_cntx.tx->ScheduleSingleHopT([&](Transaction* t, EngineShard* shard) { - return OpFetchSortEntries(t->GetOpArgs(shard), key, alpha); - }); + ShardId source_sid = Shard(key, shard_set->size()); + OpResultTyped fetch_result; + auto fetch_cb = [&](Transaction* t, EngineShard* shard) { + ShardId shard_id = shard->shard_id(); + if (shard_id == source_sid) { + fetch_result = OpFetchSortEntries(t->GetOpArgs(shard), key, alpha); + } + return fetch_result.status(); + }; + + if (store_key) { + cmd_cntx.tx->Execute(std::move(fetch_cb), false); + } else { + cmd_cntx.tx->Execute(std::move(fetch_cb), true); + } + + // OpResultTyped fetch_result = + // cmd_cntx.tx->ScheduleSingleHopT([&](Transaction* t, EngineShard* shard) { + // return OpFetchSortEntries(t->GetOpArgs(shard), key, alpha); + // }); if (fetch_result == OpStatus::WRONG_TYPE) return builder->SendError(fetch_result.status()); @@ -1481,7 +1528,7 @@ void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) { return rb->SendEmptyArray(); auto result_type = fetch_result.type(); - auto sort_call = [builder, bounds, reversed, result_type](auto& entries) { + auto sort_call = [builder, bounds, reversed, result_type, store_key, cmd_cntx](auto& entries) { using value_t = typename std::decay_t::value_type; auto cmp = reversed ? &value_t::greater : &value_t::less; if (bounds) { @@ -1500,11 +1547,29 @@ void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) { bool is_set = (result_type == OBJ_SET || result_type == OBJ_ZSET); auto* rb = static_cast(builder); - rb->StartCollection(std::distance(start_it, end_it), - is_set ? RedisReplyBuilder::SET : RedisReplyBuilder::ARRAY); + if (store_key) { + ShardId dest_sid = Shard(store_key.value(), shard_set->size()); + OpResult store_len; + auto store_callback = [&](Transaction* t, EngineShard* shard) { + ShardId shard_id = shard->shard_id(); + if (shard_id == dest_sid) { + store_len = OpStore(t->GetOpArgs(shard), store_key.value(), start_it, end_it); + } + return store_len.status(); + }; + cmd_cntx.tx->Execute(std::move(store_callback), true); + if (store_len) { + rb->SendLong(store_len.value()); + } else { + rb->SendError(store_len.status()); + } + } else { + rb->StartCollection(std::distance(start_it, end_it), + is_set ? RedisReplyBuilder::SET : RedisReplyBuilder::ARRAY); - for (auto it = start_it; it != end_it; ++it) { - rb->SendBulkString(it->key); + for (auto it = start_it; it != end_it; ++it) { + rb->SendBulkString(it->key); + } } }; @@ -1948,7 +2013,7 @@ void GenericFamily::Register(CommandRegistry* registry) { << CI{"DUMP", CO::READONLY, 2, 1, 1, acl::kDump}.HFUNC(Dump) << CI{"UNLINK", CO::WRITE, -2, 1, -1, acl::kUnlink}.HFUNC(Unlink) << CI{"STICK", CO::WRITE, -2, 1, -1, acl::kStick}.HFUNC(Stick) - << CI{"SORT", CO::READONLY, -2, 1, 1, acl::kSort}.HFUNC(Sort) + << CI{"SORT", CO::WRITE, -2, 1, -1, acl::kSort}.HFUNC(Sort) << CI{"MOVE", CO::WRITE | CO::GLOBAL_TRANS | CO::NO_AUTOJOURNAL, 3, 1, 1, acl::kMove}.HFUNC( Move) << CI{"RESTORE", CO::WRITE, -4, 1, 1, acl::kRestore}.HFUNC(Restore) diff --git a/src/server/generic_family_test.cc b/src/server/generic_family_test.cc index f6a138f3b..ada6a334d 100644 --- a/src/server/generic_family_test.cc +++ b/src/server/generic_family_test.cc @@ -755,6 +755,84 @@ TEST_F(GenericFamilyTest, SortBug3636) { ASSERT_THAT(resp, ArrLen(17)); } +TEST_F(GenericFamilyTest, SortStore) { + // Test list sort with params + Run({"del", "list-1"}); + Run({"del", "list-2"}); + Run({"lpush", "list-1", "3.5", "1.2", "10.1", "2.20", "200"}); + // numeric + auto resp = Run({"sort", "list-1", "store", "list-2"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), + ElementsAre("1.2", "2.20", "3.5", "10.1", "200")); + + // string + resp = Run({"sort", "list-1", "ALPHA", "store", "list-2"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), + ElementsAre("1.2", "10.1", "2.20", "200", "3.5")); + + // desc numeric + resp = Run({"sort", "list-1", "DESC", "store", "list-2"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), + ElementsAre("200", "10.1", "3.5", "2.20", "1.2")); + + // desc string + resp = Run({"sort", "list-1", "ALPHA", "DESC", "store", "list-2"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), + ElementsAre("3.5", "200", "2.20", "10.1", "1.2")); + + // limits + resp = Run({"sort", "list-1", "LIMIT", "0", "5", "store", "list-2"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), + ElementsAre("1.2", "2.20", "3.5", "10.1", "200")); + resp = Run({"sort", "list-1", "LIMIT", "0", "10", "store", "list-2"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), + ElementsAre("1.2", "2.20", "3.5", "10.1", "200")); + resp = Run({"sort", "list-1", "LIMIT", "2", "2", "store", "list-2"}); + EXPECT_EQ(2, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), ElementsAre("3.5", "10.1")); + resp = Run({"sort", "list-1", "LIMIT", "1", "1", "store", "list-2"}); + EXPECT_EQ(1, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}), "2.20"); + resp = Run({"sort", "list-1", "LIMIT", "4", "2", "store", "list-2"}); + EXPECT_EQ(1, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}), "200"); + resp = Run({"sort", "list-1", "LIMIT", "5", "2", "store", "list-2"}); + EXPECT_EQ(0, resp); + ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}), ArrLen(0)); + + // Test set sort + Run({"del", "set-1"}); + Run({"del", "list-3"}); + Run({"sadd", "set-1", "5.3", "4.4", "60", "99.9", "100", "9"}); + resp = Run({"sort", "set-1", "store", "list-3"}); + EXPECT_EQ(6, resp); + ASSERT_THAT(Run({"lrange", "list-3", "0", "-1"}).GetVec(), + ElementsAre("4.4", "5.3", "9", "60", "99.9", "100")); + + // Test sorted set sort + Run({"del", "zset-1"}); + Run({"del", "list-4"}); + Run({"zadd", "zset-1", "0", "3.3", "0", "30.1", "0", "8.2"}); + resp = Run({"sort", "zset-1", "store", "list-4"}); + EXPECT_EQ(3, resp); + ASSERT_THAT(Run({"lrange", "list-4", "0", "-1"}).GetVec(), ElementsAre("3.3", "8.2", "30.1")); + + // Same key overwrite. + Run({"del", "list-1"}); + Run({"del", "list-2"}); + Run({"lpush", "list-1", "3.5", "1.2", "10.1", "2.20", "200"}); + resp = Run({"sort", "list-1", "store", "list-1"}); + EXPECT_EQ(5, resp); + ASSERT_THAT(Run({"lrange", "list-1", "0", "-1"}).GetVec(), + ElementsAre("1.2", "2.20", "3.5", "10.1", "200")); +} + TEST_F(GenericFamilyTest, TimeNoKeys) { auto resp = Run({"time"}); EXPECT_THAT(resp, ArrLen(2));