From d4ada69d3b4a749982057f8618b91dd8f24d1d35 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 8 Jan 2021 10:02:14 +0800 Subject: [PATCH] refactor(mge): trace exception in compiled info GitOrigin-RevId: 508f5463b9d7b0aaf601bcf8fc88a5673d6cb0e7 --- imperative/python/megengine/jit/tracing.py | 11 +++++++---- imperative/python/src/tensor.cpp | 14 +++++++++++++- imperative/python/src/trace.h | 10 ++++++++++ imperative/python/test/unit/test_tracing.py | 1 - 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 7e1573077..05ae3d961 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -414,7 +414,7 @@ class trace: for x in escaped_tensors: try: assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) - except TraceMismatchError: + except RuntimeError: # TraceMismatchError thrown in do_exit pass self._graph.wait() @@ -954,7 +954,8 @@ class CompiledTensorProxy: elif self.__info.data_read: self.__shape = self._dev_tensor().shape else: - raise TraceMismatchError("shape of this tensor is not read in trace") + # c++ will throw TraceReadError + return None return self.__shape def numpy(self): @@ -964,7 +965,8 @@ class CompiledTensorProxy: elif self.__info.data_read: self.__value = self._dev_tensor().numpy() else: - raise TraceMismatchError("value of this tensor is not read in trace") + # c++ will throw TraceReadError + return None if self._isscalar: self.__value = self.__value.squeeze() return self.__value @@ -972,7 +974,8 @@ class CompiledTensorProxy: def _dev_tensor(self): if self.__data is None: if not self.__info.data_read: - raise TraceMismatchError("raw data of this tensor is not read in trace") + # c++ will throw TraceReadError + return None self.__data = self.__info.data_reader.get_value() return self.__data diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index f487ee8be..6faa030af 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() { if (m_tensor->m_flags & Tensor::Flags::SCALAR) { return PyTuple_New(0); } - return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); + PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); + if (shp == Py_None) { + throw TraceReadError("shape of this tensor is not read in trace"); + } + return shp; } if (m_tensor->m_trace_info.recording && !skip_tracing) { PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); @@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() { PyObject* TensorWrapper::numpy() { if (m_tensor->m_trace_info.compiled_info != nullptr) { PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); + if (np_val == Py_None) { + throw TraceReadError("value of this tensor is not read in trace"); + } if (m_tensor->m_flags & Tensor::Flags::SCALAR) { np_val = PyArray_Squeeze(reinterpret_cast(np_val)); } @@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() { PyObject* TensorWrapper::_dev_tensor(){ if (m_tensor->m_trace_info.compiled_info != nullptr) { auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); + if (dev_tensor == Py_None) { + throw TraceReadError("raw data of this tensor is not read in trace"); + } auto py_dev_tensor = py::reinterpret_borrow(dev_tensor); auto sh = interpreter_for_py->put(py_dev_tensor.cast()); m_tensor->m_handle = std::move(SharedHandle(sh)); + + return dev_tensor; } if (m_tensor->m_trace_info.recording && !skip_tracing) { PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); diff --git a/imperative/python/src/trace.h b/imperative/python/src/trace.h index c81ccf857..ab56e8603 100644 --- a/imperative/python/src/trace.h +++ b/imperative/python/src/trace.h @@ -10,9 +10,19 @@ */ #include "./tensor.h" +#include namespace mgb::imperative::python { +class TraceReadError : public std::exception { +public: + explicit TraceReadError(const char * m) : message{m} {} + const char * what() const noexcept override {return message.c_str();} +private: + std::string message = ""; +}; + + apply_result_t apply_trace(ApplyContext& ctx); } // namespace mgb::imperative::python diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index ec4924e14..4eaeb3da2 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -311,7 +311,6 @@ def test_trace_warp_perspective(): f(x, M) -@pytest.mark.skip(reason="skip") def test_raise_on_trace(): step_count = 0 catch_count = 0 -- GitLab