提交 04524b6b 编写于 作者: F fary86

Fix coredump caused by function call depth too large

上级 0d1a7ac6
......@@ -110,6 +110,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.")
.def("get_device_id", &mindspore::MsContext::device_id, "Get device id.")
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.")
.def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
......
......@@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const AnfNodePtr &func_node = fg->get_return();
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
<< ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr ret_base = nullptr;
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << ".";
}
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it;
......@@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
}
engine->DecreaseFunctionCallDepth();
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
......
......@@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
// Running the analyzer.
ResetFunctionCallDepth();
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
MS_EXCEPTION_IF_NULL(root_context);
MS_EXCEPTION_IF_NULL(root_context->func_graph());
......
......@@ -185,7 +185,9 @@ struct PartialAppHasher {
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {}
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {
function_call_depth_ = 0;
}
~AnalysisEngine() = default;
// func_graph: The func_graph to analyze.
......@@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnalysisCache cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void ResetFunctionCallDepth() { function_call_depth_ = 0; }
void IncreaseFunctionCallDepth() { function_call_depth_++; }
void DecreaseFunctionCallDepth() {
if (function_call_depth_ == 0) {
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
}
function_call_depth_--;
}
unsigned int function_call_depth() { return function_call_depth_; }
private:
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
......@@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
// record current depth of function call statck
unsigned int function_call_depth_;
#ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_;
......
......@@ -234,6 +234,17 @@ class _Context:
if not success:
raise RuntimeError("Device id set failed!!!")
@property
def max_call_depth(self):
return self._context_handle.get_max_call_depth()
@max_call_depth.setter
def max_call_depth(self, max_call_depth):
if max_call_depth <= 0:
raise ValueError(
"Max call depth must be greater than 0, but got {}".format(max_call_depth))
self._context_handle.set_max_call_depth(max_call_depth)
@property
def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag()
......@@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises:
......@@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(max_call_depth=80)
"""
_set_auto_parallel_context(**kwargs)
......@@ -532,7 +545,7 @@ def reset_auto_parallel_context():
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool)
enable_sparse=bool, max_call_depth=int)
def set_context(**kwargs):
"""
Sets context for running environment.
......
......@@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
} else {
device_id_ = 0;
}
max_call_depth_ = MAX_CALL_DEPTH_DEFAULT;
backend_policy_ = policy_map_[policy];
device_target_ = target;
execution_mode_ = kPynativeMode;
......
......@@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend";
const char kDavinciInferenceDevice[] = "AscendInference";
const char kDavinciDevice[] = "Davinci";
const char KNpuLog[] = "_npu_log";
const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
// The default max available device memory is 1024GB.
const float kDefaultMaxDeviceMemory = 1024;
......@@ -80,6 +82,13 @@ class MsContext {
uint32_t device_id() const { return device_id_; }
bool set_device_id(uint32_t device_id);
// uint32_t max_call_depth_
uint32_t max_call_depth() const { return max_call_depth_; }
inline bool set_max_call_depth(uint32_t max_call_depth) {
max_call_depth_ = max_call_depth;
return true;
}
bool save_graphs_flag() const { return save_graphs_flag_; }
void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; }
......@@ -171,6 +180,7 @@ class MsContext {
MsBackendPolicy backend_policy_;
std::string device_target_;
uint32_t device_id_;
uint32_t max_call_depth_;
int execution_mode_;
bool enable_pynative_infer_;
bool enable_pynative_hook_;
......
......@@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break():
x = self.flatten(x + elem1)
return x
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=2000)
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
context.set_context(max_call_depth=old_max_call_depth)
def test_mixed_precision_cast():
......@@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive():
with pytest.raises(ValueError):
net(i, input1)
def test_recursive_call():
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Dense(10, 10) # padding=0
#self.net2 = Net2()
def construct(self, x):
net2 = Net2()
x = net2(x)
out = self.fc(x)
return out
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.net = Net()
self.fc = nn.Dense(10, 10)
def construct(self, x):
x = self.net(x)
out = self.fc(x)
return out
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=80)
input_data = Tensor(np.identity(10).astype(np.float32))
net = Net2()
with pytest.raises(RuntimeError):
net(input_data)
context.set_context(max_call_depth=old_max_call_depth)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册