fix: buffer overrun when passing long command name from lua (#2008)

Also, few additional changes that do not affect functionality.
1. make sure passed arguments to DispatchCommand are `\0` delimited
   during pipelining.
2. extend lua malloc hook to call precise functions - to help with cpu profiling.
3. reuse arguments buffer (save allocations) when calling Dragonfly command from lua scripts.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2023-10-11 19:19:48 +03:00 committed by GitHub
parent c6f8f3882a
commit ec87114f66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 36 deletions

View file

@ -7,8 +7,8 @@ endif()
add_third_party(
lua
URL https://github.com/lua/lua/archive/refs/tags/v5.4.4.tar.gz
PATCH_COMMAND patch -p1 -i "${CMAKE_SOURCE_DIR}/patches/lua-v5.4.4.patch"
GIT_REPOSITORY https://github.com/dragonflydb/lua
GIT_TAG Dragonfly-5.4.6
CONFIGURE_COMMAND echo
BUILD_IN_SOURCE 1
BUILD_COMMAND ${DFLY_TOOLS_MAKE} all

View file

@ -113,7 +113,7 @@ void RedisTranslator::OnString(std::string_view str) {
}
void RedisTranslator::OnDouble(double d) {
static constexpr double kConvertEps = std::numeric_limits<double>::epsilon();
const double kConvertEps = std::numeric_limits<double>::epsilon();
double fractpart, intpart;
fractpart = modf(d, &intpart);
@ -351,14 +351,14 @@ int RedisStatusReplyCommand(lua_State* lua) {
return SingleFieldTable(lua, "ok");
}
// const char* kInstanceKey = "_INSTANCE";
// See https://www.lua.org/manual/5.3/manual.html#lua_Alloc
void* mimalloc_glue(void* ud, void* ptr, size_t osize, size_t nsize) {
(void)ud;
(void)osize; /* not used */
if (nsize == 0) {
mi_free(ptr);
mi_free_size(ptr, osize);
return nullptr;
} else if (ptr == nullptr) {
return mi_malloc(nsize);
} else {
return mi_realloc(ptr, nsize);
}
@ -730,12 +730,16 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
cmd_depth_++;
int argc = lua_gettop(lua_);
#define RETURN_ERROR(err) \
{ \
PushError(lua_, err); \
cmd_depth_--; \
return raise_error ? RaiseError(lua_) : 1; \
}
/* Require at least one argument */
if (argc == 0) {
PushError(lua_, "Please specify at least one argument for redis.call()");
cmd_depth_--;
return raise_error ? RaiseError(lua_) : 1;
RETURN_ERROR("Please specify at least one argument for redis.call()");
}
size_t blob_len = 0;
@ -755,28 +759,28 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
}
continue;
case LUA_TSTRING:
blob_len += lua_rawlen(lua_, idx);
blob_len += lua_rawlen(lua_, idx) + 1;
continue;
default:
PushError(lua_, "Lua redis() command arguments must be strings or integers");
cmd_depth_--;
return raise_error ? RaiseError(lua_) : 1;
RETURN_ERROR("Lua redis() command arguments must be strings or integers");
}
}
char name_buffer[32]; // backing storage for cmd name
string buffer(blob_len + 4, '\0'); // backing storage for args
char name_buffer[32]; // backing storage for cmd name
absl::FixedArray<absl::Span<char>, 4> args(argc);
char* cur = buffer.data();
char* end = cur + blob_len;
// Copy command name to name_buffer and set it as first arg.
unsigned len = lua_rawlen(lua_, 1);
DCHECK_LT(len, ABSL_ARRAYSIZE(name_buffer));
memcpy(name_buffer, lua_tostring(lua_, 1), len);
args[0] = {name_buffer, len};
unsigned name_len = lua_rawlen(lua_, 1);
if (name_len >= sizeof(name_buffer)) {
RETURN_ERROR("Lua redis() command name too long");
}
memcpy(name_buffer, lua_tostring(lua_, 1), name_len);
args[0] = {name_buffer, name_len};
buffer_.resize(blob_len + 4, '\0'); // backing storage for args
char* cur = buffer_.data();
char* end = cur + blob_len;
for (int idx = 2; idx <= argc; idx++) {
size_t len = 0;
switch (lua_type(lua_, idx)) {
@ -794,7 +798,7 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
break;
case LUA_TSTRING:
len = lua_rawlen(lua_, idx);
memcpy(cur, lua_tostring(lua_, idx), len);
memcpy(cur, lua_tostring(lua_, idx), len + 1); // + 1 for null terminator
};
args[idx - 1] = {cur, len};
@ -805,9 +809,16 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
* and this way we guaranty we will have room on the stack for the result. */
lua_pop(lua_, argc);
RedisTranslator translator(lua_);
redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async, raise_error, &raise_error});
redis_func_(
CallArgs{MutSliceSpan{args}, &buffer_, &translator, async, raise_error, &raise_error});
cmd_depth_--;
// Shrink reusable buffer if it's too big.
if (buffer_.capacity() > 128) {
buffer_.clear();
buffer_.shrink_to_fit();
}
// Raise error for regular 'call' command if needed.
if (raise_error && translator.HasError()) {
// error is already on top of stack

View file

@ -130,6 +130,7 @@ class Interpreter {
lua_State* lua_;
unsigned cmd_depth_ = 0;
RedisFunc redis_func_;
std::string buffer_;
};
// Manages an internal interpreter pool. This allows multiple connections residing on the same

View file

@ -77,19 +77,21 @@ class InterpreterTest : public ::testing::Test {
CHECK_EQ(0, lua_pcall(lua(), 0, num_results, 0));
}
void SetGlobalArray(const char* name, vector<string> vec);
void SetGlobalArray(const char* name, const vector<string_view>& vec);
bool Execute(string_view script);
Interpreter intptr_;
TestSerializer ser_;
string error_;
vector<unique_ptr<string>> strings_;
};
void InterpreterTest::SetGlobalArray(const char* name, vector<string> vec) {
void InterpreterTest::SetGlobalArray(const char* name, const vector<string_view>& vec) {
vector<MutableSlice> slices(vec.size());
for (size_t i = 0; i < vec.size(); ++i) {
slices[i] = MutableSlice{vec[i]};
strings_.emplace_back(new string(vec[i]));
slices[i] = MutableSlice{*strings_.back()};
}
intptr_.SetGlobalArray(name, MutSliceSpan{slices});
}
@ -318,6 +320,10 @@ TEST_F(InterpreterTest, ArgKeys) {
SetGlobalArray("KEYS", {"key1", "key2"});
EXPECT_TRUE(Execute("return {ARGV[1], KEYS[1], KEYS[2]}"));
EXPECT_EQ("[str(foo) str(key1) str(key2)]", ser_.res);
SetGlobalArray("INTKEYS", {"123456", "1"});
EXPECT_TRUE(Execute("return INTKEYS[1] + 0")) << error_;
EXPECT_EQ("i(123456)", ser_.res);
}
TEST_F(InterpreterTest, Modules) {

View file

@ -159,8 +159,9 @@ void Connection::PipelineMessage::SetArgs(const RespVec& args) {
size_t s = buf.size();
if (s)
memcpy(next, buf.data(), s);
next[s] = '\0';
this->args[i] = MutableSlice(next, s);
next += s;
next += (s + 1);
}
}
@ -959,7 +960,7 @@ Connection::PipelineMessagePtr Connection::FromArgs(RespVec args, mi_heap_t* hea
size_t backed_sz = 0;
for (const auto& arg : args) {
CHECK_EQ(RespExpr::STRING, arg.type);
backed_sz += arg.GetBuf().size();
backed_sz += arg.GetBuf().size() + 1; // for '\0'
}
DCHECK(backed_sz);

View file

@ -21,13 +21,13 @@ using namespace facade;
StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args, facade::ReplyMode mode)
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
size_t total_size = 0;
for (auto args : args)
total_size += args.size();
for (auto args : args) {
total_size += args.size() + 1; // +1 for null terminator
}
buffer_.resize(total_size);
char* next = buffer_.data();
for (unsigned i = 0; i < args.size(); i++) {
memcpy(next, args[i].data(), args[i].size());
memcpy(next, args[i].data(), args[i].size() + 1);
sizes_[i] = args[i].size();
next += args[i].size();
}

View file

@ -371,6 +371,7 @@ void InterpreterReplier::PostItem() {
void InterpreterReplier::SendError(string_view str, std::string_view type) {
DCHECK(array_len_.empty());
DVLOG(1) << "Lua/df_call error " << str;
explr_->OnError(str);
}
@ -1377,7 +1378,7 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) {
DCHECK(cntx->transaction);
DVLOG(1) << "CallFromScript " << cntx->transaction->DebugId() << " " << ArgS(ca.args, 0);
DVLOG(1) << "CallFromScript " << ArgS(ca.args, 0);
InterpreterReplier replier(ca.translator);
facade::SinkReplyBuilder* orig = cntx->Inject(&replier);