From 9fb8444d24734e68d6777a05c367c783db6544fc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 20 Jan 2021 18:16:54 +0800 Subject: [PATCH] fix(imperative): catch python exception in c++ GitOrigin-RevId: 16a2abfdad35c52d50f34783d29c2d503ab29568 --- imperative/python/src/tensor.cpp | 10 ++++++---- imperative/python/src/trace.cpp | 9 ++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index e8484d9a6..276280f46 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -240,10 +240,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { pyf = cpp_apply_const_with_tracing; } - auto ret = py::reinterpret_steal( - PyObject_Call(pyf, tup.ptr(), nullptr)); - auto py_ret = py::reinterpret_borrow(ret); - if (auto* t = try_cast(py_ret[0].ptr())) { + auto py_ret = PyObject_Call(pyf, tup.ptr(), nullptr); + if (!py_ret) throw py::error_already_set(); + auto py_list = py::reinterpret_steal(py_ret); + if (auto* t = try_cast(py_list[0].ptr())) { m_tensor = t->m_tensor; } return; @@ -389,6 +389,7 @@ PyObject* TensorWrapper::device() { PyObject* TensorWrapper::numpy() { if (m_tensor->m_trace_info.compiled_info != nullptr) { PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); + if (!np_val) throw py::error_already_set(); if (np_val == Py_None) { throw TraceReadError("value of this tensor is not read in trace"); } @@ -478,6 +479,7 @@ PyObject* TensorWrapper::detach() { PyObject* TensorWrapper::_dev_tensor(){ if (m_tensor->m_trace_info.compiled_info != nullptr) { auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); + if (!dev_tensor) throw py::error_already_set(); if (dev_tensor == Py_None) { throw TraceReadError("raw data of this tensor is not read in trace"); } diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index 3a477e244..9571c5cda 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -31,9 +31,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { } py::object ret = py::reinterpret_steal( PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); - if (!ret) { - throw py::value_error("invalid py object call"); - } + if (!ret) throw py::error_already_set(); // assumption: python function always returns PyList auto tup = py::reinterpret_borrow(ret); @@ -58,8 +56,9 @@ apply_result_t apply_trace(ApplyContext& ctx) { for (size_t i = 0; i < ctx.nargs; i++) { args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); } - auto ret = py::reinterpret_steal( - PyObject_Call(pyf, args.ptr(), nullptr)); + auto pyout = PyObject_Call(pyf, args.ptr(), nullptr); + if (!pyout) throw py::error_already_set(); + auto ret = py::reinterpret_steal(pyout); // assumption: python function always returns PyList auto tup = py::reinterpret_borrow(ret); -- GitLab