提交 d4ada69d 编写于 作者: M Megvii Engine Team

refactor(mge): trace exception in compiled info

GitOrigin-RevId: 508f5463b9d7b0aaf601bcf8fc88a5673d6cb0e7
上级 c9c3429a
......@@ -414,7 +414,7 @@ class trace:
for x in escaped_tensors:
try:
assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
except TraceMismatchError:
except RuntimeError:
# TraceMismatchError thrown in do_exit
pass
self._graph.wait()
......@@ -954,7 +954,8 @@ class CompiledTensorProxy:
elif self.__info.data_read:
self.__shape = self._dev_tensor().shape
else:
raise TraceMismatchError("shape of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
return self.__shape
def numpy(self):
......@@ -964,7 +965,8 @@ class CompiledTensorProxy:
elif self.__info.data_read:
self.__value = self._dev_tensor().numpy()
else:
raise TraceMismatchError("value of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
if self._isscalar:
self.__value = self.__value.squeeze()
return self.__value
......@@ -972,7 +974,8 @@ class CompiledTensorProxy:
def _dev_tensor(self):
if self.__data is None:
if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
self.__data = self.__info.data_reader.get_value()
return self.__data
......
......@@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
if (shp == Py_None) {
throw TraceReadError("shape of this tensor is not read in trace");
}
return shp;
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
......@@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() {
if (m_tensor->m_trace_info.compiled_info != nullptr) {
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (np_val == Py_None) {
throw TraceReadError("value of this tensor is not read in trace");
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
}
......@@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() {
PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.compiled_info != nullptr) {
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh));
return dev_tensor;
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
......
......@@ -10,9 +10,19 @@
*/
#include "./tensor.h"
#include <stdexcept>
namespace mgb::imperative::python {
class TraceReadError : public std::exception {
public:
explicit TraceReadError(const char * m) : message{m} {}
const char * what() const noexcept override {return message.c_str();}
private:
std::string message = "";
};
apply_result_t apply_trace(ApplyContext& ctx);
} // namespace mgb::imperative::python
......@@ -311,7 +311,6 @@ def test_trace_warp_perspective():
f(x, M)
@pytest.mark.skip(reason="skip")
def test_raise_on_trace():
step_count = 0
catch_count = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册