提交 38083e05 编写于 作者: F fary86

Fix coredump missing return statement after while loop

上级 406ce735
...@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, ...@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor) is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
from .serialize import * from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
...@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', ...@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj', 'convert_to_ms_tensor'] 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description']
...@@ -322,6 +322,20 @@ def convert_to_ms_tensor(data): ...@@ -322,6 +322,20 @@ def convert_to_ms_tensor(data):
return MsTensor(data) return MsTensor(data)
def get_object_description(obj, fname, fline):
"""return method or funcition description for error report, include location, class name, etc."""
if isinstance(obj, types.MethodType):
obj_cls = obj.__self__.__class__
class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}'
cls_fname = inspect.getfile(obj_cls)
_, cls_fline = inspect.getsourcelines(obj_cls)
class_loc = f'{cls_fname}:{cls_fline}'
return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
if isinstance(obj, (types.FunctionType, ast.FunctionDef)):
return f"function '{obj.name}' at {fname}:{fline}"
return str(obj)
class Parser: class Parser:
""" """
Parser python code to ast tree. Parser python code to ast tree.
......
...@@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() { ...@@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() {
RemoveUnnecessaryPhis(); RemoveUnnecessaryPhis();
MS_EXCEPTION_IF_NULL(pFnBlock); MS_EXCEPTION_IF_NULL(pFnBlock);
// check whether the functions refered by this function and itself are missing 'return' statement
auto mng = Manage(pFnBlock->func_graph(), false);
for (auto func_graph : mng->func_graphs()) {
if (func_graph->get_return() != nullptr) {
continue;
}
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
py::str desc =
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
}
// clear manager info after checking missing return
for (auto fg : mng->func_graphs()) {
fg->ClearAllManagerInfo();
}
return pFnBlock->func_graph(); return pFnBlock->func_graph();
} }
...@@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo ...@@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
(void)ParseStatements(pFunBlock, funcObj); (void)ParseStatements(pFunBlock, funcObj);
if (current_fg->get_return() == nullptr) { if (current_fg->get_return() == nullptr) {
MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
errcode_ = PARSE_NO_RETURN; py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
return pFunBlock; MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
} }
GenerateArgsDefaultValueForFunction(pFunBlock, node); GenerateArgsDefaultValueForFunction(pFunBlock, node);
return pFunBlock; return pFunBlock;
...@@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py: ...@@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
} }
auto filename = location[0].cast<std::string>(); auto filename = location[0].cast<std::string>();
auto line_no = location[1].cast<int>(); auto line_no = location[1].cast<int>();
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; auto fn_loc = block->func_graph()->debug_info()->location();
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
fn_loc->file_name(), fn_loc->line());
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
<< desc.cast<std::string>() << ".";
} }
} }
...@@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object ...@@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
auto filename = ret[0].cast<std::string>(); auto filename = ret[0].cast<std::string>();
auto line_no = ret[1].cast<int>(); auto line_no = ret[1].cast<int>();
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; auto fn_loc = block->func_graph()->debug_info()->location();
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
fn_loc->file_name(), fn_loc->line());
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
<< desc.cast<std::string>() << ".";
} }
} }
......
...@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; ...@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_PARSE_GET_ARGS[] = "get_args"; const char PYTHON_PARSE_GET_ARGS[] = "get_args";
......
...@@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> & ...@@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
while (!nodes_ordered.empty()) { while (!nodes_ordered.empty()) {
AnfNodePtr node = nodes_ordered.pop(); AnfNodePtr node = nodes_ordered.pop();
MS_EXCEPTION_IF_NULL(node); if (node == nullptr) {
// Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor
MS_LOG(WARNING) << "Node to be dropped is nullptr";
continue;
}
if (!all_nodes_.contains(node)) { if (!all_nodes_.contains(node)) {
continue; continue;
} }
......
...@@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) { ...@@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) {
ASSERT_TRUE(nullptr != func_graph); ASSERT_TRUE(nullptr != func_graph);
} }
TEST_F(TestParser, TestParseGraphFailure) {
GetPythonFunction("get_no_return_fn");
// create parser
std::shared_ptr<ParseAst> ast = std::make_shared<ParseAst>(fn);
bool succ = ast->InitParseAstInfo();
ASSERT_TRUE(succ = true);
std::shared_ptr<Parser> parser = std::make_shared<Parser>(ast);
// parse ast to graph
FuncGraphPtr func_graph = parser->ParseFuncGraph();
ASSERT_EQ(PARSE_NO_RETURN, parser->errcode());
ASSERT_TRUE(nullptr == func_graph);
}
TEST_F(TestParser, TestParseGraphIf) { TEST_F(TestParser, TestParseGraphIf) {
GetPythonFunction("test_if"); GetPythonFunction("test_if");
......
...@@ -689,3 +689,26 @@ def test_while_concat(): ...@@ -689,3 +689,26 @@ def test_while_concat():
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
net = Net(x) net = Net(x)
net(x) net(x)
def test_tensor_all_construct_lack_branch():
class NetConditionLackBranch(nn.Cell):
def __init__(self):
super(NetConditionLackBranch, self).__init__()
self.logicaland = P.LogicalAnd()
self.logicalor = P.LogicalOr()
def construct(self, input1, input2):
if input1.all():
return self.logicaland(input1, input2)
while input1.any():
return self.logicalor(input1, input2)
# NOTICE: here missing return statement, default return None
input_np_1 = np.random.choice([True], size=(2, 3, 4, 5))
input_tensor_1 = Tensor(input_np_1)
input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5))
input_tensor_2 = Tensor(input_np_2)
net = NetConditionLackBranch()
with pytest.raises(Exception):
net(input_tensor_1, input_tensor_2)
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import functools import functools
import logging import logging
import numpy as np import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
...@@ -62,13 +63,9 @@ def test_net_without_construct(): ...@@ -62,13 +63,9 @@ def test_net_without_construct():
""" test_net_without_construct """ """ test_net_without_construct """
net = NetMissConstruct() net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
try: with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp) _executor.compile(net, inp)
except RuntimeError as err: assert "Unsupported syntax 'Raise' at " in str(err.value)
if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
print(str(err))
else:
raise err
class NetWithRaise(nn.Cell): class NetWithRaise(nn.Cell):
...@@ -87,13 +84,9 @@ def test_net_with_raise(): ...@@ -87,13 +84,9 @@ def test_net_with_raise():
""" test_net_with_raise """ """ test_net_with_raise """
net = NetWithRaise() net = NetWithRaise()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
try: with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp) _executor.compile(net, inp)
except RuntimeError as err: assert "Unsupported syntax 'Raise' at " in str(err.value)
if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
print(str(err))
else:
raise err
class NetAddN(nn.Cell): class NetAddN(nn.Cell):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test mindspore grammar constraints
1. funtion must have return statement
2. raise statement can not be used
"""
# pylint: disable=R1705, R1710, W0223
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE)
def test_missing_return():
class NetMissReturn(nn.Cell):
def __init__(self):
super(NetMissReturn, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
if y == 1:
return 3
elif y == 2:
for i in range(z):
return i + z
i = 0
while i < z:
return i + z
def g(u):
return x + u
# here method 'construct' misses a return statement
g(y)
else:
return 7
else:
return 5
net = NetMissReturn()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(TypeError) as er:
net(x, y, z)
assert "Missing return statement in bound method 'construct'" in str(er.value)
def test_nest_function_missing_return():
class NetNestFuncMissReturn(nn.Cell):
def __init__(self):
super(NetNestFuncMissReturn, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
if y == 1:
return 3
elif y == 2:
for i in range(z):
return i + z
i = 0
while i < z:
return i + z
def g(u):
x += u
# nested function 'g' misses a return a statement
return g(y)
else:
return 7
else:
return 5
net = NetNestFuncMissReturn()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(TypeError) as er:
net(x, y, z)
assert "Missing return statement in function 'g'" in str(er.value)
def test_raise_in_method():
class NetRaiseInMethod(nn.Cell):
def __init__(self):
super(NetRaiseInMethod, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
# add not support grammar 'raise' here
raise ValueError('Illegal case')
else:
return y + z
net = NetRaiseInMethod()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(RuntimeError) as er:
net(x, y, z)
assert "Unsupported syntax 'Raise' at" in str(er.value)
def test_raise_in_nested_function():
class NetNestRaise(nn.Cell):
def __init__(self):
super(NetNestRaise, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
def nest_fn(u):
if u > 0:
# add not support grammar 'raise' here
raise ValueError('Illegal case')
return u + z + 1
return nest_fn(y)
else:
return y + z
net = NetNestRaise()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(RuntimeError) as er:
net(x, y, z)
assert "Unsupported syntax 'Raise' at " in str(er.value)
def test_nest_branch_with_return():
class NetBranchWithReturn(nn.Cell):
def __init__(self):
super(NetBranchWithReturn, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
else:
return 5
context.set_context(save_graphs=True)
net = NetBranchWithReturn()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
net(x, y, z)
def test_any_with_no_return():
class NetAnyNoReturn(nn.Cell):
def __init__(self):
super(NetAnyNoReturn, self).__init__()
def construct(self, inp):
result = inp.any()
if result:
return 6
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetAnyNoReturn()
with pytest.raises(TypeError) as er:
net(tensor)
assert "Missing return statement in bound method 'construct'" in str(er.value)
def test_missing_construct():
class NetMissConstruct(nn.Cell):
def __init__(self):
super(NetMissConstruct, self).__init__()
def construct1(self, inp):
return 5
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
with pytest.raises(RuntimeError) as er:
net(tensor)
assert "Unsupported syntax 'Raise' at " in str(er.value)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册