From 38083e055a9776719f0d1cbb66ad07bea72ecab5 Mon Sep 17 00:00:00 2001 From: fary86 Date: Fri, 14 Aug 2020 00:17:25 +0800 Subject: [PATCH] Fix coredump missing return statement after while loop --- mindspore/_extends/parse/__init__.py | 4 +- mindspore/_extends/parse/parser.py | 14 ++ mindspore/ccsrc/pipeline/jit/parse/parse.cc | 35 ++- .../ccsrc/pipeline/jit/parse/parse_base.h | 1 + mindspore/core/ir/manager.cc | 6 +- tests/ut/cpp/pipeline/parse/parser_test.cc | 15 -- tests/ut/python/ops/test_control_ops.py | 23 ++ tests/ut/python/ops/test_ops_check.py | 17 +- .../parse/test_grammar_constraints.py | 201 ++++++++++++++++++ 9 files changed, 281 insertions(+), 35 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_grammar_constraints.py diff --git a/mindspore/_extends/parse/__init__.py b/mindspore/_extends/parse/__init__.py index cd13d329a..17f7eab2b 100644 --- a/mindspore/_extends/parse/__init__.py +++ b/mindspore/_extends/parse/__init__.py @@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_module_namespace, get_obj_type, get_object_key, get_parse_method_of_class, get_scope_name, - is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor) + is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description) from .serialize import * __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', @@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', - 'create_slice_obj', 'convert_to_ms_tensor'] + 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description'] diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 30f48f826..1b4f76f77 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -322,6 +322,20 @@ def convert_to_ms_tensor(data): return MsTensor(data) +def get_object_description(obj, fname, fline): + """return method or funcition description for error report, include location, class name, etc.""" + if isinstance(obj, types.MethodType): + obj_cls = obj.__self__.__class__ + class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}' + cls_fname = inspect.getfile(obj_cls) + _, cls_fline = inspect.getsourcelines(obj_cls) + class_loc = f'{cls_fname}:{cls_fline}' + return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>" + if isinstance(obj, (types.FunctionType, ast.FunctionDef)): + return f"function '{obj.name}' at {fname}:{fline}" + return str(obj) + + class Parser: """ Parser python code to ast tree. diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index be75d6ac2..614610ee6 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() { RemoveUnnecessaryPhis(); MS_EXCEPTION_IF_NULL(pFnBlock); + + // check whether the functions refered by this function and itself are missing 'return' statement + auto mng = Manage(pFnBlock->func_graph(), false); + for (auto func_graph : mng->func_graphs()) { + if (func_graph->get_return() != nullptr) { + continue; + } + py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + py::str desc = + python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]); + MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast() << "."; + } + // clear manager info after checking missing return + for (auto fg : mng->func_graphs()) { + fg->ClearAllManagerInfo(); + } + return pFnBlock->func_graph(); } @@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo (void)ParseStatements(pFunBlock, funcObj); if (current_fg->get_return() == nullptr) { - MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); - errcode_ = PARSE_NO_RETURN; - return pFunBlock; + py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]); + MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast() << "."; } GenerateArgsDefaultValueForFunction(pFunBlock, node); return pFunBlock; @@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py: } auto filename = location[0].cast(); auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; + auto fn_loc = block->func_graph()->debug_info()->location(); + py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), + fn_loc->file_name(), fn_loc->line()); + MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in " + << desc.cast() << "."; } } @@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); auto filename = ret[0].cast(); auto line_no = ret[1].cast(); - MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; + auto fn_loc = block->func_graph()->debug_info()->location(); + py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), + fn_loc->file_name(), fn_loc->line()); + MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in " + << desc.cast() << "."; } } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index d2c8d7a2f..ddc774e3d 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; +const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description"; const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; const char PYTHON_PARSE_GET_ARGS[] = "get_args"; diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 2970d22ee..ce239d28e 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector & FuncGraphSetPtr func_graphs_to_check = std::make_shared(); while (!nodes_ordered.empty()) { AnfNodePtr node = nodes_ordered.pop(); - MS_EXCEPTION_IF_NULL(node); + if (node == nullptr) { + // Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor + MS_LOG(WARNING) << "Node to be dropped is nullptr"; + continue; + } if (!all_nodes_.contains(node)) { continue; } diff --git a/tests/ut/cpp/pipeline/parse/parser_test.cc b/tests/ut/cpp/pipeline/parse/parser_test.cc index f1d908711..17c35673a 100644 --- a/tests/ut/cpp/pipeline/parse/parser_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_test.cc @@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) { ASSERT_TRUE(nullptr != func_graph); } -TEST_F(TestParser, TestParseGraphFailure) { - GetPythonFunction("get_no_return_fn"); - - // create parser - std::shared_ptr ast = std::make_shared(fn); - bool succ = ast->InitParseAstInfo(); - ASSERT_TRUE(succ = true); - std::shared_ptr parser = std::make_shared(ast); - - // parse ast to graph - FuncGraphPtr func_graph = parser->ParseFuncGraph(); - ASSERT_EQ(PARSE_NO_RETURN, parser->errcode()); - ASSERT_TRUE(nullptr == func_graph); -} - TEST_F(TestParser, TestParseGraphIf) { GetPythonFunction("test_if"); diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index ac31420f6..53f0222ab 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -689,3 +689,26 @@ def test_while_concat(): x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) net = Net(x) net(x) + + +def test_tensor_all_construct_lack_branch(): + class NetConditionLackBranch(nn.Cell): + def __init__(self): + super(NetConditionLackBranch, self).__init__() + self.logicaland = P.LogicalAnd() + self.logicalor = P.LogicalOr() + + def construct(self, input1, input2): + if input1.all(): + return self.logicaland(input1, input2) + while input1.any(): + return self.logicalor(input1, input2) + # NOTICE: here missing return statement, default return None + + input_np_1 = np.random.choice([True], size=(2, 3, 4, 5)) + input_tensor_1 = Tensor(input_np_1) + input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5)) + input_tensor_2 = Tensor(input_np_2) + net = NetConditionLackBranch() + with pytest.raises(Exception): + net(input_tensor_1, input_tensor_2) diff --git a/tests/ut/python/ops/test_ops_check.py b/tests/ut/python/ops/test_ops_check.py index c7bcb555e..beb1d4eb0 100644 --- a/tests/ut/python/ops/test_ops_check.py +++ b/tests/ut/python/ops/test_ops_check.py @@ -16,6 +16,7 @@ import functools import logging import numpy as np +import pytest import mindspore.context as context from mindspore import Tensor @@ -62,13 +63,9 @@ def test_net_without_construct(): """ test_net_without_construct """ net = NetMissConstruct() inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) - try: + with pytest.raises(RuntimeError) as err: _executor.compile(net, inp) - except RuntimeError as err: - if str(err).find("Unsupported syntax 'Raise' at ") >= 0: - print(str(err)) - else: - raise err + assert "Unsupported syntax 'Raise' at " in str(err.value) class NetWithRaise(nn.Cell): @@ -87,13 +84,9 @@ def test_net_with_raise(): """ test_net_with_raise """ net = NetWithRaise() inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) - try: + with pytest.raises(RuntimeError) as err: _executor.compile(net, inp) - except RuntimeError as err: - if str(err).find("Unsupported syntax 'Raise' at ") >= 0: - print(str(err)) - else: - raise err + assert "Unsupported syntax 'Raise' at " in str(err.value) class NetAddN(nn.Cell): diff --git a/tests/ut/python/pipeline/parse/test_grammar_constraints.py b/tests/ut/python/pipeline/parse/test_grammar_constraints.py new file mode 100644 index 000000000..d93a54039 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_grammar_constraints.py @@ -0,0 +1,201 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test mindspore grammar constraints +1. funtion must have return statement +2. raise statement can not be used +""" +# pylint: disable=R1705, R1710, W0223 +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE) + +def test_missing_return(): + class NetMissReturn(nn.Cell): + def __init__(self): + super(NetMissReturn, self).__init__() + + def construct(self, x, y, z): + if x == 1: + return 10 + elif x == 20: + if y == 1: + return 3 + elif y == 2: + for i in range(z): + return i + z + i = 0 + while i < z: + return i + z + def g(u): + return x + u + # here method 'construct' misses a return statement + g(y) + else: + return 7 + else: + return 5 + + net = NetMissReturn() + x = Tensor(0, mstype.int32) + y = Tensor(5, mstype.int32) + z = Tensor(2, mstype.int32) + with pytest.raises(TypeError) as er: + net(x, y, z) + assert "Missing return statement in bound method 'construct'" in str(er.value) + + +def test_nest_function_missing_return(): + class NetNestFuncMissReturn(nn.Cell): + def __init__(self): + super(NetNestFuncMissReturn, self).__init__() + + def construct(self, x, y, z): + if x == 1: + return 10 + elif x == 20: + if y == 1: + return 3 + elif y == 2: + for i in range(z): + return i + z + i = 0 + while i < z: + return i + z + def g(u): + x += u + # nested function 'g' misses a return a statement + return g(y) + else: + return 7 + else: + return 5 + + net = NetNestFuncMissReturn() + x = Tensor(0, mstype.int32) + y = Tensor(5, mstype.int32) + z = Tensor(2, mstype.int32) + with pytest.raises(TypeError) as er: + net(x, y, z) + assert "Missing return statement in function 'g'" in str(er.value) + + +def test_raise_in_method(): + class NetRaiseInMethod(nn.Cell): + def __init__(self): + super(NetRaiseInMethod, self).__init__() + + def construct(self, x, y, z): + if x == 1: + return 10 + elif x == 20: + # add not support grammar 'raise' here + raise ValueError('Illegal case') + else: + return y + z + + net = NetRaiseInMethod() + x = Tensor(0, mstype.int32) + y = Tensor(5, mstype.int32) + z = Tensor(2, mstype.int32) + with pytest.raises(RuntimeError) as er: + net(x, y, z) + assert "Unsupported syntax 'Raise' at" in str(er.value) + + +def test_raise_in_nested_function(): + class NetNestRaise(nn.Cell): + def __init__(self): + super(NetNestRaise, self).__init__() + + def construct(self, x, y, z): + if x == 1: + return 10 + elif x == 20: + def nest_fn(u): + if u > 0: + # add not support grammar 'raise' here + raise ValueError('Illegal case') + return u + z + 1 + return nest_fn(y) + else: + return y + z + + net = NetNestRaise() + x = Tensor(0, mstype.int32) + y = Tensor(5, mstype.int32) + z = Tensor(2, mstype.int32) + with pytest.raises(RuntimeError) as er: + net(x, y, z) + assert "Unsupported syntax 'Raise' at " in str(er.value) + + +def test_nest_branch_with_return(): + class NetBranchWithReturn(nn.Cell): + def __init__(self): + super(NetBranchWithReturn, self).__init__() + + def construct(self, x, y, z): + if x == 1: + return 10 + else: + return 5 + + context.set_context(save_graphs=True) + net = NetBranchWithReturn() + x = Tensor(0, mstype.int32) + y = Tensor(5, mstype.int32) + z = Tensor(2, mstype.int32) + net(x, y, z) + + +def test_any_with_no_return(): + class NetAnyNoReturn(nn.Cell): + def __init__(self): + super(NetAnyNoReturn, self).__init__() + + def construct(self, inp): + result = inp.any() + if result: + return 6 + + np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) + tensor = Tensor(np_input) + net = NetAnyNoReturn() + with pytest.raises(TypeError) as er: + net(tensor) + assert "Missing return statement in bound method 'construct'" in str(er.value) + + +def test_missing_construct(): + class NetMissConstruct(nn.Cell): + def __init__(self): + super(NetMissConstruct, self).__init__() + + def construct1(self, inp): + return 5 + + np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) + tensor = Tensor(np_input) + net = NetMissConstruct() + with pytest.raises(RuntimeError) as er: + net(tensor) + assert "Unsupported syntax 'Raise' at " in str(er.value) -- GitLab