提交 38083e05 编写于 作者: F fary86

Fix coredump missing return statement after while loop

上级 406ce735
......@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor)
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
......@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj', 'convert_to_ms_tensor']
'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description']
......@@ -322,6 +322,20 @@ def convert_to_ms_tensor(data):
return MsTensor(data)
def get_object_description(obj, fname, fline):
"""return method or funcition description for error report, include location, class name, etc."""
if isinstance(obj, types.MethodType):
obj_cls = obj.__self__.__class__
class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}'
cls_fname = inspect.getfile(obj_cls)
_, cls_fline = inspect.getsourcelines(obj_cls)
class_loc = f'{cls_fname}:{cls_fline}'
return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
if isinstance(obj, (types.FunctionType, ast.FunctionDef)):
return f"function '{obj.name}' at {fname}:{fline}"
return str(obj)
class Parser:
"""
Parser python code to ast tree.
......
......@@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() {
RemoveUnnecessaryPhis();
MS_EXCEPTION_IF_NULL(pFnBlock);
// check whether the functions refered by this function and itself are missing 'return' statement
auto mng = Manage(pFnBlock->func_graph(), false);
for (auto func_graph : mng->func_graphs()) {
if (func_graph->get_return() != nullptr) {
continue;
}
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
py::str desc =
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
}
// clear manager info after checking missing return
for (auto fg : mng->func_graphs()) {
fg->ClearAllManagerInfo();
}
return pFnBlock->func_graph();
}
......@@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
(void)ParseStatements(pFunBlock, funcObj);
if (current_fg->get_return() == nullptr) {
MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString();
errcode_ = PARSE_NO_RETURN;
return pFunBlock;
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
}
GenerateArgsDefaultValueForFunction(pFunBlock, node);
return pFunBlock;
......@@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
}
auto filename = location[0].cast<std::string>();
auto line_no = location[1].cast<int>();
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no;
auto fn_loc = block->func_graph()->debug_info()->location();
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
fn_loc->file_name(), fn_loc->line());
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
<< desc.cast<std::string>() << ".";
}
}
......@@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
auto filename = ret[0].cast<std::string>();
auto line_no = ret[1].cast<int>();
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no;
auto fn_loc = block->func_graph()->debug_info()->location();
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
fn_loc->file_name(), fn_loc->line());
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
<< desc.cast<std::string>() << ".";
}
}
......
......@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
......
......@@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
while (!nodes_ordered.empty()) {
AnfNodePtr node = nodes_ordered.pop();
MS_EXCEPTION_IF_NULL(node);
if (node == nullptr) {
// Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor
MS_LOG(WARNING) << "Node to be dropped is nullptr";
continue;
}
if (!all_nodes_.contains(node)) {
continue;
}
......
......@@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) {
ASSERT_TRUE(nullptr != func_graph);
}
TEST_F(TestParser, TestParseGraphFailure) {
GetPythonFunction("get_no_return_fn");
// create parser
std::shared_ptr<ParseAst> ast = std::make_shared<ParseAst>(fn);
bool succ = ast->InitParseAstInfo();
ASSERT_TRUE(succ = true);
std::shared_ptr<Parser> parser = std::make_shared<Parser>(ast);
// parse ast to graph
FuncGraphPtr func_graph = parser->ParseFuncGraph();
ASSERT_EQ(PARSE_NO_RETURN, parser->errcode());
ASSERT_TRUE(nullptr == func_graph);
}
TEST_F(TestParser, TestParseGraphIf) {
GetPythonFunction("test_if");
......
......@@ -689,3 +689,26 @@ def test_while_concat():
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
net = Net(x)
net(x)
def test_tensor_all_construct_lack_branch():
class NetConditionLackBranch(nn.Cell):
def __init__(self):
super(NetConditionLackBranch, self).__init__()
self.logicaland = P.LogicalAnd()
self.logicalor = P.LogicalOr()
def construct(self, input1, input2):
if input1.all():
return self.logicaland(input1, input2)
while input1.any():
return self.logicalor(input1, input2)
# NOTICE: here missing return statement, default return None
input_np_1 = np.random.choice([True], size=(2, 3, 4, 5))
input_tensor_1 = Tensor(input_np_1)
input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5))
input_tensor_2 = Tensor(input_np_2)
net = NetConditionLackBranch()
with pytest.raises(Exception):
net(input_tensor_1, input_tensor_2)
......@@ -16,6 +16,7 @@
import functools
import logging
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
......@@ -62,13 +63,9 @@ def test_net_without_construct():
""" test_net_without_construct """
net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
try:
with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp)
except RuntimeError as err:
if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
print(str(err))
else:
raise err
assert "Unsupported syntax 'Raise' at " in str(err.value)
class NetWithRaise(nn.Cell):
......@@ -87,13 +84,9 @@ def test_net_with_raise():
""" test_net_with_raise """
net = NetWithRaise()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
try:
with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp)
except RuntimeError as err:
if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
print(str(err))
else:
raise err
assert "Unsupported syntax 'Raise' at " in str(err.value)
class NetAddN(nn.Cell):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test mindspore grammar constraints
1. funtion must have return statement
2. raise statement can not be used
"""
# pylint: disable=R1705, R1710, W0223
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE)
def test_missing_return():
class NetMissReturn(nn.Cell):
def __init__(self):
super(NetMissReturn, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
if y == 1:
return 3
elif y == 2:
for i in range(z):
return i + z
i = 0
while i < z:
return i + z
def g(u):
return x + u
# here method 'construct' misses a return statement
g(y)
else:
return 7
else:
return 5
net = NetMissReturn()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(TypeError) as er:
net(x, y, z)
assert "Missing return statement in bound method 'construct'" in str(er.value)
def test_nest_function_missing_return():
class NetNestFuncMissReturn(nn.Cell):
def __init__(self):
super(NetNestFuncMissReturn, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
if y == 1:
return 3
elif y == 2:
for i in range(z):
return i + z
i = 0
while i < z:
return i + z
def g(u):
x += u
# nested function 'g' misses a return a statement
return g(y)
else:
return 7
else:
return 5
net = NetNestFuncMissReturn()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(TypeError) as er:
net(x, y, z)
assert "Missing return statement in function 'g'" in str(er.value)
def test_raise_in_method():
class NetRaiseInMethod(nn.Cell):
def __init__(self):
super(NetRaiseInMethod, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
# add not support grammar 'raise' here
raise ValueError('Illegal case')
else:
return y + z
net = NetRaiseInMethod()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(RuntimeError) as er:
net(x, y, z)
assert "Unsupported syntax 'Raise' at" in str(er.value)
def test_raise_in_nested_function():
class NetNestRaise(nn.Cell):
def __init__(self):
super(NetNestRaise, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
elif x == 20:
def nest_fn(u):
if u > 0:
# add not support grammar 'raise' here
raise ValueError('Illegal case')
return u + z + 1
return nest_fn(y)
else:
return y + z
net = NetNestRaise()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(RuntimeError) as er:
net(x, y, z)
assert "Unsupported syntax 'Raise' at " in str(er.value)
def test_nest_branch_with_return():
class NetBranchWithReturn(nn.Cell):
def __init__(self):
super(NetBranchWithReturn, self).__init__()
def construct(self, x, y, z):
if x == 1:
return 10
else:
return 5
context.set_context(save_graphs=True)
net = NetBranchWithReturn()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
net(x, y, z)
def test_any_with_no_return():
class NetAnyNoReturn(nn.Cell):
def __init__(self):
super(NetAnyNoReturn, self).__init__()
def construct(self, inp):
result = inp.any()
if result:
return 6
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetAnyNoReturn()
with pytest.raises(TypeError) as er:
net(tensor)
assert "Missing return statement in bound method 'construct'" in str(er.value)
def test_missing_construct():
class NetMissConstruct(nn.Cell):
def __init__(self):
super(NetMissConstruct, self).__init__()
def construct1(self, inp):
return 5
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
with pytest.raises(RuntimeError) as er:
net(tensor)
assert "Unsupported syntax 'Raise' at " in str(er.value)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册