fix: report errors from commands with redis.call (#1108)

Redis call now directly reports erros
This commit is contained in:
Vladislav 2023-04-18 17:29:07 +03:00 committed by GitHub
parent 1382ed1c37
commit 77e18f0463
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 34 deletions

View file

@ -44,6 +44,27 @@ int EVPDigest(const void* data, size_t datalen, unsigned char* md, size_t* mdlen
return ret;
}
/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
* with a single "err" field set to the error string. Note that this
* table is never a valid reply by proper commands, since the returned
* tables are otherwise always indexed by integers, never by strings. */
void PushError(lua_State* lua, string_view error, bool trace = true) {
lua_Debug dbg;
lua_newtable(lua);
lua_pushstring(lua, "err");
/* Attempt to figure out where this function was called, if possible */
if (trace && lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) {
string msg = absl::StrCat(dbg.source, ": ", dbg.currentline, ": ", error);
lua_pushlstring(lua, msg.c_str(), msg.size());
} else {
lua_pushlstring(lua, error.data(), error.size());
}
lua_settable(lua, -3);
}
class RedisTranslator : public ObjectExplorer {
public:
RedisTranslator(lua_State* lua) : lua_(lua) {
@ -58,6 +79,8 @@ class RedisTranslator : public ObjectExplorer {
void OnStatus(std::string_view str) final;
void OnError(std::string_view str) final;
bool HasError();
private:
void ArrayPre() {
}
@ -68,8 +91,9 @@ class RedisTranslator : public ObjectExplorer {
}
}
vector<unsigned> array_index_;
lua_State* lua_;
bool has_error_{false};
vector<unsigned> array_index_{};
};
void RedisTranslator::OnBool(bool b) {
@ -123,11 +147,8 @@ void RedisTranslator::OnStatus(std::string_view str) {
}
void RedisTranslator::OnError(std::string_view str) {
CHECK(array_index_.empty()) << "unexpected error";
lua_newtable(lua_);
lua_pushstring(lua_, "err");
lua_pushlstring(lua_, str.data(), str.size());
lua_settable(lua_, -3);
has_error_ = true;
PushError(lua_, str, false);
}
void RedisTranslator::OnArrayStart(unsigned len) {
@ -144,6 +165,10 @@ void RedisTranslator::OnArrayEnd() {
ArrayPost();
}
bool RedisTranslator::HasError() {
return has_error_;
}
void RunSafe(lua_State* lua, string_view buf, const char* name) {
CHECK_EQ(0, luaL_loadbuffer(lua, buf.data(), buf.size(), name));
int err = lua_pcall(lua, 0, 0, 0);
@ -181,27 +206,6 @@ void SetGlobalArrayInternal(lua_State* lua, const char* name, MutSliceSpan args)
lua_setglobal(lua, name);
}
/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
* with a single "err" field set to the error string. Note that this
* table is never a valid reply by proper commands, since the returned
* tables are otherwise always indexed by integers, never by strings. */
void PushError(lua_State* lua, const char* error) {
lua_Debug dbg;
lua_newtable(lua);
lua_pushstring(lua, "err");
/* Attempt to figure out where this function was called, if possible */
if (lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) {
string msg = absl::StrCat(dbg.source, ": ", dbg.currentline, ": ", error);
lua_pushstring(lua, msg.c_str());
} else {
lua_pushstring(lua, error);
}
lua_settable(lua, -3);
}
/* In case the error set into the Lua stack by PushError() was generated
* by the non-error-trapping version of redis.pcall(), which is redis.call(),
* this function will raise the Lua error so that the execution of the
@ -717,12 +721,17 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async) {
lua_pop(lua_, argc);
RedisTranslator translator(lua_);
redis_func_(CallArgs{MutSliceSpan{args}, &buffer, &translator, async});
cmd_depth_--;
// Raise error for regular 'call' command if needed.
if (raise_error && translator.HasError()) {
// error is already on top of stack
return RaiseError(lua_);
}
if (!async)
DCHECK_EQ(1, lua_gettop(lua_));
cmd_depth_--;
return 1;
}

View file

@ -276,19 +276,19 @@ TEST_F(InterpreterTest, Call) {
};
intptr_.SetRedisFunc(cb);
ASSERT_TRUE(Execute("local var = redis.call('string'); return {type(var), var}"));
ASSERT_TRUE(Execute("local var = redis.pcall('string'); return {type(var), var}"));
EXPECT_EQ("[str(string) str(foo)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('double'); return {type(var), var}"));
EXPECT_TRUE(Execute("local var = redis.pcall('double'); return {type(var), var}"));
EXPECT_EQ("[str(number) d(3.1415)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('int'); return {type(var), var}"));
EXPECT_TRUE(Execute("local var = redis.pcall('int'); return {type(var), var}"));
EXPECT_EQ("[str(number) i(42)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('err'); return {type(var), var}"));
EXPECT_TRUE(Execute("local var = redis.pcall('err'); return {type(var), var}"));
EXPECT_EQ("[str(table) err(myerr)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('status'); return {type(var), var}"));
EXPECT_TRUE(Execute("local var = redis.pcall('status'); return {type(var), var}"));
EXPECT_EQ("[str(table) status(mystatus)]", ser_.res);
}

View file

@ -201,3 +201,21 @@ async def test_golang_asynq_script(async_pool, num_queues=10, num_tasks=100):
for job in jobs:
await job
ERROR_CALL_SCRIPT = """
redis.call('ECHO', 'I', 'want', 'an', 'error')
"""
ERROR_PCALL_SCRIPT = """
redis.pcall('ECHO', 'I', 'want', 'an', 'error')
"""
@pytest.mark.asyncio
async def test_eval_error_propagation(async_client):
assert await async_client.eval(ERROR_PCALL_SCRIPT, 0) is None
try:
await async_client.eval(ERROR_CALL_SCRIPT, 0)
assert False, "Eval must have thrown an error"
except aioredis.RedisError as e:
pass