提交 32ad9265 编写于 作者: M Megvii Engine Team

feat(mge): improve presentation of async errors

GitOrigin-RevId: af5426c37d311f63ba7e56406d0e6994c8e24143
上级 e6a8b025
......@@ -11,6 +11,7 @@
#pragma once
#include <exception>
#include <stdexcept>
#include <vector>
#include <utility>
......@@ -69,10 +70,43 @@ inline int cvt_retint(int ret) {
struct py_err_set : std::exception {};
#define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \
catch(pybind11::error_already_set& e) {e.restore(); return RET;} \
catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \
catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;}
// refer to pybind11 for the following exception handling helper
inline void pybind11_translate_exception(std::exception_ptr last_exception) {
auto &registered_exception_translators = pybind11::detail::get_internals().registered_exception_translators;
for (auto& translator : registered_exception_translators) {
try {
translator(last_exception);
} catch (...) {
last_exception = std::current_exception();
continue;
}
return;
}
PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!");
}
inline void pybind11_translate_exception() {
pybind11_translate_exception(std::current_exception());
}
#if defined(__GNUG__) && !defined(__clang__)
#define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND catch (::abi::__forced_unwind&) {throw;}
#else
#define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND
#endif
#define PYEXT17_TRANSLATE_EXC \
catch(::pyext17::py_err_set&) {} \
catch(::pybind11::error_already_set& e) {e.restore();} \
PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
catch(...) {::pyext17::pybind11_translate_exception();}
#define PYEXT17_TRANSLATE_EXC_RET(RET) \
catch(::pyext17::py_err_set&) {return RET;} \
catch(::pybind11::error_already_set& e) {e.restore(); return RET;} \
PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
catch(...) {::pyext17::pybind11_translate_exception(); return RET;};
template <typename T>
struct wrap {
......@@ -134,7 +168,7 @@ private:
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
CVT_RET_PYOBJ((inst->*f)());
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
};
......@@ -146,7 +180,7 @@ private:
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
CVT_RET_PYOBJ((inst->*f)(args, kwargs));
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
};
......@@ -159,7 +193,7 @@ private:
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
CVT_RET_PYOBJ((inst->*f)(args, nargs));
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
#else
static constexpr int flags = METH_VARARGS;
......@@ -170,7 +204,7 @@ private:
auto size = PyTuple_GET_SIZE(args);
try {
CVT_RET_PYOBJ((inst->*f)(arr, size));
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
#endif
};
......@@ -183,7 +217,7 @@ private:
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
CVT_RET_PYOBJ((inst->*f)(obj));
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
};
......@@ -209,7 +243,7 @@ private:
} else {
static_assert(!std::is_same_v<F, F>);
}
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
};
......@@ -230,7 +264,7 @@ private:
} else {
static_assert(!std::is_same_v<F, F>);
}
} HANDLE_ALL_EXC(-1)
} PYEXT17_TRANSLATE_EXC_RET(-1)
}
static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr;
......@@ -314,7 +348,7 @@ private:
} else {
new(inst) T();
}
} HANDLE_ALL_EXC(nullptr)
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
free_guard.self = nullptr;
return self;
}
......@@ -464,7 +498,7 @@ public:
new(inst) T(std::forward<Args>(args)...);
return self;
}
struct caster {
static constexpr auto name = T::tp_name;
......@@ -493,4 +527,3 @@ public:
#undef HAS_MEMBER
#undef CVT_RET_PYOBJ
#undef CVT_RET_INT
#undef HANDLE_ALL_EXC
......@@ -26,8 +26,11 @@
#include "./graph_rt.h"
#include "./helper.h"
#include <object.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pytypes.h>
#include <pyerrors.h>
#include <range/v3/all.hpp>
#include <string>
......@@ -230,10 +233,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
}
return ret.release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
......@@ -391,7 +391,7 @@ void TensorWrapper::set_handle(PyObject* dest) {
PyObject* TensorWrapper::shape() {
// if it's tracing compiled mode, get value from compiled_info
// if it's tracing compiled mode, get value from compiled_info
if (m_tensor->m_trace_info.compiled_info != nullptr) {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
......@@ -821,10 +821,7 @@ PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) {
try {
PyArray_Descr* res = _dtype_promotion(args, nargs);
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
......@@ -835,10 +832,7 @@ PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
try {
CompNode cn = _get_device(args, nargs);
return py::cast(cn).release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
} PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
#ifdef METH_FASTCALL
......@@ -865,6 +859,34 @@ void init_tensor(py::module m) {
static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel();
interpreter_for_py = sl_interpreter_for_py.get();
static py::exception<interpreter::AsyncError> py_async_error(m, "AsyncError", PyExc_RuntimeError);
py::register_exception_translator([](std::exception_ptr p) {
try {
if (p) std::rethrow_exception(p);
} catch (const interpreter::AsyncError& e) {
pyext17::pybind11_translate_exception(e.nested_ptr());
if (PyErr_Occurred()) {
PyObject *exc, *val, *tb;
PyErr_Fetch(&exc, &val, &tb);
PyErr_NormalizeException(&exc, &val, &tb);
if (tb) {
PyException_SetTraceback(val, tb);
}
auto val2 = py_async_error.py::object::operator()(
"An async error is reported. See above for the actual cause."
" Hint: This is where it is reported, not where it happened."
" You may call `megengine.core.set_option('async_level', 0)` to get better error reporting."
);
PyException_SetCause(val2.ptr(), val); // PyException_SetCause steals reference
Py_XDECREF(exc);
Py_XDECREF(tb);
PyErr_Restore(py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
} else {
py_async_error("Unkown async error");
}
}
});
auto* tensor_type = TensorWrapper::wrap_t::type()
.def<&TensorWrapper::numpy>("numpy")
.def_getset<&TensorWrapper::shape>("shape")
......@@ -932,7 +954,7 @@ void init_tensor(py::module m) {
if (v->is_scalar) {
return py::object(py::array(np_val).squeeze());
}
return np_val;
return np_val;
})
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
......
......@@ -7,6 +7,7 @@ import pytest
import megengine as mge
import megengine.functional as F
from megengine.core._imperative_rt.core2 import (
AsyncError,
_set_drop_flag,
_set_swap_flag,
config_async_level,
......@@ -98,3 +99,25 @@ def test_regression_2870():
with pytest.raises(RuntimeError):
y.numpy()
(x + x).numpy()
# NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for
# mge.core.set_option('async_level', 0).
# If you change the canonical API to set async level,
# update the error message of AsyncError as well.
def test_async_error():
orig_lvl = mge.core.get_option("async_level")
try:
mge.core.set_option("async_level", 1)
x = F.utils._simulate_error()
try:
x.numpy()
except AsyncError as e:
assert isinstance(e.__cause__, RuntimeError)
mge.core.set_option("async_level", 0)
with pytest.raises(RuntimeError):
F.utils._simulate_error()
finally:
mge.core.set_option("async_level", orig_lvl)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册