mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 10:25:47 +02:00
fix: buffer overrun when passing long command name from lua (#2008)
Also, few additional changes that do not affect functionality. 1. make sure passed arguments to DispatchCommand are `\0` delimited during pipelining. 2. extend lua malloc hook to call precise functions - to help with cpu profiling. 3. reuse arguments buffer (save allocations) when calling Dragonfly command from lua scripts. Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
parent
c6f8f3882a
commit
ec87114f66
7 changed files with 56 additions and 36 deletions
|
@ -7,8 +7,8 @@ endif()
|
|||
|
||||
add_third_party(
|
||||
lua
|
||||
URL https://github.com/lua/lua/archive/refs/tags/v5.4.4.tar.gz
|
||||
PATCH_COMMAND patch -p1 -i "${CMAKE_SOURCE_DIR}/patches/lua-v5.4.4.patch"
|
||||
GIT_REPOSITORY https://github.com/dragonflydb/lua
|
||||
GIT_TAG Dragonfly-5.4.6
|
||||
CONFIGURE_COMMAND echo
|
||||
BUILD_IN_SOURCE 1
|
||||
BUILD_COMMAND ${DFLY_TOOLS_MAKE} all
|
||||
|
|
|
@ -113,7 +113,7 @@ void RedisTranslator::OnString(std::string_view str) {
|
|||
}
|
||||
|
||||
void RedisTranslator::OnDouble(double d) {
|
||||
static constexpr double kConvertEps = std::numeric_limits<double>::epsilon();
|
||||
const double kConvertEps = std::numeric_limits<double>::epsilon();
|
||||
|
||||
double fractpart, intpart;
|
||||
fractpart = modf(d, &intpart);
|
||||
|
@ -351,14 +351,14 @@ int RedisStatusReplyCommand(lua_State* lua) {
|
|||
return SingleFieldTable(lua, "ok");
|
||||
}
|
||||
|
||||
// const char* kInstanceKey = "_INSTANCE";
|
||||
|
||||
// See https://www.lua.org/manual/5.3/manual.html#lua_Alloc
|
||||
void* mimalloc_glue(void* ud, void* ptr, size_t osize, size_t nsize) {
|
||||
(void)ud;
|
||||
(void)osize; /* not used */
|
||||
if (nsize == 0) {
|
||||
mi_free(ptr);
|
||||
mi_free_size(ptr, osize);
|
||||
return nullptr;
|
||||
} else if (ptr == nullptr) {
|
||||
return mi_malloc(nsize);
|
||||
} else {
|
||||
return mi_realloc(ptr, nsize);
|
||||
}
|
||||
|
@ -730,12 +730,16 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
|
|||
cmd_depth_++;
|
||||
int argc = lua_gettop(lua_);
|
||||
|
||||
#define RETURN_ERROR(err) \
|
||||
{ \
|
||||
PushError(lua_, err); \
|
||||
cmd_depth_--; \
|
||||
return raise_error ? RaiseError(lua_) : 1; \
|
||||
}
|
||||
|
||||
/* Require at least one argument */
|
||||
if (argc == 0) {
|
||||
PushError(lua_, "Please specify at least one argument for redis.call()");
|
||||
cmd_depth_--;
|
||||
|
||||
return raise_error ? RaiseError(lua_) : 1;
|
||||
RETURN_ERROR("Please specify at least one argument for redis.call()");
|
||||
}
|
||||
|
||||
size_t blob_len = 0;
|
||||
|
@ -755,28 +759,28 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
|
|||
}
|
||||
continue;
|
||||
case LUA_TSTRING:
|
||||
blob_len += lua_rawlen(lua_, idx);
|
||||
blob_len += lua_rawlen(lua_, idx) + 1;
|
||||
continue;
|
||||
default:
|
||||
PushError(lua_, "Lua redis() command arguments must be strings or integers");
|
||||
cmd_depth_--;
|
||||
return raise_error ? RaiseError(lua_) : 1;
|
||||
RETURN_ERROR("Lua redis() command arguments must be strings or integers");
|
||||
}
|
||||
}
|
||||
|
||||
char name_buffer[32]; // backing storage for cmd name
|
||||
string buffer(blob_len + 4, '\0'); // backing storage for args
|
||||
char name_buffer[32]; // backing storage for cmd name
|
||||
absl::FixedArray<absl::Span<char>, 4> args(argc);
|
||||
|
||||
char* cur = buffer.data();
|
||||
char* end = cur + blob_len;
|
||||
|
||||
// Copy command name to name_buffer and set it as first arg.
|
||||
unsigned len = lua_rawlen(lua_, 1);
|
||||
DCHECK_LT(len, ABSL_ARRAYSIZE(name_buffer));
|
||||
memcpy(name_buffer, lua_tostring(lua_, 1), len);
|
||||
args[0] = {name_buffer, len};
|
||||
unsigned name_len = lua_rawlen(lua_, 1);
|
||||
if (name_len >= sizeof(name_buffer)) {
|
||||
RETURN_ERROR("Lua redis() command name too long");
|
||||
}
|
||||
|
||||
memcpy(name_buffer, lua_tostring(lua_, 1), name_len);
|
||||
args[0] = {name_buffer, name_len};
|
||||
buffer_.resize(blob_len + 4, '\0'); // backing storage for args
|
||||
|
||||
char* cur = buffer_.data();
|
||||
char* end = cur + blob_len;
|
||||
for (int idx = 2; idx <= argc; idx++) {
|
||||
size_t len = 0;
|
||||
switch (lua_type(lua_, idx)) {
|
||||
|
@ -794,7 +798,7 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
|
|||
break;
|
||||
case LUA_TSTRING:
|
||||
len = lua_rawlen(lua_, idx);
|
||||
memcpy(cur, lua_tostring(lua_, idx), len);
|
||||
memcpy(cur, lua_tostring(lua_, idx), len + 1); // + 1 for null terminator
|
||||
};
|
||||
|
||||
args[idx - 1] = {cur, len};
|
||||
|
@ -805,9 +809,16 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
|
|||
* and this way we guaranty we will have room on the stack for the result. */
|
||||
lua_pop(lua_, argc);
|
||||
RedisTranslator translator(lua_);
|
||||
redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async, raise_error, &raise_error});
|
||||
redis_func_(
|
||||
CallArgs{MutSliceSpan{args}, &buffer_, &translator, async, raise_error, &raise_error});
|
||||
cmd_depth_--;
|
||||
|
||||
// Shrink reusable buffer if it's too big.
|
||||
if (buffer_.capacity() > 128) {
|
||||
buffer_.clear();
|
||||
buffer_.shrink_to_fit();
|
||||
}
|
||||
|
||||
// Raise error for regular 'call' command if needed.
|
||||
if (raise_error && translator.HasError()) {
|
||||
// error is already on top of stack
|
||||
|
|
|
@ -130,6 +130,7 @@ class Interpreter {
|
|||
lua_State* lua_;
|
||||
unsigned cmd_depth_ = 0;
|
||||
RedisFunc redis_func_;
|
||||
std::string buffer_;
|
||||
};
|
||||
|
||||
// Manages an internal interpreter pool. This allows multiple connections residing on the same
|
||||
|
|
|
@ -77,19 +77,21 @@ class InterpreterTest : public ::testing::Test {
|
|||
CHECK_EQ(0, lua_pcall(lua(), 0, num_results, 0));
|
||||
}
|
||||
|
||||
void SetGlobalArray(const char* name, vector<string> vec);
|
||||
void SetGlobalArray(const char* name, const vector<string_view>& vec);
|
||||
|
||||
bool Execute(string_view script);
|
||||
|
||||
Interpreter intptr_;
|
||||
TestSerializer ser_;
|
||||
string error_;
|
||||
vector<unique_ptr<string>> strings_;
|
||||
};
|
||||
|
||||
void InterpreterTest::SetGlobalArray(const char* name, vector<string> vec) {
|
||||
void InterpreterTest::SetGlobalArray(const char* name, const vector<string_view>& vec) {
|
||||
vector<MutableSlice> slices(vec.size());
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
slices[i] = MutableSlice{vec[i]};
|
||||
strings_.emplace_back(new string(vec[i]));
|
||||
slices[i] = MutableSlice{*strings_.back()};
|
||||
}
|
||||
intptr_.SetGlobalArray(name, MutSliceSpan{slices});
|
||||
}
|
||||
|
@ -318,6 +320,10 @@ TEST_F(InterpreterTest, ArgKeys) {
|
|||
SetGlobalArray("KEYS", {"key1", "key2"});
|
||||
EXPECT_TRUE(Execute("return {ARGV[1], KEYS[1], KEYS[2]}"));
|
||||
EXPECT_EQ("[str(foo) str(key1) str(key2)]", ser_.res);
|
||||
|
||||
SetGlobalArray("INTKEYS", {"123456", "1"});
|
||||
EXPECT_TRUE(Execute("return INTKEYS[1] + 0")) << error_;
|
||||
EXPECT_EQ("i(123456)", ser_.res);
|
||||
}
|
||||
|
||||
TEST_F(InterpreterTest, Modules) {
|
||||
|
|
|
@ -159,8 +159,9 @@ void Connection::PipelineMessage::SetArgs(const RespVec& args) {
|
|||
size_t s = buf.size();
|
||||
if (s)
|
||||
memcpy(next, buf.data(), s);
|
||||
next[s] = '\0';
|
||||
this->args[i] = MutableSlice(next, s);
|
||||
next += s;
|
||||
next += (s + 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -959,7 +960,7 @@ Connection::PipelineMessagePtr Connection::FromArgs(RespVec args, mi_heap_t* hea
|
|||
size_t backed_sz = 0;
|
||||
for (const auto& arg : args) {
|
||||
CHECK_EQ(RespExpr::STRING, arg.type);
|
||||
backed_sz += arg.GetBuf().size();
|
||||
backed_sz += arg.GetBuf().size() + 1; // for '\0'
|
||||
}
|
||||
DCHECK(backed_sz);
|
||||
|
||||
|
|
|
@ -21,13 +21,13 @@ using namespace facade;
|
|||
StoredCmd::StoredCmd(const CommandId* cid, CmdArgList args, facade::ReplyMode mode)
|
||||
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
|
||||
size_t total_size = 0;
|
||||
for (auto args : args)
|
||||
total_size += args.size();
|
||||
|
||||
for (auto args : args) {
|
||||
total_size += args.size() + 1; // +1 for null terminator
|
||||
}
|
||||
buffer_.resize(total_size);
|
||||
char* next = buffer_.data();
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
memcpy(next, args[i].data(), args[i].size());
|
||||
memcpy(next, args[i].data(), args[i].size() + 1);
|
||||
sizes_[i] = args[i].size();
|
||||
next += args[i].size();
|
||||
}
|
||||
|
|
|
@ -371,6 +371,7 @@ void InterpreterReplier::PostItem() {
|
|||
|
||||
void InterpreterReplier::SendError(string_view str, std::string_view type) {
|
||||
DCHECK(array_len_.empty());
|
||||
DVLOG(1) << "Lua/df_call error " << str;
|
||||
explr_->OnError(str);
|
||||
}
|
||||
|
||||
|
@ -1377,7 +1378,7 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
|
|||
|
||||
void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) {
|
||||
DCHECK(cntx->transaction);
|
||||
DVLOG(1) << "CallFromScript " << cntx->transaction->DebugId() << " " << ArgS(ca.args, 0);
|
||||
DVLOG(1) << "CallFromScript " << ArgS(ca.args, 0);
|
||||
|
||||
InterpreterReplier replier(ca.translator);
|
||||
facade::SinkReplyBuilder* orig = cntx->Inject(&replier);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue