feat(server): Implement LPOS command #368 (#379)

Signed-off-by: Elle Y
This commit is contained in:
Elle Y 2022-10-12 21:26:15 -07:00 committed by GitHub
parent 45a5f30cdd
commit 0e2f918f58
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 9 deletions

View file

@ -15,11 +15,11 @@ extern "C" {
#include "server/blocking_controller.h" #include "server/blocking_controller.h"
#include "server/command_registry.h" #include "server/command_registry.h"
#include "server/conn_context.h" #include "server/conn_context.h"
#include "server/container_utils.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "server/error.h" #include "server/error.h"
#include "server/server_state.h" #include "server/server_state.h"
#include "server/transaction.h" #include "server/transaction.h"
#include "server/container_utils.h"
/** /**
* The number of entries allowed per internal list node can be specified * The number of entries allowed per internal list node can be specified
@ -552,9 +552,7 @@ void ListFamily::RPopLPush(CmdArgList args, ConnectionContext* cntx) {
void ListFamily::LLen(CmdArgList args, ConnectionContext* cntx) { void ListFamily::LLen(CmdArgList args, ConnectionContext* cntx) {
auto key = ArgS(args, 1); auto key = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) { auto cb = [&](Transaction* t, EngineShard* shard) { return OpLen(t->GetOpArgs(shard), key); };
return OpLen(t->GetOpArgs(shard), key);
};
OpResult<uint32_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); OpResult<uint32_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result) { if (result) {
(*cntx)->SendLong(result.value()); (*cntx)->SendLong(result.value());
@ -565,6 +563,64 @@ void ListFamily::LLen(CmdArgList args, ConnectionContext* cntx) {
} }
} }
void ListFamily::LPos(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view elem = ArgS(args, 2);
int rank = 1;
uint32_t count = 1;
uint32_t max_len = 0;
bool skip_count = true;
for (size_t i = 3; i < args.size(); i++) {
ToUpper(&args[i]);
const auto& arg_v = ArgS(args, i);
if (arg_v == "RANK") {
if (!absl::SimpleAtoi(ArgS(args, (i + 1)), &rank) || rank == 0) {
return (*cntx)->SendError(kInvalidIntErr);
}
}
if (arg_v == "COUNT") {
if (!absl::SimpleAtoi(ArgS(args, (i + 1)), &count)) {
return (*cntx)->SendError(kInvalidIntErr);
}
skip_count = false;
}
if (arg_v == "MAXLEN") {
if (!absl::SimpleAtoi(ArgS(args, (i + 1)), &max_len)) {
return (*cntx)->SendError(kInvalidIntErr);
}
}
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpPos(t->GetOpArgs(shard), key, elem, rank, count, max_len);
};
Transaction* trans = cntx->transaction;
OpResult<vector<uint32_t>> result = trans->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(result.status());
} else if (result.status() == OpStatus::INVALID_VALUE) {
return (*cntx)->SendError(result.status());
}
if (skip_count) {
if (result->empty()) {
(*cntx)->SendNull();
} else {
(*cntx)->SendLong((*result)[0]);
}
} else {
(*cntx)->StartArray(result->size());
const auto& array = result.value();
for (const auto& v : array) {
(*cntx)->SendLong(v);
}
}
}
void ListFamily::LIndex(CmdArgList args, ConnectionContext* cntx) { void ListFamily::LIndex(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1); std::string_view key = ArgS(args, 1);
std::string_view index_str = ArgS(args, 2); std::string_view index_str = ArgS(args, 2);
@ -853,6 +909,50 @@ OpResult<string> ListFamily::OpIndex(const OpArgs& op_args, std::string_view key
return str; return str;
} }
OpResult<vector<uint32_t>> ListFamily::OpPos(const OpArgs& op_args, std::string_view key,
std::string_view element, int rank, int count,
int max_len) {
OpResult<PrimeIterator> it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_LIST);
if (!it_res.ok())
return it_res.status();
int direction = AL_START_HEAD;
if (rank < 0) {
rank = -rank;
direction = AL_START_TAIL;
}
quicklist* ql = GetQL(it_res.value()->second);
quicklistIter* ql_iter = quicklistGetIterator(ql, direction);
quicklistEntry entry;
int index = 0;
int matched = 0;
vector<uint32_t> matches;
string str;
while (quicklistNext(ql_iter, &entry) && (max_len == 0 || index < max_len)) {
if (entry.value) {
str.assign(reinterpret_cast<char*>(entry.value), entry.sz);
} else {
str = absl::StrCat(entry.longval);
}
if (str == element) {
matched++;
auto k = (direction == AL_START_TAIL) ? ql->count - index - 1 : index;
if (matched >= rank) {
matches.push_back(k);
if (count && matched - rank + 1 >= count) {
break;
}
}
}
index++;
}
quicklistReleaseIterator(ql_iter);
return matches;
}
OpResult<int> ListFamily::OpInsert(const OpArgs& op_args, string_view key, string_view pivot, OpResult<int> ListFamily::OpInsert(const OpArgs& op_args, string_view key, string_view pivot,
string_view elem, int insert_param) { string_view elem, int insert_param) {
auto& db_slice = op_args.shard->db_slice(); auto& db_slice = op_args.shard->db_slice();
@ -1021,10 +1121,13 @@ OpResult<StringVec> ListFamily::OpRange(const OpArgs& op_args, std::string_view
} }
StringVec str_vec; StringVec str_vec;
container_utils::IterateList(res.value()->second, [&str_vec](container_utils::ContainerEntry ce) { container_utils::IterateList(
str_vec.emplace_back(ce.ToString()); res.value()->second,
return true; [&str_vec](container_utils::ContainerEntry ce) {
}, start, end); str_vec.emplace_back(ce.ToString());
return true;
},
start, end);
return str_vec; return str_vec;
} }
@ -1044,6 +1147,7 @@ void ListFamily::Register(CommandRegistry* registry) {
<< CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop) << CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop)
<< CI{"BRPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BRPop) << CI{"BRPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BRPop)
<< CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen) << CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen)
<< CI{"LPOS", CO::READONLY | CO::FAST, -3, 1, 1, 1}.HFUNC(LPos)
<< CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex) << CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex)
<< CI{"LINSERT", CO::WRITE, 5, 1, 1, 1}.HFUNC(LInsert) << CI{"LINSERT", CO::WRITE, 5, 1, 1, 1}.HFUNC(LInsert)
<< CI{"LRANGE", CO::READONLY, 4, 1, 1, 1}.HFUNC(LRange) << CI{"LRANGE", CO::READONLY, 4, 1, 1, 1}.HFUNC(LRange)

View file

@ -29,6 +29,7 @@ class ListFamily {
static void BLPop(CmdArgList args, ConnectionContext* cntx); static void BLPop(CmdArgList args, ConnectionContext* cntx);
static void BRPop(CmdArgList args, ConnectionContext* cntx); static void BRPop(CmdArgList args, ConnectionContext* cntx);
static void LLen(CmdArgList args, ConnectionContext* cntx); static void LLen(CmdArgList args, ConnectionContext* cntx);
static void LPos(CmdArgList args, ConnectionContext* cntx);
static void LIndex(CmdArgList args, ConnectionContext* cntx); static void LIndex(CmdArgList args, ConnectionContext* cntx);
static void LInsert(CmdArgList args, ConnectionContext* cntx); static void LInsert(CmdArgList args, ConnectionContext* cntx);
static void LTrim(CmdArgList args, ConnectionContext* cntx); static void LTrim(CmdArgList args, ConnectionContext* cntx);
@ -45,6 +46,9 @@ class ListFamily {
static OpResult<uint32_t> OpLen(const OpArgs& op_args, std::string_view key); static OpResult<uint32_t> OpLen(const OpArgs& op_args, std::string_view key);
static OpResult<std::string> OpIndex(const OpArgs& op_args, std::string_view key, long index); static OpResult<std::string> OpIndex(const OpArgs& op_args, std::string_view key, long index);
static OpResult<std::vector<uint32_t>> OpPos(const OpArgs& op_args, std::string_view key,
std::string_view element, int rank = 0, int count = -1,
int max_len = 0);
static OpResult<int> OpInsert(const OpArgs& op_args, std::string_view key, std::string_view pivot, static OpResult<int> OpInsert(const OpArgs& op_args, std::string_view key, std::string_view pivot,
std::string_view elem, int insert_param); std::string_view elem, int insert_param);
@ -56,7 +60,6 @@ class ListFamily {
static OpResult<StringVec> OpRange(const OpArgs& op_args, std::string_view key, long start, static OpResult<StringVec> OpRange(const OpArgs& op_args, std::string_view key, long start,
long end); long end);
}; };
} // namespace dfly } // namespace dfly

View file

@ -439,4 +439,25 @@ TEST_F(ListFamilyTest, Lset) {
ASSERT_THAT(Run({"lset", kKey2, "1", "foo"}), ErrArg("index out of range")); ASSERT_THAT(Run({"lset", kKey2, "1", "foo"}), ErrArg("index out of range"));
} }
TEST_F(ListFamilyTest, LPos) {
auto resp = Run({"rpush", kKey1, "1", "a", "b", "1", "1", "a", "1"});
ASSERT_THAT(resp, IntArg(7));
ASSERT_THAT(Run({"lpos", kKey1, "1"}), IntArg(0));
ASSERT_THAT(Run({"lpos", kKey1, "f"}), ArgType(RespExpr::NIL));
ASSERT_THAT(Run({"lpos", kKey1, "1", "COUNT", "-1"}), ArgType(RespExpr::ERROR));
ASSERT_THAT(Run({"lpos", kKey1, "1", "MAXLEN", "-1"}), ArgType(RespExpr::ERROR));
ASSERT_THAT(Run({"lpos", kKey1, "1", "RANK", "0"}), ArgType(RespExpr::ERROR));
resp = Run({"lpos", kKey1, "a", "RANK", "-1", "COUNT", "2"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(5), IntArg(1)));
resp = Run({"lpos", kKey1, "1", "COUNT", "0"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(3), IntArg(4), IntArg(6)));
resp = Run({"lpos", kKey1, "1", "COUNT", "0", "MAXLEN", "5"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(3), IntArg(4)));
}
} // namespace dfly } // namespace dfly