提交 d4ada69d 编写于 作者: M Megvii Engine Team

refactor(mge): trace exception in compiled info

GitOrigin-RevId: 508f5463b9d7b0aaf601bcf8fc88a5673d6cb0e7
上级 c9c3429a
...@@ -414,7 +414,7 @@ class trace: ...@@ -414,7 +414,7 @@ class trace:
for x in escaped_tensors: for x in escaped_tensors:
try: try:
assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
except TraceMismatchError: except RuntimeError:
# TraceMismatchError thrown in do_exit # TraceMismatchError thrown in do_exit
pass pass
self._graph.wait() self._graph.wait()
...@@ -954,7 +954,8 @@ class CompiledTensorProxy: ...@@ -954,7 +954,8 @@ class CompiledTensorProxy:
elif self.__info.data_read: elif self.__info.data_read:
self.__shape = self._dev_tensor().shape self.__shape = self._dev_tensor().shape
else: else:
raise TraceMismatchError("shape of this tensor is not read in trace") # c++ will throw TraceReadError
return None
return self.__shape return self.__shape
def numpy(self): def numpy(self):
...@@ -964,7 +965,8 @@ class CompiledTensorProxy: ...@@ -964,7 +965,8 @@ class CompiledTensorProxy:
elif self.__info.data_read: elif self.__info.data_read:
self.__value = self._dev_tensor().numpy() self.__value = self._dev_tensor().numpy()
else: else:
raise TraceMismatchError("value of this tensor is not read in trace") # c++ will throw TraceReadError
return None
if self._isscalar: if self._isscalar:
self.__value = self.__value.squeeze() self.__value = self.__value.squeeze()
return self.__value return self.__value
...@@ -972,7 +974,8 @@ class CompiledTensorProxy: ...@@ -972,7 +974,8 @@ class CompiledTensorProxy:
def _dev_tensor(self): def _dev_tensor(self):
if self.__data is None: if self.__data is None:
if not self.__info.data_read: if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace") # c++ will throw TraceReadError
return None
self.__data = self.__info.data_reader.get_value() self.__data = self.__info.data_reader.get_value()
return self.__data return self.__data
......
...@@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() { ...@@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0); return PyTuple_New(0);
} }
return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
if (shp == Py_None) {
throw TraceReadError("shape of this tensor is not read in trace");
}
return shp;
} }
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
...@@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() { ...@@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
if (m_tensor->m_trace_info.compiled_info != nullptr) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (np_val == Py_None) {
throw TraceReadError("value of this tensor is not read in trace");
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val)); np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
} }
...@@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() { ...@@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() {
PyObject* TensorWrapper::_dev_tensor(){ PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.compiled_info != nullptr) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));
return dev_tensor;
} }
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
......
...@@ -10,9 +10,19 @@ ...@@ -10,9 +10,19 @@
*/ */
#include "./tensor.h" #include "./tensor.h"
#include <stdexcept>
namespace mgb::imperative::python { namespace mgb::imperative::python {
class TraceReadError : public std::exception {
public:
explicit TraceReadError(const char * m) : message{m} {}
const char * what() const noexcept override {return message.c_str();}
private:
std::string message = "";
};
apply_result_t apply_trace(ApplyContext& ctx); apply_result_t apply_trace(ApplyContext& ctx);
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
...@@ -311,7 +311,6 @@ def test_trace_warp_perspective(): ...@@ -311,7 +311,6 @@ def test_trace_warp_perspective():
f(x, M) f(x, M)
@pytest.mark.skip(reason="skip")
def test_raise_on_trace(): def test_raise_on_trace():
step_count = 0 step_count = 0
catch_count = 0 catch_count = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册