SSCAN command support count and match parameters (#466)

SSCAN command support count and match parameters fixes #426

Signed-off-by: adi_holden <adi@dragonflydb.io>

Signed-off-by: adi_holden <adi@dragonflydb.io>
This commit is contained in:
adiholden 2022-11-09 23:33:54 +02:00 committed by GitHub
parent 91ab423e6a
commit 22f8554680
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 157 additions and 118 deletions

View file

@ -12,6 +12,7 @@
extern "C" {
#include "redis/object.h"
#include "redis/rdb.h"
#include "redis/util.h"
#include "redis/zmalloc.h"
}
@ -193,4 +194,46 @@ TieredStats& TieredStats::operator+=(const TieredStats& o) {
return *this;
}
OpResult<ScanOpts> ScanOpts::TryFrom(CmdArgList args) {
ScanOpts scan_opts;
for (unsigned i = 0; i < args.size(); i += 2) {
ToUpper(&args[i]);
string_view opt = ArgS(args, i);
if (i + 1 == args.size()) {
return facade::OpStatus::SYNTAX_ERR;
}
if (opt == "COUNT") {
if (!absl::SimpleAtoi(ArgS(args, i + 1), &scan_opts.limit)) {
return facade::OpStatus::INVALID_INT;
}
if (scan_opts.limit == 0)
scan_opts.limit = 1;
else if (scan_opts.limit > 4096)
scan_opts.limit = 4096;
} else if (opt == "MATCH") {
scan_opts.pattern = ArgS(args, i + 1);
if (scan_opts.pattern == "*")
scan_opts.pattern = string_view{};
} else if (opt == "TYPE") {
ToLower(&args[i + 1]);
scan_opts.type_filter = ArgS(args, i + 1);
} else if (opt == "BUCKET") {
if (!absl::SimpleAtoi(ArgS(args, i + 1), &scan_opts.bucket_id)) {
return facade::OpStatus::INVALID_INT;
}
} else {
return facade::OpStatus::SYNTAX_ERR;
}
}
return scan_opts;
}
bool ScanOpts::Matches(std::string_view val_name) const {
if (pattern.empty())
return true;
return stringmatchlen(pattern.data(), pattern.size(), val_name.data(), val_name.size(), 0) == 1;
}
} // namespace dfly

View file

@ -4,11 +4,11 @@
#pragma once
#include <boost/fiber/mutex.hpp>
#include <absl/strings/ascii.h>
#include <absl/strings/str_cat.h>
#include <absl/types/span.h>
#include <boost/fiber/mutex.hpp>
#include <string_view>
#include <vector>
@ -32,6 +32,7 @@ using facade::ArgS;
using facade::CmdArgList;
using facade::CmdArgVec;
using facade::MutableSlice;
using facade::OpResult;
using ArgSlice = absl::Span<const std::string_view>;
using StringVec = std::vector<std::string>;
@ -196,4 +197,14 @@ using AggregateStatus = AggregateValue<facade::OpStatus>;
static_assert(facade::OpStatus::OK == facade::OpStatus{},
"Default intitialization should be OK value");
struct ScanOpts {
std::string_view pattern;
size_t limit = 10;
std::string_view type_filter;
unsigned bucket_id = UINT_MAX;
bool Matches(std::string_view val_name) const;
static OpResult<ScanOpts> TryFrom(CmdArgList args);
};
} // namespace dfly

View file

@ -406,14 +406,6 @@ OpStatus Renamer::UpdateDest(Transaction* t, EngineShard* es) {
return OpStatus::OK;
}
struct ScanOpts {
string_view pattern;
string_view type_filter;
size_t limit = 10;
unsigned bucket_id = UINT_MAX;
};
OpStatus OpPersist(const OpArgs& op_args, string_view key) {
auto& db_slice = op_args.shard->db_slice();
auto [it, expire_it] = db_slice.FindExt(op_args.db_cntx, key);
@ -514,15 +506,11 @@ bool ScanCb(const OpArgs& op_args, PrimeIterator it, const ScanOpts& opts, Strin
return false;
}
if (opts.pattern.empty()) {
res->push_back(it->first.ToString());
} else {
string str = it->first.ToString();
if (stringmatchlen(opts.pattern.data(), opts.pattern.size(), str.data(), str.size(), 0) != 1)
return false;
res->push_back(std::move(str));
string str = it->first.ToString();
if (!opts.Matches(str)) {
return false;
}
res->push_back(std::move(str));
return true;
}
@ -1219,42 +1207,16 @@ void GenericFamily::Scan(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError("invalid cursor");
}
ScanOpts scan_opts;
for (unsigned i = 2; i < args.size(); i += 2) {
if (i + 1 == args.size()) {
return (*cntx)->SendError(kSyntaxErr);
}
ToUpper(&args[i]);
string_view opt = ArgS(args, i);
if (opt == "COUNT") {
if (!absl::SimpleAtoi(ArgS(args, i + 1), &scan_opts.limit)) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (scan_opts.limit == 0)
scan_opts.limit = 1;
else if (scan_opts.limit > 4096)
scan_opts.limit = 4096;
} else if (opt == "MATCH") {
scan_opts.pattern = ArgS(args, i + 1);
if (scan_opts.pattern == "*")
scan_opts.pattern = string_view{};
} else if (opt == "TYPE") {
ToLower(&args[i + 1]);
scan_opts.type_filter = ArgS(args, i + 1);
} else if (opt == "BUCKET") {
if (!absl::SimpleAtoi(ArgS(args, i + 1), &scan_opts.bucket_id)) {
return (*cntx)->SendError(kInvalidIntErr);
}
} else {
return (*cntx)->SendError(kSyntaxErr);
}
OpResult<ScanOpts> ops = ScanOpts::TryFrom(args.subspan(2));
if (!ops) {
DVLOG(1) << "Scan invalid args - return " << ops << " to the user";
return (*cntx)->SendError(ops.status());
}
ScanOpts scan_op = ops.value();
StringVec keys;
cursor = ScanGeneric(cursor, scan_opts, &keys, cntx);
cursor = ScanGeneric(cursor, scan_op, &keys, cntx);
(*cntx)->StartArray(2);
(*cntx)->SendSimpleString(absl::StrCat(cursor));

View file

@ -26,48 +26,6 @@ using namespace facade;
namespace {
struct ScanOpts {
string_view pattern;
size_t limit = 10;
constexpr bool Matches(std::string_view val_name) const {
if (pattern.empty())
return true;
return stringmatchlen(pattern.data(), pattern.size(), val_name.data(), val_name.size(), 0) == 1;
}
static OpResult<ScanOpts> TryFrom(CmdArgList args);
};
OpResult<ScanOpts> ScanOpts::TryFrom(CmdArgList args) {
ScanOpts scan_opts;
for (unsigned i = 3; i < args.size(); i += 2) {
ToUpper(&args[i]);
string_view opt = ArgS(args, i);
if (i + 1 == args.size()) {
return OpStatus::SYNTAX_ERR;
}
if (opt == "COUNT") {
if (!absl::SimpleAtoi(ArgS(args, i + 1), &scan_opts.limit)) {
return OpStatus::INVALID_INT;
}
if (scan_opts.limit == 0)
scan_opts.limit = 1;
else if (scan_opts.limit > 4096)
scan_opts.limit = 4096;
} else if (opt == "MATCH") {
scan_opts.pattern = ArgS(args, i + 1);
if (scan_opts.pattern == "*")
scan_opts.pattern = string_view{};
} else {
return OpStatus::SYNTAX_ERR;
}
}
return scan_opts;
}
constexpr size_t kMaxListPackLen = 1024;
using IncrByParam = std::variant<double, int64_t>;
using OptStr = std::optional<std::string>;
@ -881,7 +839,7 @@ void HSetFamily::HScan(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(kSyntaxErr);
}
OpResult<ScanOpts> ops = ScanOpts::TryFrom(args);
OpResult<ScanOpts> ops = ScanOpts::TryFrom(args.subspan(3));
if (!ops) {
DVLOG(1) << "HScan invalid args - return " << ops << " to the user";
return (*cntx)->SendError(ops.status());

View file

@ -209,14 +209,9 @@ void InitSet(ArgSlice vals, CompactObj* set) {
}
}
void ScanCallback(void* privdata, const dictEntry* de) {
StringVec* sv = (StringVec*)privdata;
sds key = (sds)de->key;
sv->push_back(string(key, sdslen(key)));
}
uint64_t ScanStrSet(const DbContext& db_context, const CompactObj& co, uint64_t curs,
unsigned count, StringVec* res) {
const ScanOpts& scan_op, StringVec* res) {
uint32_t count = scan_op.limit;
long maxiterations = count * 10;
if (IsDenseEncoding(co)) {
@ -224,13 +219,36 @@ uint64_t ScanStrSet(const DbContext& db_context, const CompactObj& co, uint64_t
set->set_time(TimeNowSecRel(db_context.time_now_ms));
do {
curs = set->Scan(curs, [&](const sds ptr) { res->push_back(std::string(ptr, sdslen(ptr))); });
auto scan_callback = [&](const sds ptr) {
string_view str{ptr, sdslen(ptr)};
if (scan_op.Matches(str)) {
res->push_back(std::string(str));
}
};
curs = set->Scan(curs, scan_callback);
} while (curs && maxiterations-- && res->size() < count);
} else {
DCHECK_EQ(co.Encoding(), kEncodingStrMap);
using PrivateDataRef = std::tuple<StringVec*, const ScanOpts&>;
PrivateDataRef private_data_ref(res, scan_op);
void* private_data = &private_data_ref;
dict* ds = (dict*)co.RObjPtr();
auto scan_callback = [](void* private_data, const dictEntry* de) {
StringVec* sv = std::get<0>(*(PrivateDataRef*)private_data);
const ScanOpts& scan_op = std::get<1>(*(PrivateDataRef*)private_data);
sds key = (sds)de->key;
auto len = sdslen(key);
if (scan_op.Matches(std::string_view(key, len))) {
sv->emplace_back(key, len);
}
};
do {
curs = dictScan(ds, curs, ScanCallback, NULL, res);
curs = dictScan(ds, curs, scan_callback, NULL, private_data);
} while (curs && maxiterations-- && res->size() < count);
}
@ -290,9 +308,8 @@ bool IsInSet(const DbContext& db_context, const SetType& st, string_view member)
}
}
void FindInSet(StringVec& memberships,
const DbContext& db_context, const SetType& st,
const vector<string_view>& members) {
void FindInSet(StringVec& memberships, const DbContext& db_context, const SetType& st,
const vector<string_view>& members) {
for (const auto& member : members) {
bool status = IsInSet(db_context, st, member);
memberships.emplace_back(to_string(status));
@ -967,7 +984,8 @@ OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count
return result;
}
OpResult<StringVec> OpScan(const OpArgs& op_args, string_view key, uint64_t* cursor) {
OpResult<StringVec> OpScan(const OpArgs& op_args, string_view key, uint64_t* cursor,
const ScanOpts& scan_op) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
@ -975,18 +993,20 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, string_view key, uint64_t* cur
PrimeIterator it = find_res.value();
StringVec res;
uint32_t count = 10;
if (it->second.Encoding() == kEncodingIntSet) {
intset* is = (intset*)it->second.RObjPtr();
int64_t intele;
uint32_t pos = 0;
while (intsetGet(is, pos++, &intele)) {
res.push_back(absl::StrCat(intele));
std::string int_str = absl::StrCat(intele);
if (scan_op.Matches(int_str)) {
res.push_back(int_str);
}
}
*cursor = 0;
} else {
*cursor = ScanStrSet(op_args.db_cntx, it->second, *cursor, count, &res);
*cursor = ScanStrSet(op_args.db_cntx, it->second, *cursor, scan_op, &res);
}
return res;
@ -1404,12 +1424,22 @@ void SScan(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError("invalid cursor");
}
if (args.size() > 3) {
return (*cntx)->SendError("scan options are not supported yet");
// SSCAN key cursor [MATCH pattern] [COUNT count]
if (args.size() > 7) {
DVLOG(1) << "got " << args.size() << " this is more than it should be";
return (*cntx)->SendError(kSyntaxErr);
}
OpResult<ScanOpts> ops = ScanOpts::TryFrom(args.subspan(3));
if (!ops) {
DVLOG(1) << "SScan invalid args - return " << ops << " to the user";
return (*cntx)->SendError(ops.status());
}
ScanOpts scan_op = ops.value();
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpScan(t->GetOpArgs(shard), key, &cursor);
return OpScan(t->GetOpArgs(shard), key, &cursor, scan_op);
};
OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));

View file

@ -138,26 +138,26 @@ TEST_F(SetFamilyTest, SMIsMember) {
auto resp = Run({"smismember", "foo"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"smismember", "foo1", "a", "b"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("0", "0"));
resp = Run({"smismember", "foo", "a", "c"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "0"));
resp = Run({"smismember", "foo", "a", "b"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "1"));
resp = Run({"smismember", "foo", "d", "e"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), ElementsAre("0", "0"));
resp = Run({"smismember", "foo", "b"});
EXPECT_THAT(resp, "1");
resp = Run({"smismember", "foo", "x"});
EXPECT_THAT(resp, "0");
}
@ -167,4 +167,39 @@ TEST_F(SetFamilyTest, Empty) {
ASSERT_THAT(resp, ArrLen(0));
}
TEST_F(SetFamilyTest, SScan) {
// Test for int set
for (int i = 0; i < 15; i++) {
Run({"sadd", "myintset", absl::StrCat(i)});
}
// Note that even though this limit by 4, it would return more because
// all fields are on intlist
auto resp = Run({"sscan", "myintset", "0", "count", "4"});
auto vec = StrArray(resp.GetVec()[1]);
EXPECT_THAT(vec.size(), 15);
resp = Run({"sscan", "myintset", "0", "match", "1*"});
vec = StrArray(resp.GetVec()[1]);
EXPECT_THAT(vec, UnorderedElementsAre("1", "10", "11", "12", "13", "14"));
// test string set
for (int i = 0; i < 15; i++) {
Run({"sadd", "mystrset", absl::StrCat("str-", i)});
}
resp = Run({"sscan", "mystrset", "0", "count", "5"});
vec = StrArray(resp.GetVec()[1]);
EXPECT_THAT(vec.size(), 5);
resp = Run({"sscan", "mystrset", "0", "match", "str-1*", "count", "3"});
vec = StrArray(resp.GetVec()[1]);
EXPECT_THAT(vec, IsSubsetOf({"str-1", "str-10", "str-11", "str-12", "str-13", "str-14"}));
// nothing should match this
resp = Run({"sscan", "mystrset", "0", "match", "1*"});
vec = StrArray(resp.GetVec()[1]);
EXPECT_THAT(vec.size(), 0);
}
} // namespace dfly