提交 d7a312d0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2544 Fix the bug of pynative mode catching the exception.

Merge pull request !2544 from rick_sanchez/huangyong
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <unordered_set> #include <unordered_set>
#include <algorithm> #include <algorithm>
#include "debug/trace.h"
#include "ir/tensor_py.h" #include "ir/tensor_py.h"
#include "ir/param_value_py.h" #include "ir/param_value_py.h"
#include "utils/any.h" #include "utils/any.h"
...@@ -66,6 +67,42 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; ...@@ -66,6 +67,42 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_; std::mutex PynativeExecutor::instance_lock_;
ResourcePtr PynativeExecutor::resource_; ResourcePtr PynativeExecutor::resource_;
template <typename... Args>
void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
try {
(executor->*method)(args...);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
std::ostringstream oss;
trace::TraceGraphEval();
trace::GetEvalStackInfo(oss);
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info
py::print(oss.str());
MS_LOG(ERROR) << oss.str();
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::type_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::type_error(ex);
} catch (const py::value_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::value_error(ex);
} catch (const py::index_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::index_error(ex);
} catch (const std::exception &ex) {
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(std::runtime_error(ex.what()));
} catch (...) {
PynativeExecutor::GetInstance()->Clean();
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
}
}
inline ValuePtr PyAttrValue(const py::object &obj) { inline ValuePtr PyAttrValue(const py::object &obj) {
ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj); ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
if (!converted_ret) { if (!converted_ret) {
...@@ -144,7 +181,7 @@ std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args, ...@@ -144,7 +181,7 @@ std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
} }
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
py::list *out_args_list) { py::list *const out_args_list) {
auto &py_args = *out_args; auto &py_args = *out_args;
py::tuple input_mask(args.size()); py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
...@@ -564,7 +601,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { ...@@ -564,7 +601,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return node; return node;
} }
py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) { py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true); mindspore::parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy; MsBackendPolicy backend_policy;
...@@ -603,7 +640,7 @@ py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) { ...@@ -603,7 +640,7 @@ py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return result; return result;
} }
py::tuple RunOp(const py::args &args) { py::tuple RunOpInner(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size(); MS_LOG(DEBUG) << "RunOp start" << args.size();
py::list args_input = args[PY_INPUTS]; py::list args_input = args[PY_INPUTS];
...@@ -623,7 +660,42 @@ py::tuple RunOp(const py::args &args) { ...@@ -623,7 +660,42 @@ py::tuple RunOp(const py::args &args) {
return value_ret; return value_ret;
} }
} }
return RunOp(op_exec_info, args_input); return RunOpInner(op_exec_info, args_input);
}
py::tuple RunOp(const py::args &args) {
try {
return RunOpInner(args);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
std::ostringstream oss;
trace::TraceGraphEval();
trace::GetEvalStackInfo(oss);
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info
py::print(oss.str());
MS_LOG(ERROR) << oss.str();
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::type_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::type_error(ex);
} catch (const py::value_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::value_error(ex);
} catch (const py::index_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::index_error(ex);
} catch (const std::exception &ex) {
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(std::runtime_error(ex.what()));
} catch (...) {
PynativeExecutor::GetInstance()->Clean();
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
}
} }
void ClearPyNativeSession() { session = nullptr; } void ClearPyNativeSession() { session = nullptr; }
...@@ -632,7 +704,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); } ...@@ -632,7 +704,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell); auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) { if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Newgraph already compiled"; MS_LOG(DEBUG) << "Newgraph already compiled";
...@@ -753,7 +825,7 @@ void PynativeExecutor::Popp() { ...@@ -753,7 +825,7 @@ void PynativeExecutor::Popp() {
graph_p_.pop(); graph_p_.pop();
} }
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetId(cell); auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) { if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Endgraph already compiled"; MS_LOG(DEBUG) << "Endgraph already compiled";
...@@ -892,8 +964,8 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args ...@@ -892,8 +964,8 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
return args_spec; return args_spec;
} }
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) { const py::args &args) {
MS_LOG(INFO) << "GradNet start" << args.size(); MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size(); std::size_t size = args.size();
...@@ -939,8 +1011,10 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c ...@@ -939,8 +1011,10 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
} }
void PynativeExecutor::Clear(const std::string &flag) { void PynativeExecutor::Clear(const std::string &flag) {
if (flag == "resource") { if (!flag.empty()) {
MS_LOG(INFO) << "Clear res"; MS_LOG(INFO) << "Clear res";
(void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag);
Clean(); Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag. // Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
...@@ -949,6 +1023,7 @@ void PynativeExecutor::Clear(const std::string &flag) { ...@@ -949,6 +1023,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
} }
return; return;
} }
MS_LOG(INFO) << "Clear"; MS_LOG(INFO) << "Clear";
top_g_ = nullptr; top_g_ = nullptr;
curr_g_ = nullptr; curr_g_ = nullptr;
...@@ -1010,6 +1085,19 @@ FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr ...@@ -1010,6 +1085,19 @@ FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr
return df_builder_; return df_builder_;
} }
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
}
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_") (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
......
...@@ -46,7 +46,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -46,7 +46,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple RunOp(const py::args &args); py::tuple RunOp(const py::args &args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
py::list *out_args_list); py::list *const out_args_list);
void ClearPyNativeSession(); void ClearPyNativeSession();
...@@ -68,11 +68,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { ...@@ -68,11 +68,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
return executor_; return executor_;
} }
void NewGraph(const py::object &cell, const py::args &args); void NewGraph(const py::object &cell, const py::args &args);
void NewGraphInner(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights); std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); abstract::AbstractBasePtrList GetArgsSpec(const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
void Clear(const std::string &flag = ""); void Clear(const std::string &flag = "");
void Clean(); void Clean();
void ClearRes(); void ClearRes();
......
...@@ -186,7 +186,7 @@ class Cell: ...@@ -186,7 +186,7 @@ class Cell:
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
def __del__(self): def __del__(self):
_pynative_exec.clear("resource") _pynative_exec.clear(str(id(self)))
if hasattr(self, "_create_time"): if hasattr(self, "_create_time"):
_executor.del_net_res(str(self._create_time)) _executor.del_net_res(str(self._create_time))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册