feat(server): Support Resp3 (#975)

Accept hello 3 command and switch to resp3 response format.

---------

Signed-off-by: ashotland <ari@dragonflydb.io>
This commit is contained in:
ashotland 2023-03-22 12:18:29 +02:00 committed by GitHub
parent a2fdbc59c2
commit 39174f398a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 414 additions and 67 deletions

View file

@ -88,6 +88,8 @@ const char* RespExpr::TypeName(Type t) {
return "string";
case INT64:
return "int";
case DOUBLE:
return "double";
case ARRAY:
return "array";
case NIL_ARRAY:
@ -155,6 +157,9 @@ ostream& operator<<(ostream& os, const facade::RespExpr& e) {
case RespExpr::INT64:
os << "i" << get<int64_t>(e.u);
break;
case RespExpr::DOUBLE:
os << "d" << get<int64_t>(e.u);
break;
case RespExpr::STRING:
os << "'" << ToSV(get<RespExpr::Buffer>(e.u)) << "'";
break;

View file

@ -35,6 +35,12 @@ bool RespMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listen
*listener << "\nActual : " << actual << " expected: " << exp_int_;
return false;
}
} else if (type_ == RespExpr::DOUBLE) {
auto actual = get<double>(e.u);
if (exp_double_ != actual) {
*listener << "\nActual : " << actual << " expected: " << exp_double_;
return false;
}
} else if (type_ == RespExpr::ARRAY) {
size_t len = get<RespVec*>(e.u)->size();
if (len != size_t(exp_int_)) {

View file

@ -18,6 +18,8 @@ class RespMatcher {
RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64) : type_(t), exp_int_(val) {
}
RespMatcher(double_t val, RespExpr::Type t = RespExpr::DOUBLE) : type_(t), exp_double_(val) {
}
using is_gtest_matcher = void;
bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const;
@ -31,6 +33,7 @@ class RespMatcher {
std::string exp_str_;
int64_t exp_int_;
double_t exp_double_;
};
class RespTypeMatcher {
@ -58,8 +61,12 @@ inline ::testing::PolymorphicMatcher<RespMatcher> IntArg(int64_t ival) {
return ::testing::MakePolymorphicMatcher(RespMatcher(ival));
}
inline ::testing::PolymorphicMatcher<RespMatcher> DoubleArg(double_t dval) {
return ::testing::MakePolymorphicMatcher(RespMatcher(dval));
}
inline ::testing::PolymorphicMatcher<RespMatcher> ArrLen(size_t len) {
return ::testing::MakePolymorphicMatcher(RespMatcher(len, RespExpr::ARRAY));
return ::testing::MakePolymorphicMatcher(RespMatcher((int64_t)len, RespExpr::ARRAY));
}
inline ::testing::PolymorphicMatcher<RespTypeMatcher> ArgType(RespExpr::Type t) {

View file

@ -92,6 +92,26 @@ template <> class OpResult<void> : public OpResultBase {
using OpResultBase::OpResultBase;
};
template <typename V> class OpResultTyped : public OpResult<V> {
public:
OpResultTyped(V v) : OpResult<V>(std::move(v)) {
}
OpResultTyped(OpStatus st = OpStatus::OK) : OpResult<V>(st) {
}
void setType(int type) {
type_ = type;
}
int type() const {
return type_;
}
private:
int type_ = -1;
};
inline bool operator==(OpStatus st, const OpResultBase& ob) {
return ob.operator==(st);
}

View file

@ -40,11 +40,12 @@ auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> R
while (state_ != CMD_COMPLETE_S) {
last_consumed_ = 0;
switch (state_) {
case MAP_LEN_S:
case ARRAY_LEN_S:
last_result_ = ConsumeArrayLen(str);
break;
case PARSE_ARG_S:
if (str.size() < 4) {
if (str.size() == 0 || (str.size() < 4 && str[0] != '_')) {
last_result_ = INPUT_PENDING;
} else {
last_result_ = ParseArg(str);
@ -99,12 +100,18 @@ void RedisParser::InitStart(uint8_t prefix_b, RespExpr::Vec* res) {
case ':':
case '+':
case '-':
case '_': // Resp3 NULL
case ',': // Resp3 DOUBLE
state_ = PARSE_ARG_S;
parse_stack_.emplace_back(1, cached_expr_); // expression of length 1.
break;
case '*':
case '~': // Resp3 SET
state_ = ARRAY_LEN_S;
break;
case '%': // Resp3 MAP
state_ = MAP_LEN_S;
break;
default:
state_ = INLINE_S;
break;
@ -231,6 +238,11 @@ auto RedisParser::ConsumeArrayLen(Buffer str) -> Result {
int64_t len;
Result res = ParseNum(str, &len);
if (state_ == MAP_LEN_S) {
// Map starts with %N followed by an array of 2*N elements.
// Even elements are keys, odd elements are values.
len *= 2;
}
switch (res) {
case INPUT_PENDING:
return INPUT_PENDING;
@ -284,6 +296,15 @@ auto RedisParser::ConsumeArrayLen(Buffer str) -> Result {
auto RedisParser::ParseArg(Buffer str) -> Result {
char c = str[0];
if (c == '_') { // Resp3 NIL
state_ = FINISH_ARG_S;
cached_expr_->emplace_back(RespExpr::NIL);
cached_expr_->back().u = Buffer{};
last_consumed_ += 3; // '_','\r','\n'
return OK;
}
if (c == '$') {
int64_t len;
@ -301,7 +322,7 @@ auto RedisParser::ParseArg(Buffer str) -> Result {
LOG(ERROR) << "Unexpected result " << res;
}
if (len < 0) {
if (len < 0) { // Resp2 NIL
state_ = FINISH_ARG_S;
cached_expr_->emplace_back(RespExpr::NIL);
} else {
@ -349,6 +370,19 @@ auto RedisParser::ParseArg(Buffer str) -> Result {
cached_expr_->emplace_back(RespExpr::INT64);
cached_expr_->back().u = ival;
} else if (c == ',') {
DCHECK(!server_mode_);
if (!eol) {
return str.size() < 32 ? INPUT_PENDING : BAD_DOUBLE;
}
double_t dval;
std::string_view tok{s, size_t((eol - s) - 1)};
if (eol[-1] != '\r' || !absl::SimpleAtod(tok, &dval))
return BAD_INT;
cached_expr_->emplace_back(RespExpr::DOUBLE);
cached_expr_->back().u = dval;
} else {
return BAD_STRING;
}
@ -441,4 +475,4 @@ void RedisParser::ExtendLastString(Buffer str) {
buf_stash_.back() = std::move(nb);
}
} // namespace dfly
} // namespace facade

View file

@ -18,7 +18,7 @@ namespace facade {
*/
class RedisParser {
public:
enum Result { OK, INPUT_PENDING, BAD_ARRAYLEN, BAD_BULKLEN, BAD_STRING, BAD_INT };
enum Result { OK, INPUT_PENDING, BAD_ARRAYLEN, BAD_BULKLEN, BAD_STRING, BAD_INT, BAD_DOUBLE };
using Buffer = RespExpr::Buffer;
explicit RedisParser(bool server_mode = true) : server_mode_(server_mode) {
@ -73,6 +73,7 @@ class RedisParser {
INIT_S = 0,
INLINE_S,
ARRAY_LEN_S,
MAP_LEN_S,
PARSE_ARG_S, // Parse [$:+-]string\r\n
BULK_STR_S,
FINISH_ARG_S,

View file

@ -169,6 +169,10 @@ char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len)
RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
}
void RedisReplyBuilder::SetResp3(bool is_resp3) {
is_resp3_ = is_resp3;
}
void RedisReplyBuilder::SendError(string_view str, string_view err_type) {
if (err_type.empty()) {
err_type = str;
@ -201,10 +205,15 @@ void RedisReplyBuilder::SendSetSkipped() {
SendNull();
}
void RedisReplyBuilder::SendNull() {
constexpr char kNullStr[] = "$-1\r\n";
const char* RedisReplyBuilder::NullString() {
if (is_resp3_) {
return "_\r\n";
}
return "$-1\r\n";
}
iovec v[] = {IoVec(kNullStr)};
void RedisReplyBuilder::SendNull() {
iovec v[] = {IoVec(NullString())};
Send(v, ABSL_ARRAYSIZE(v));
}
@ -265,13 +274,45 @@ void RedisReplyBuilder::SendLong(long num) {
SendRaw(str);
}
void RedisReplyBuilder::SendScoredArray(const std::vector<std::pair<std::string, double>>& arr,
bool with_scores) {
if (!with_scores) {
StartArray(arr.size());
for (const auto& p : arr) {
SendBulkString(p.first);
}
return;
}
if (!is_resp3_) { // RESP2 formats withscores as a flat array.
StartArray(arr.size() * 2);
for (const auto& p : arr) {
SendBulkString(p.first);
SendDouble(p.second);
}
return;
}
// Resp3 formats withscores as array of (key, score) pairs.
StartArray(arr.size());
for (const auto& p : arr) {
StartArray(2);
SendBulkString(p.first);
SendDouble(p.second);
}
}
void RedisReplyBuilder::SendDouble(double val) {
char buf[64];
StringBuilder sb(buf, sizeof(buf));
CHECK(dfly_conv.ToShortest(val, &sb));
SendBulkString(sb.Finalize());
if (!is_resp3_) {
SendBulkString(sb.Finalize());
} else {
// RESP3
string str = absl::StrCat(",", sb.Finalize(), kCRLF);
SendRaw(str);
}
}
void RedisReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
@ -281,7 +322,7 @@ void RedisReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
StrAppend(&res, "$", resp[i]->value.size(), kCRLF);
res.append(resp[i]->value).append(kCRLF);
} else {
res.append("$-1\r\n");
res.append(NullString());
}
}
@ -312,7 +353,7 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
return;
}
SendStringArr(arr.data(), arr.size());
SendStringCollection(arr.data(), arr.size(), CollectionType::ARRAY);
}
// This implementation a bit complicated because it uses vectorized
@ -321,19 +362,67 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
// We limit the vector length to 256 and when it fills up we flush it to the socket and continue
// iterating.
void RedisReplyBuilder::SendStringArr(absl::Span<const string> arr) {
if (arr.empty()) {
SendRaw("*0\r\n");
return;
}
SendStringArr(arr.data(), arr.size());
SendStringCollection(arr.data(), arr.size(), CollectionType::ARRAY);
}
void RedisReplyBuilder::SendStringArrayAsMap(absl::Span<const std::string_view> arr) {
SendStringCollection(arr.data(), arr.size(), CollectionType::MAP);
}
void RedisReplyBuilder::SendStringArrayAsMap(absl::Span<const std::string> arr) {
SendStringCollection(arr.data(), arr.size(), CollectionType::MAP);
}
void RedisReplyBuilder::SendStringArrayAsSet(absl::Span<const std::string_view> arr) {
SendStringCollection(arr.data(), arr.size(), CollectionType::SET);
}
void RedisReplyBuilder::SendStringArrayAsSet(absl::Span<const std::string> arr) {
SendStringCollection(arr.data(), arr.size(), CollectionType::SET);
}
void RedisReplyBuilder::StartArray(unsigned len) {
SendRaw(absl::StrCat("*", len, kCRLF));
}
void RedisReplyBuilder::SendStringArr(StrPtr str_ptr, uint32_t len) {
DVLOG(2) << "Sending array of " << len << " strings.";
void RedisReplyBuilder::StartMap(unsigned num_pairs) {
if (is_resp3_) {
SendRaw(absl::StrCat("%", num_pairs, kCRLF));
return;
}
// Flatten for Resp2.
StartArray(num_pairs * 2);
}
void RedisReplyBuilder::StartSet(unsigned num_elements) {
if (is_resp3_) {
SendRaw(absl::StrCat("~", num_elements, kCRLF));
}
// Flatten for Resp2.
StartArray(num_elements);
}
void RedisReplyBuilder::SendStringCollection(StrPtr str_ptr, uint32_t len, CollectionType type) {
string type_char = "*";
size_t header_len = len;
if (is_resp3_) {
switch (type) {
case CollectionType::ARRAY:
break;
case CollectionType::MAP:
type_char[0] = '%';
header_len = 0.5 * len; // Each key value pair counts as one.
break;
case CollectionType::SET:
type_char[0] = '~';
break;
}
}
if (header_len == 0) {
SendRaw(absl::StrCat(type_char, "0\r\n"));
return;
}
// When vector length is too long, Send returns EMSGSIZE.
size_t vec_len = std::min<size_t>(256u, len);
@ -342,8 +431,8 @@ void RedisReplyBuilder::SendStringArr(StrPtr str_ptr, uint32_t len) {
absl::FixedArray<char, 64> meta((vec_len + 1) * 16);
char* next = meta.data();
*next++ = '*';
next = absl::numbers_internal::FastIntToBuffer(len, next);
*next++ = type_char[0];
next = absl::numbers_internal::FastIntToBuffer(header_len, next);
*next++ = '\r';
*next++ = '\n';
vec[0] = IoVec(string_view{meta.data(), size_t(next - meta.data())});

View file

@ -118,6 +118,8 @@ class RedisReplyBuilder : public SinkReplyBuilder {
public:
RedisReplyBuilder(::io::Sink* stream);
void SetResp3(bool is_resp3);
void SendError(std::string_view str, std::string_view type = std::string_view{}) override;
void SendMGetResponse(const OptResp* resp, uint32_t count) override;
void SendSimpleString(std::string_view str) override;
@ -135,13 +137,23 @@ class RedisReplyBuilder : public SinkReplyBuilder {
virtual void SendStringArr(absl::Span<const std::string_view> arr);
virtual void SendStringArr(absl::Span<const std::string> arr);
virtual void SendStringArrayAsMap(absl::Span<const std::string_view> arr);
virtual void SendStringArrayAsMap(absl::Span<const std::string> arr);
virtual void SendStringArrayAsSet(absl::Span<const std::string_view> arr);
virtual void SendStringArrayAsSet(absl::Span<const std::string> arr);
virtual void SendNull();
virtual void SendScoredArray(const std::vector<std::pair<std::string, double>>& arr,
bool with_scores);
virtual void SendDouble(double val);
virtual void SendBulkString(std::string_view str);
virtual void StartArray(unsigned len);
virtual void StartMap(unsigned num_pairs);
virtual void StartSet(unsigned num_elements);
static char* FormatDouble(double val, char* dest, unsigned dest_len);
@ -150,8 +162,17 @@ class RedisReplyBuilder : public SinkReplyBuilder {
static std::string_view StatusToMsg(OpStatus status);
private:
enum CollectionType {
ARRAY,
SET,
MAP,
};
using StrPtr = std::variant<const std::string_view*, const std::string*>;
void SendStringArr(StrPtr str_ptr, uint32_t len);
void SendStringCollection(StrPtr str_ptr, uint32_t len, CollectionType type);
bool is_resp3_ = false;
const char* NullString();
};
class ReqSerializer {

View file

@ -656,4 +656,100 @@ TEST_F(RedisReplyBuilderTest, TestBatchMode) {
absl::StrCat(kBulkStringStart, "0"), std::string_view{}));
}
TEST_F(RedisReplyBuilderTest, TestResp3Double) {
builder_->SetResp3(true);
builder_->SendDouble(5.5);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(str(), ",5.5\r\n");
}
TEST_F(RedisReplyBuilderTest, TestResp3NullString) {
builder_->SetResp3(true);
builder_->SendNull();
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "_\r\n");
}
TEST_F(RedisReplyBuilderTest, TestSendStringArrayAsMap) {
const std::vector<std::string> map_array{"k1", "v1", "k2", "v2"};
builder_->SetResp3(false);
builder_->SendStringArrayAsMap(map_array);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "*4\r\n$2\r\nk1\r\n$2\r\nv1\r\n$2\r\nk2\r\n$2\r\nv2\r\n")
<< "SendStringArrayAsMap Resp2 Failed.";
builder_->SetResp3(true);
builder_->SendStringArrayAsMap(map_array);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "%2\r\n$2\r\nk1\r\n$2\r\nv1\r\n$2\r\nk2\r\n$2\r\nv2\r\n")
<< "SendStringArrayAsMap Resp3 Failed.";
}
TEST_F(RedisReplyBuilderTest, TestSendStringArrayAsSet) {
const std::vector<std::string> set_array{"e1", "e2", "e3"};
builder_->SetResp3(false);
builder_->SendStringArrayAsSet(set_array);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "SendStringArrayAsSet Resp2 Failed.";
builder_->SetResp3(true);
builder_->SendStringArrayAsSet(set_array);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "~3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "SendStringArrayAsSet Resp3 Failed.";
}
TEST_F(RedisReplyBuilderTest, TestSendScoredArray) {
const std::vector<std::pair<std::string, double>> scored_array{
{"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}};
builder_->SetResp3(false);
builder_->SendScoredArray(scored_array, false);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "Resp2 WITHOUT scores failed.";
builder_->SetResp3(true);
builder_->SendScoredArray(scored_array, false);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "Resp3 WITHOUT scores failed.";
builder_->SetResp3(false);
builder_->SendScoredArray(scored_array, true);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(),
"*6\r\n$2\r\ne1\r\n$3\r\n1.1\r\n$2\r\ne2\r\n$3\r\n2.2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n")
<< "Resp3 WITHSCORES failed.";
builder_->SetResp3(true);
builder_->SendScoredArray(scored_array, true);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(),
"*3\r\n*2\r\n$2\r\ne1\r\n,1.1\r\n*2\r\n$2\r\ne2\r\n,2.2\r\n*2\r\n$2\r\ne3\r\n,3.3\r\n")
<< "Resp3 WITHSCORES failed.";
}
TEST_F(RedisReplyBuilderTest, TestSendMGetResponse) {
std::vector<SinkReplyBuilder::OptResp> mget_res(3);
auto& v = mget_res[0].emplace();
v.value = "v1";
v = mget_res[2].emplace();
v.value = "v3";
builder_->SetResp3(false);
builder_->SendMGetResponse(&mget_res[0], 3);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\nv3\r\n$-1\r\n$0\r\n\r\n")
<< "Resp2 SendMGetResponse failed.";
builder_->SetResp3(true);
builder_->SendMGetResponse(&mget_res[0], 3);
ASSERT_TRUE(builder_->err_count().empty());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\nv3\r\n_\r\n$0\r\n\r\n") << "Resp3 SendMGetResponse failed.";
}
} // namespace facade

View file

@ -16,13 +16,13 @@ class RespExpr {
public:
using Buffer = absl::Span<uint8_t>;
enum Type : uint8_t { STRING, ARRAY, INT64, NIL, NIL_ARRAY, ERROR };
enum Type : uint8_t { STRING, ARRAY, INT64, DOUBLE, NIL, NIL_ARRAY, ERROR };
using Vec = std::vector<RespExpr>;
Type type;
bool has_support; // whether pointers in this item are supported by the external storage.
std::variant<int64_t, Buffer, Vec*> u;
std::variant<int64_t, double, Buffer, Vec*> u;
RespExpr(Type t = NIL) : type(t), has_support(false) {
}

View file

@ -228,10 +228,15 @@ TEST_F(DflyEngineTest, Hello) {
ArgType(RespExpr::STRING), "proto", IntArg(2), "id",
ArgType(RespExpr::INT64), "mode", "standalone", "role", "master"));
resp = Run({"hello", "3"});
ASSERT_THAT(resp, ArrLen(14));
EXPECT_THAT(resp.GetVec(),
ElementsAre("server", "redis", "version", "6.2.11", "dfly_version",
ArgType(RespExpr::STRING), "proto", IntArg(3), "id",
ArgType(RespExpr::INT64), "mode", "standalone", "role", "master"));
// These are valid arguments to HELLO, however as they are not yet supported the implementation
// is degraded to 'unknown command'.
EXPECT_THAT(Run({"hello", "3"}),
ErrArg("ERR unknown command 'HELLO' with args beginning with: `3`"));
EXPECT_THAT(
Run({"hello", "2", "AUTH", "uname", "pwd"}),
ErrArg("ERR unknown command 'HELLO' with args beginning with: `2`, `AUTH`, `uname`, `pwd`"));

View file

@ -929,8 +929,8 @@ template <typename F> bool Iterate(const PrimeValue& pv, F&& func) {
}
// Create a SortEntryList from given key
OpResult<SortEntryList> OpFetchSortEntries(const OpArgs& op_args, std::string_view key,
bool alpha) {
OpResultTyped<SortEntryList> OpFetchSortEntries(const OpArgs& op_args, std::string_view key,
bool alpha) {
using namespace container_utils;
auto [it, _] = op_args.shard->db_slice().FindExt(op_args.db_cntx, key);
@ -947,7 +947,9 @@ OpResult<SortEntryList> OpFetchSortEntries(const OpArgs& op_args, std::string_vi
});
},
result);
return success ? OpResult{std::move(result)} : OpStatus::WRONG_TYPE;
auto res = OpResultTyped{std::move(result)};
res.setType(it->second.ObjType());
return success ? res : OpStatus::WRONG_TYPE;
}
void GenericFamily::Sort(CmdArgList args, ConnectionContext* cntx) {
@ -978,18 +980,19 @@ void GenericFamily::Sort(CmdArgList args, ConnectionContext* cntx) {
}
}
OpResult<SortEntryList> entries =
OpResultTyped<SortEntryList> fetch_result =
cntx->transaction->ScheduleSingleHopT([&](Transaction* t, EngineShard* shard) {
return OpFetchSortEntries(t->GetOpArgs(shard), key, alpha);
});
if (entries.status() == OpStatus::WRONG_TYPE)
if (fetch_result.status() == OpStatus::WRONG_TYPE)
return (*cntx)->SendError("One or more scores can't be converted into double");
if (!entries.ok())
if (!fetch_result.ok())
return (*cntx)->SendEmptyArray();
auto sort_call = [cntx, bounds, reversed](auto& entries) {
auto result_type = fetch_result.type();
auto sort_call = [cntx, bounds, reversed, result_type](auto& entries) {
if (bounds) {
auto sort_it = entries.begin() + std::min(bounds->first + bounds->second, entries.size());
std::partial_sort(entries.begin(), sort_it, entries.end(),
@ -1009,12 +1012,17 @@ void GenericFamily::Sort(CmdArgList args, ConnectionContext* cntx) {
end_it = entries.begin() + std::min(bounds->first + bounds->second, entries.size());
}
(*cntx)->StartArray(std::distance(start_it, end_it));
if (result_type == OBJ_SET || result_type == OBJ_ZSET) {
(*cntx)->StartSet(std::distance(start_it, end_it));
} else {
(*cntx)->StartArray(std::distance(start_it, end_it));
}
for (auto it = start_it; it != end_it; ++it) {
(*cntx)->SendBulkString(it->key);
}
};
std::visit(std::move(sort_call), entries.value());
std::visit(std::move(sort_call), fetch_result.value());
}
void GenericFamily::Restore(CmdArgList args, ConnectionContext* cntx) {

View file

@ -697,7 +697,11 @@ void HGetGeneric(CmdArgList args, ConnectionContext* cntx, uint8_t getall_mask)
OpResult<vector<string>> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result) {
(*cntx)->SendStringArr(absl::Span<const string>{*result});
if (getall_mask == (VALUES | FIELDS)) {
(*cntx)->SendStringArrayAsMap(absl::Span<const string>{*result});
} else {
(*cntx)->SendStringArr(absl::Span<const string>{*result});
}
} else {
(*cntx)->SendError(result.status());
}
@ -945,7 +949,7 @@ void HSetFamily::HScan(CmdArgList args, ConnectionContext* cntx) {
if (result.status() != OpStatus::WRONG_TYPE) {
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(absl::StrCat(cursor));
(*cntx)->StartArray(result->size());
(*cntx)->StartArray(result->size()); // Within scan the page type is array
for (const auto& k : *result) {
(*cntx)->SendBulkString(k);
}

View file

@ -27,6 +27,14 @@ class HSetFamilyTest : public BaseFamilyTest {
protected:
};
class HestFamilyTestProtocolVersioned : public HSetFamilyTest,
public ::testing::WithParamInterface<string> {
protected:
};
INSTANTIATE_TEST_CASE_P(HestFamilyTestProtocolVersioned, HestFamilyTestProtocolVersioned,
::testing::Values("2", "3"));
TEST_F(HSetFamilyTest, Hash) {
robj* obj = createHashObject();
sds field = sdsnew("field");
@ -70,8 +78,12 @@ TEST_F(HSetFamilyTest, HSet) {
EXPECT_EQ(1, CheckedInt({"hset", "small", "", "565323349817"}));
}
TEST_F(HSetFamilyTest, Get) {
auto resp = Run({"hset", "x", "a", "1", "b", "2", "c", "3"});
TEST_P(HestFamilyTestProtocolVersioned, Get) {
auto resp = Run({"hello", GetParam()});
EXPECT_THAT(resp.GetVec()[6], "proto");
EXPECT_THAT(resp.GetVec()[7], IntArg(atoi(GetParam().c_str())));
resp = Run({"hset", "x", "a", "1", "b", "2", "c", "3"});
EXPECT_THAT(resp, IntArg(3));
resp = Run({"hmget", "unkwn", "a", "c"});

View file

@ -213,6 +213,10 @@ class InterpreterReplier : public RedisReplyBuilder {
void SendSimpleString(std::string_view str) final;
void SendMGetResponse(const OptResp* resp, uint32_t count) final;
void SendSimpleStrArr(const string_view* arr, uint32_t count) final;
void SendStringArrayAsMap(absl::Span<const std::string_view> arr) final;
void SendStringArrayAsMap(absl::Span<const std::string> arr) final;
void SendStringArrayAsSet(absl::Span<const std::string_view> arr) final;
void SendStringArrayAsSet(absl::Span<const std::string> arr) final;
void SendNullArray() final;
void SendStringArr(absl::Span<const string_view> arr) final;
@ -340,6 +344,22 @@ void InterpreterReplier::SendSimpleStrArr(const string_view* arr, uint32_t count
explr_->OnArrayEnd();
}
void InterpreterReplier::SendStringArrayAsMap(absl::Span<const string_view> arr) {
SendStringArr(arr);
}
void InterpreterReplier::SendStringArrayAsMap(absl::Span<const string> arr) {
SendStringArr(arr);
}
void InterpreterReplier::SendStringArrayAsSet(absl::Span<const string_view> arr) {
SendStringArr(arr);
}
void InterpreterReplier::SendStringArrayAsSet(absl::Span<const string> arr) {
SendStringArr(arr);
}
void InterpreterReplier::SendNullArray() {
SendSimpleStrArr(nullptr, 0);
PostItem();

View file

@ -1318,7 +1318,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) {
string_view param = ArgS(args, 2);
string_view res[2] = {param, "tbd"};
return (*cntx)->SendStringArr(res);
return (*cntx)->SendStringArrayAsMap(res);
} else if (sub_cmd == "RESETSTAT") {
shard_set->pool()->Await([](auto*) {
auto* stats = ServerState::tl_connection_stats();
@ -1689,21 +1689,29 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) {
}
void ServerFamily::Hello(CmdArgList args, ConnectionContext* cntx) {
// Allow calling this commands with no arguments or protover=2
// technically that is all that is supported at the moment.
// For all other cases degrade to 'unknown command' so that clients
// checking for the existence of the command to detect if RESP3 is
// supported or whether authentication can be performed using HELLO
// will gracefully fallback to RESP2 and using the AUTH command explicitly.
// If no arguments are provided default to RESP2.
// AUTH and SETNAME options are not supported.
bool is_resp3 = false;
if (args.size() > 1) {
string_view proto_version = ArgS(args, 1);
if (proto_version != "2" || args.size() > 2) {
is_resp3 = proto_version == "3";
bool valid_proto_version = proto_version == "2" || is_resp3;
if (!valid_proto_version || args.size() > 2) {
(*cntx)->SendError(UnknownCmd("HELLO", args.subspan(1)));
return;
}
}
(*cntx)->StartArray(14);
int proto_version = 2;
if (is_resp3) {
proto_version = 3;
(*cntx)->SetResp3(true);
} else {
// Issuing hello 2 again is valid and should switch back to RESP2
(*cntx)->SetResp3(false);
}
(*cntx)->StartMap(7);
(*cntx)->SendBulkString("server");
(*cntx)->SendBulkString("redis");
(*cntx)->SendBulkString("version");
@ -1711,7 +1719,7 @@ void ServerFamily::Hello(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendBulkString("dfly_version");
(*cntx)->SendBulkString(GetVersion());
(*cntx)->SendBulkString("proto");
(*cntx)->SendLong(2);
(*cntx)->SendLong(proto_version);
(*cntx)->SendBulkString("id");
(*cntx)->SendLong(cntx->owner()->GetClientId());
(*cntx)->SendBulkString("mode");

View file

@ -1205,7 +1205,7 @@ void SPop(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendBulkString(result.value().front());
}
} else { // SPOP key cnt
(*cntx)->SendStringArr(*result);
(*cntx)->SendStringArrayAsSet(*result);
}
return;
}
@ -1241,7 +1241,7 @@ void SDiff(CmdArgList args, ConnectionContext* cntx) {
if (cntx->conn_state.script_info) { // sort under script
sort(arr.begin(), arr.end());
}
(*cntx)->SendStringArr(arr);
(*cntx)->SendStringArrayAsSet(arr);
}
void SDiffStore(CmdArgList args, ConnectionContext* cntx) {
@ -1309,7 +1309,7 @@ void SMembers(CmdArgList args, ConnectionContext* cntx) {
if (cntx->conn_state.script_info) { // sort under script
sort(svec.begin(), svec.end());
}
(*cntx)->SendStringArr(*result);
(*cntx)->SendStringArrayAsSet(*result);
} else {
(*cntx)->SendError(result.status());
}
@ -1331,7 +1331,7 @@ void SInter(CmdArgList args, ConnectionContext* cntx) {
if (cntx->conn_state.script_info) { // sort under script
sort(arr.begin(), arr.end());
}
(*cntx)->SendStringArr(arr);
(*cntx)->SendStringArrayAsSet(arr);
} else {
(*cntx)->SendError(result.status());
}
@ -1394,7 +1394,7 @@ void SUnion(CmdArgList args, ConnectionContext* cntx) {
if (cntx->conn_state.script_info) { // sort under script
sort(arr.begin(), arr.end());
}
(*cntx)->SendStringArr(arr);
(*cntx)->SendStringArrayAsSet(arr);
} else {
(*cntx)->SendError(unionset.status());
}
@ -1473,7 +1473,7 @@ void SScan(CmdArgList args, ConnectionContext* cntx) {
if (result.status() != OpStatus::WRONG_TYPE) {
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(absl::StrCat(cursor));
(*cntx)->StartArray(result->size());
(*cntx)->StartArray(result->size()); // Within scan the return page is of type array
for (const auto& k : *result) {
(*cntx)->SendBulkString(k);
}

View file

@ -734,7 +734,7 @@ void StreamFamily::XInfo(CmdArgList args, ConnectionContext* cntx) {
string_view arr[8] = {"name", ginfo.name, "consumers", an1.Piece(),
"pending", an2.Piece(), "last-delivered-id", last_id};
(*cntx)->SendStringArr(absl::Span<string_view>{arr, 8});
(*cntx)->SendStringArrayAsMap(absl::Span<string_view>{arr, 8});
}
return;
}

View file

@ -663,4 +663,14 @@ TEST_F(StringFamilyTest, ClThrottle) {
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));
}
TEST_F(StringFamilyTest, SetMGetWithNilResp3) {
Run({"hello", "3"});
EXPECT_EQ(Run({"set", "key", "val"}), "OK");
EXPECT_EQ(Run({"get", "key"}), "val");
RespExpr resp = Run({"mget", "key", "nonexist"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
EXPECT_THAT(resp.GetVec(), ElementsAre("val", ArgType(RespExpr::NIL)));
}
} // namespace dfly

View file

@ -1578,7 +1578,7 @@ void ZSetFamily::ZMScore(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(kWrongTypeErr);
}
(*cntx)->StartArray(result->size());
(*cntx)->StartArray(result->size()); // Array return type.
const MScoreResponse& array = result.value();
for (const auto& p : array) {
if (p) {
@ -1614,7 +1614,7 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
if (result.status() != OpStatus::WRONG_TYPE) {
(*cntx)->StartArray(2);
(*cntx)->SendBulkString(absl::StrCat(cursor));
(*cntx)->StartArray(result->size());
(*cntx)->StartArray(result->size()); // Within scan the returned page is of type array.
for (const auto& k : *result) {
(*cntx)->SendBulkString(k);
}
@ -1649,16 +1649,7 @@ void ZSetFamily::OutputScoredArrayResult(const OpResult<ScoredArray>& result,
LOG_IF(WARNING, !result && result.status() != OpStatus::KEY_NOTFOUND)
<< "Unexpected status " << result.status();
(*cntx)->StartArray(result->size() * (params.with_scores ? 2 : 1));
const ScoredArray& array = result.value();
for (const auto& p : array) {
(*cntx)->SendBulkString(p.first);
if (params.with_scores) {
(*cntx)->SendDouble(p.second);
}
}
(*cntx)->SendScoredArray(result.value(), params.with_scores);
}
void ZSetFamily::ZRemRangeGeneric(string_view key, const ZRangeSpec& range_spec,

View file

@ -477,4 +477,14 @@ TEST_F(ZSetFamilyTest, ZPopMax) {
resp = Run({"zpopmax", "key", "1"});
ASSERT_THAT(resp, ArrLen(0));
}
TEST_F(ZSetFamilyTest, Resp3) {
Run({"hello", "3"});
Run({"zadd", "x", "1", "a", "2", "b"});
auto resp = Run({"zrange", "x", "0", "-1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
ASSERT_THAT(resp.GetVec()[0].GetVec(), ElementsAre("a", DoubleArg(1)));
ASSERT_THAT(resp.GetVec()[1].GetVec(), ElementsAre("b", DoubleArg(2)));
}
} // namespace dfly