diff --git a/src/server/list_family.cc b/src/server/list_family.cc index 163a2f254..eed2dfdf2 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -15,11 +15,11 @@ extern "C" { #include "server/blocking_controller.h" #include "server/command_registry.h" #include "server/conn_context.h" +#include "server/container_utils.h" #include "server/engine_shard_set.h" #include "server/error.h" #include "server/server_state.h" #include "server/transaction.h" -#include "server/container_utils.h" /** * The number of entries allowed per internal list node can be specified @@ -552,9 +552,7 @@ void ListFamily::RPopLPush(CmdArgList args, ConnectionContext* cntx) { void ListFamily::LLen(CmdArgList args, ConnectionContext* cntx) { auto key = ArgS(args, 1); - auto cb = [&](Transaction* t, EngineShard* shard) { - return OpLen(t->GetOpArgs(shard), key); - }; + auto cb = [&](Transaction* t, EngineShard* shard) { return OpLen(t->GetOpArgs(shard), key); }; OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); if (result) { (*cntx)->SendLong(result.value()); @@ -565,6 +563,64 @@ void ListFamily::LLen(CmdArgList args, ConnectionContext* cntx) { } } +void ListFamily::LPos(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 1); + string_view elem = ArgS(args, 2); + + int rank = 1; + uint32_t count = 1; + uint32_t max_len = 0; + bool skip_count = true; + + for (size_t i = 3; i < args.size(); i++) { + ToUpper(&args[i]); + const auto& arg_v = ArgS(args, i); + if (arg_v == "RANK") { + if (!absl::SimpleAtoi(ArgS(args, (i + 1)), &rank) || rank == 0) { + return (*cntx)->SendError(kInvalidIntErr); + } + } + if (arg_v == "COUNT") { + if (!absl::SimpleAtoi(ArgS(args, (i + 1)), &count)) { + return (*cntx)->SendError(kInvalidIntErr); + } + skip_count = false; + } + if (arg_v == "MAXLEN") { + if (!absl::SimpleAtoi(ArgS(args, (i + 1)), &max_len)) { + return (*cntx)->SendError(kInvalidIntErr); + } + } + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpPos(t->GetOpArgs(shard), key, elem, rank, count, max_len); + }; + + Transaction* trans = cntx->transaction; + OpResult> result = trans->ScheduleSingleHopT(std::move(cb)); + + if (result.status() == OpStatus::WRONG_TYPE) { + return (*cntx)->SendError(result.status()); + } else if (result.status() == OpStatus::INVALID_VALUE) { + return (*cntx)->SendError(result.status()); + } + + if (skip_count) { + if (result->empty()) { + (*cntx)->SendNull(); + } else { + (*cntx)->SendLong((*result)[0]); + } + } else { + (*cntx)->StartArray(result->size()); + const auto& array = result.value(); + for (const auto& v : array) { + (*cntx)->SendLong(v); + } + } +} + void ListFamily::LIndex(CmdArgList args, ConnectionContext* cntx) { std::string_view key = ArgS(args, 1); std::string_view index_str = ArgS(args, 2); @@ -853,6 +909,50 @@ OpResult ListFamily::OpIndex(const OpArgs& op_args, std::string_view key return str; } +OpResult> ListFamily::OpPos(const OpArgs& op_args, std::string_view key, + std::string_view element, int rank, int count, + int max_len) { + OpResult it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST); + if (!it_res.ok()) + return it_res.status(); + + int direction = AL_START_HEAD; + if (rank < 0) { + rank = -rank; + direction = AL_START_TAIL; + } + + quicklist* ql = GetQL(it_res.value()->second); + quicklistIter* ql_iter = quicklistGetIterator(ql, direction); + quicklistEntry entry; + + int index = 0; + int matched = 0; + vector matches; + string str; + + while (quicklistNext(ql_iter, &entry) && (max_len == 0 || index < max_len)) { + if (entry.value) { + str.assign(reinterpret_cast(entry.value), entry.sz); + } else { + str = absl::StrCat(entry.longval); + } + if (str == element) { + matched++; + auto k = (direction == AL_START_TAIL) ? ql->count - index - 1 : index; + if (matched >= rank) { + matches.push_back(k); + if (count && matched - rank + 1 >= count) { + break; + } + } + } + index++; + } + quicklistReleaseIterator(ql_iter); + return matches; +} + OpResult ListFamily::OpInsert(const OpArgs& op_args, string_view key, string_view pivot, string_view elem, int insert_param) { auto& db_slice = op_args.shard->db_slice(); @@ -1021,10 +1121,13 @@ OpResult ListFamily::OpRange(const OpArgs& op_args, std::string_view } StringVec str_vec; - container_utils::IterateList(res.value()->second, [&str_vec](container_utils::ContainerEntry ce) { - str_vec.emplace_back(ce.ToString()); - return true; - }, start, end); + container_utils::IterateList( + res.value()->second, + [&str_vec](container_utils::ContainerEntry ce) { + str_vec.emplace_back(ce.ToString()); + return true; + }, + start, end); return str_vec; } @@ -1044,6 +1147,7 @@ void ListFamily::Register(CommandRegistry* registry) { << CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop) << CI{"BRPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BRPop) << CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen) + << CI{"LPOS", CO::READONLY | CO::FAST, -3, 1, 1, 1}.HFUNC(LPos) << CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex) << CI{"LINSERT", CO::WRITE, 5, 1, 1, 1}.HFUNC(LInsert) << CI{"LRANGE", CO::READONLY, 4, 1, 1, 1}.HFUNC(LRange) diff --git a/src/server/list_family.h b/src/server/list_family.h index 4e281b8be..00df8f479 100644 --- a/src/server/list_family.h +++ b/src/server/list_family.h @@ -29,6 +29,7 @@ class ListFamily { static void BLPop(CmdArgList args, ConnectionContext* cntx); static void BRPop(CmdArgList args, ConnectionContext* cntx); static void LLen(CmdArgList args, ConnectionContext* cntx); + static void LPos(CmdArgList args, ConnectionContext* cntx); static void LIndex(CmdArgList args, ConnectionContext* cntx); static void LInsert(CmdArgList args, ConnectionContext* cntx); static void LTrim(CmdArgList args, ConnectionContext* cntx); @@ -45,6 +46,9 @@ class ListFamily { static OpResult OpLen(const OpArgs& op_args, std::string_view key); static OpResult OpIndex(const OpArgs& op_args, std::string_view key, long index); + static OpResult> OpPos(const OpArgs& op_args, std::string_view key, + std::string_view element, int rank = 0, int count = -1, + int max_len = 0); static OpResult OpInsert(const OpArgs& op_args, std::string_view key, std::string_view pivot, std::string_view elem, int insert_param); @@ -56,7 +60,6 @@ class ListFamily { static OpResult OpRange(const OpArgs& op_args, std::string_view key, long start, long end); - }; } // namespace dfly diff --git a/src/server/list_family_test.cc b/src/server/list_family_test.cc index 761792e6f..f9d9eb67a 100644 --- a/src/server/list_family_test.cc +++ b/src/server/list_family_test.cc @@ -439,4 +439,25 @@ TEST_F(ListFamilyTest, Lset) { ASSERT_THAT(Run({"lset", kKey2, "1", "foo"}), ErrArg("index out of range")); } +TEST_F(ListFamilyTest, LPos) { + auto resp = Run({"rpush", kKey1, "1", "a", "b", "1", "1", "a", "1"}); + ASSERT_THAT(resp, IntArg(7)); + + ASSERT_THAT(Run({"lpos", kKey1, "1"}), IntArg(0)); + + ASSERT_THAT(Run({"lpos", kKey1, "f"}), ArgType(RespExpr::NIL)); + ASSERT_THAT(Run({"lpos", kKey1, "1", "COUNT", "-1"}), ArgType(RespExpr::ERROR)); + ASSERT_THAT(Run({"lpos", kKey1, "1", "MAXLEN", "-1"}), ArgType(RespExpr::ERROR)); + ASSERT_THAT(Run({"lpos", kKey1, "1", "RANK", "0"}), ArgType(RespExpr::ERROR)); + + resp = Run({"lpos", kKey1, "a", "RANK", "-1", "COUNT", "2"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(5), IntArg(1))); + + resp = Run({"lpos", kKey1, "1", "COUNT", "0"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(3), IntArg(4), IntArg(6))); + + resp = Run({"lpos", kKey1, "1", "COUNT", "0", "MAXLEN", "5"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(3), IntArg(4))); +} + } // namespace dfly