提交 04524b6b 编写于 作者: F fary86

Fix coredump caused by function call depth too large

上级 0d1a7ac6
...@@ -110,6 +110,8 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -110,6 +110,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.")
.def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.")
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.")
.def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
......
...@@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr ...@@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const AnfNodePtr &func_node = fg->get_return(); const AnfNodePtr &func_node = fg->get_return();
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
<< ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr ret_base = nullptr; AbstractBasePtr ret_base = nullptr;
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << ".";
}
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it; const auto &node = *it;
...@@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr ...@@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
} }
engine->DecreaseFunctionCallDepth();
MS_EXCEPTION_IF_NULL(ret_base); MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
......
...@@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac ...@@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
// Running the analyzer. // Running the analyzer.
ResetFunctionCallDepth();
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
MS_EXCEPTION_IF_NULL(root_context); MS_EXCEPTION_IF_NULL(root_context);
MS_EXCEPTION_IF_NULL(root_context->func_graph()); MS_EXCEPTION_IF_NULL(root_context->func_graph());
......
...@@ -185,7 +185,9 @@ struct PartialAppHasher { ...@@ -185,7 +185,9 @@ struct PartialAppHasher {
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public: public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {
function_call_depth_ = 0;
}
~AnalysisEngine() = default; ~AnalysisEngine() = default;
// func_graph: The func_graph to analyze. // func_graph: The func_graph to analyze.
...@@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { ...@@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnalysisCache cache_; AnalysisCache cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void ResetFunctionCallDepth() { function_call_depth_ = 0; }
void IncreaseFunctionCallDepth() { function_call_depth_++; }
void DecreaseFunctionCallDepth() {
if (function_call_depth_ == 0) {
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
}
function_call_depth_--;
}
unsigned int function_call_depth() { return function_call_depth_; }
private: private:
void SetUndeterminedFlag(const EvaluatorPtr &evaluator); void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
...@@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { ...@@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const ConfigPtrList &args_conf_list); const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list); const ConfigPtrList &args_conf_list);
// record current depth of function call statck
unsigned int function_call_depth_;
#ifdef DEBUG #ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_; std::vector<AnfNodePtr> compute_conf_stack_;
......
...@@ -234,6 +234,17 @@ class _Context: ...@@ -234,6 +234,17 @@ class _Context:
if not success: if not success:
raise RuntimeError("Device id set failed!!!") raise RuntimeError("Device id set failed!!!")
@property
def max_call_depth(self):
return self._context_handle.get_max_call_depth()
@max_call_depth.setter
def max_call_depth(self, max_call_depth):
if max_call_depth <= 0:
raise ValueError(
"Max call depth must be greater than 0, but got {}".format(max_call_depth))
self._context_handle.set_max_call_depth(max_call_depth)
@property @property
def enable_auto_mixed_precision(self): def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag() return self._context_handle.get_auto_mixed_precision_flag()
...@@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs): ...@@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving. data parallel training in the benefit of time and memory saving.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises: Raises:
...@@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs): ...@@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(max_call_depth=80)
""" """
_set_auto_parallel_context(**kwargs) _set_auto_parallel_context(**kwargs)
...@@ -532,7 +545,7 @@ def reset_auto_parallel_context(): ...@@ -532,7 +545,7 @@ def reset_auto_parallel_context():
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool) enable_sparse=bool, max_call_depth=int)
def set_context(**kwargs): def set_context(**kwargs):
""" """
Sets context for running environment. Sets context for running environment.
......
...@@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { ...@@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
} else { } else {
device_id_ = 0; device_id_ = 0;
} }
max_call_depth_ = MAX_CALL_DEPTH_DEFAULT;
backend_policy_ = policy_map_[policy]; backend_policy_ = policy_map_[policy];
device_target_ = target; device_target_ = target;
execution_mode_ = kPynativeMode; execution_mode_ = kPynativeMode;
......
...@@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend"; ...@@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend";
const char kDavinciInferenceDevice[] = "AscendInference"; const char kDavinciInferenceDevice[] = "AscendInference";
const char kDavinciDevice[] = "Davinci"; const char kDavinciDevice[] = "Davinci";
const char KNpuLog[] = "_npu_log"; const char KNpuLog[] = "_npu_log";
const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
// The default max available device memory is 1024GB. // The default max available device memory is 1024GB.
const float kDefaultMaxDeviceMemory = 1024; const float kDefaultMaxDeviceMemory = 1024;
...@@ -80,6 +82,13 @@ class MsContext { ...@@ -80,6 +82,13 @@ class MsContext {
uint32_t device_id() const { return device_id_; } uint32_t device_id() const { return device_id_; }
bool set_device_id(uint32_t device_id); bool set_device_id(uint32_t device_id);
// uint32_t max_call_depth_
uint32_t max_call_depth() const { return max_call_depth_; }
inline bool set_max_call_depth(uint32_t max_call_depth) {
max_call_depth_ = max_call_depth;
return true;
}
bool save_graphs_flag() const { return save_graphs_flag_; } bool save_graphs_flag() const { return save_graphs_flag_; }
void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; }
...@@ -171,6 +180,7 @@ class MsContext { ...@@ -171,6 +180,7 @@ class MsContext {
MsBackendPolicy backend_policy_; MsBackendPolicy backend_policy_;
std::string device_target_; std::string device_target_;
uint32_t device_id_; uint32_t device_id_;
uint32_t max_call_depth_;
int execution_mode_; int execution_mode_;
bool enable_pynative_infer_; bool enable_pynative_infer_;
bool enable_pynative_hook_; bool enable_pynative_hook_;
......
...@@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break(): ...@@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break():
x = self.flatten(x + elem1) x = self.flatten(x + elem1)
return x return x
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=2000)
t = Tensor(np.ones([2, 3], dtype=np.float32)) t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net() net = Net()
net(t) net(t)
context.set_context(max_call_depth=old_max_call_depth)
def test_mixed_precision_cast(): def test_mixed_precision_cast():
...@@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive(): ...@@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive():
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(i, input1) net(i, input1)
def test_recursive_call():
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Dense(10, 10) # padding=0
#self.net2 = Net2()
def construct(self, x):
net2 = Net2()
x = net2(x)
out = self.fc(x)
return out
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.net = Net()
self.fc = nn.Dense(10, 10)
def construct(self, x):
x = self.net(x)
out = self.fc(x)
return out
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=80)
input_data = Tensor(np.identity(10).astype(np.float32))
net = Net2()
with pytest.raises(RuntimeError):
net(input_data)
context.set_context(max_call_depth=old_max_call_depth)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册