Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
38083e05
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
38083e05
编写于
8月 14, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix coredump missing return statement after while loop
上级
406ce735
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
281 addition
and
35 deletion
+281
-35
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+2
-2
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+14
-0
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+30
-5
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
+1
-0
mindspore/core/ir/manager.cc
mindspore/core/ir/manager.cc
+5
-1
tests/ut/cpp/pipeline/parse/parser_test.cc
tests/ut/cpp/pipeline/parse/parser_test.cc
+0
-15
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+23
-0
tests/ut/python/ops/test_ops_check.py
tests/ut/python/ops/test_ops_check.py
+5
-12
tests/ut/python/pipeline/parse/test_grammar_constraints.py
tests/ut/python/pipeline/parse/test_grammar_constraints.py
+201
-0
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
38083e05
...
...
@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
is_class_member
,
parse_cb
,
resolve_symbol
,
convert_to_ms_tensor
)
is_class_member
,
parse_cb
,
resolve_symbol
,
convert_to_ms_tensor
,
get_object_description
)
from
.serialize
import
*
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
...
...
@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
,
'convert_to_ms_tensor'
]
'create_slice_obj'
,
'convert_to_ms_tensor'
,
'get_object_description'
]
mindspore/_extends/parse/parser.py
浏览文件 @
38083e05
...
...
@@ -322,6 +322,20 @@ def convert_to_ms_tensor(data):
return
MsTensor
(
data
)
def
get_object_description
(
obj
,
fname
,
fline
):
"""return method or funcition description for error report, include location, class name, etc."""
if
isinstance
(
obj
,
types
.
MethodType
):
obj_cls
=
obj
.
__self__
.
__class__
class_name
=
f
'
{
obj_cls
.
__module__
}
.
{
obj_cls
.
__qualname__
}
'
cls_fname
=
inspect
.
getfile
(
obj_cls
)
_
,
cls_fline
=
inspect
.
getsourcelines
(
obj_cls
)
class_loc
=
f
'
{
cls_fname
}
:
{
cls_fline
}
'
return
f
"bound method '
{
obj
.
__name__
}
' at
{
fname
}
:
{
fline
}
of <
{
class_name
}
at
{
class_loc
}
object>"
if
isinstance
(
obj
,
(
types
.
FunctionType
,
ast
.
FunctionDef
)):
return
f
"function '
{
obj
.
name
}
' at
{
fname
}
:
{
fline
}
"
return
str
(
obj
)
class
Parser
:
"""
Parser python code to ast tree.
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
38083e05
...
...
@@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() {
RemoveUnnecessaryPhis
();
MS_EXCEPTION_IF_NULL
(
pFnBlock
);
// check whether the functions refered by this function and itself are missing 'return' statement
auto
mng
=
Manage
(
pFnBlock
->
func_graph
(),
false
);
for
(
auto
func_graph
:
mng
->
func_graphs
())
{
if
(
func_graph
->
get_return
()
!=
nullptr
)
{
continue
;
}
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
ret
[
0
],
ret
[
1
]);
MS_EXCEPTION
(
TypeError
)
<<
"Missing return statement in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
// clear manager info after checking missing return
for
(
auto
fg
:
mng
->
func_graphs
())
{
fg
->
ClearAllManagerInfo
();
}
return
pFnBlock
->
func_graph
();
}
...
...
@@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
(
void
)
ParseStatements
(
pFunBlock
,
funcObj
);
if
(
current_fg
->
get_return
()
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Graph return node is null, loc:"
<<
GetLocation
(
node
)
->
ToString
(
);
errcode_
=
PARSE_NO_RETURN
;
return
pFunBlock
;
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
node
,
ret
[
0
],
ret
[
1
])
;
MS_EXCEPTION
(
TypeError
)
<<
"Missing return statement in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
GenerateArgsDefaultValueForFunction
(
pFunBlock
,
node
);
return
pFunBlock
;
...
...
@@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
}
auto
filename
=
location
[
0
].
cast
<
std
::
string
>
();
auto
line_no
=
location
[
1
].
cast
<
int
>
();
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
;
auto
fn_loc
=
block
->
func_graph
()
->
debug_info
()
->
location
();
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
fn_loc
->
file_name
(),
fn_loc
->
line
());
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
<<
" in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
}
...
...
@@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
auto
filename
=
ret
[
0
].
cast
<
std
::
string
>
();
auto
line_no
=
ret
[
1
].
cast
<
int
>
();
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
;
auto
fn_loc
=
block
->
func_graph
()
->
debug_info
()
->
location
();
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
fn_loc
->
file_name
(),
fn_loc
->
line
());
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
<<
" in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
}
...
...
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
浏览文件 @
38083e05
...
...
@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const
char
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
[]
=
"get_class_member_namespace_symbol"
;
const
char
PYTHON_MOD_GET_PARSE_METHOD
[]
=
"get_parse_method_of_class"
;
const
char
PYTHON_MOD_GET_BPROP_METHOD
[]
=
"get_bprop_method_of_class"
;
const
char
PYTHON_MOD_GET_OBJECT_DESCRIPTION
[]
=
"get_object_description"
;
const
char
PYTHON_MOD_CONVERT_TO_MS_TENSOR
[]
=
"convert_to_ms_tensor"
;
const
char
PYTHON_PARSE_GET_ARGS
[]
=
"get_args"
;
...
...
mindspore/core/ir/manager.cc
浏览文件 @
38083e05
...
...
@@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
FuncGraphSetPtr
func_graphs_to_check
=
std
::
make_shared
<
FuncGraphSet
>
();
while
(
!
nodes_ordered
.
empty
())
{
AnfNodePtr
node
=
nodes_ordered
.
pop
();
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
==
nullptr
)
{
// Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor
MS_LOG
(
WARNING
)
<<
"Node to be dropped is nullptr"
;
continue
;
}
if
(
!
all_nodes_
.
contains
(
node
))
{
continue
;
}
...
...
tests/ut/cpp/pipeline/parse/parser_test.cc
浏览文件 @
38083e05
...
...
@@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) {
ASSERT_TRUE
(
nullptr
!=
func_graph
);
}
TEST_F
(
TestParser
,
TestParseGraphFailure
)
{
GetPythonFunction
(
"get_no_return_fn"
);
// create parser
std
::
shared_ptr
<
ParseAst
>
ast
=
std
::
make_shared
<
ParseAst
>
(
fn
);
bool
succ
=
ast
->
InitParseAstInfo
();
ASSERT_TRUE
(
succ
=
true
);
std
::
shared_ptr
<
Parser
>
parser
=
std
::
make_shared
<
Parser
>
(
ast
);
// parse ast to graph
FuncGraphPtr
func_graph
=
parser
->
ParseFuncGraph
();
ASSERT_EQ
(
PARSE_NO_RETURN
,
parser
->
errcode
());
ASSERT_TRUE
(
nullptr
==
func_graph
);
}
TEST_F
(
TestParser
,
TestParseGraphIf
)
{
GetPythonFunction
(
"test_if"
);
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
38083e05
...
...
@@ -689,3 +689,26 @@ def test_while_concat():
x
=
Tensor
(
np
.
arange
(
10
*
2
*
3
).
reshape
(
10
,
2
,
3
).
astype
(
np
.
float32
))
net
=
Net
(
x
)
net
(
x
)
def
test_tensor_all_construct_lack_branch
():
class
NetConditionLackBranch
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetConditionLackBranch
,
self
).
__init__
()
self
.
logicaland
=
P
.
LogicalAnd
()
self
.
logicalor
=
P
.
LogicalOr
()
def
construct
(
self
,
input1
,
input2
):
if
input1
.
all
():
return
self
.
logicaland
(
input1
,
input2
)
while
input1
.
any
():
return
self
.
logicalor
(
input1
,
input2
)
# NOTICE: here missing return statement, default return None
input_np_1
=
np
.
random
.
choice
([
True
],
size
=
(
2
,
3
,
4
,
5
))
input_tensor_1
=
Tensor
(
input_np_1
)
input_np_2
=
np
.
random
.
choice
([
True
,
False
],
size
=
(
2
,
3
,
4
,
5
))
input_tensor_2
=
Tensor
(
input_np_2
)
net
=
NetConditionLackBranch
()
with
pytest
.
raises
(
Exception
):
net
(
input_tensor_1
,
input_tensor_2
)
tests/ut/python/ops/test_ops_check.py
浏览文件 @
38083e05
...
...
@@ -16,6 +16,7 @@
import
functools
import
logging
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
from
mindspore
import
Tensor
...
...
@@ -62,13 +63,9 @@ def test_net_without_construct():
""" test_net_without_construct """
net
=
NetMissConstruct
()
inp
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
))
try
:
with
pytest
.
raises
(
RuntimeError
)
as
err
:
_executor
.
compile
(
net
,
inp
)
except
RuntimeError
as
err
:
if
str
(
err
).
find
(
"Unsupported syntax 'Raise' at "
)
>=
0
:
print
(
str
(
err
))
else
:
raise
err
assert
"Unsupported syntax 'Raise' at "
in
str
(
err
.
value
)
class
NetWithRaise
(
nn
.
Cell
):
...
...
@@ -87,13 +84,9 @@ def test_net_with_raise():
""" test_net_with_raise """
net
=
NetWithRaise
()
inp
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
))
try
:
with
pytest
.
raises
(
RuntimeError
)
as
err
:
_executor
.
compile
(
net
,
inp
)
except
RuntimeError
as
err
:
if
str
(
err
).
find
(
"Unsupported syntax 'Raise' at "
)
>=
0
:
print
(
str
(
err
))
else
:
raise
err
assert
"Unsupported syntax 'Raise' at "
in
str
(
err
.
value
)
class
NetAddN
(
nn
.
Cell
):
...
...
tests/ut/python/pipeline/parse/test_grammar_constraints.py
0 → 100644
浏览文件 @
38083e05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test mindspore grammar constraints
1. funtion must have return statement
2. raise statement can not be used
"""
# pylint: disable=R1705, R1710, W0223
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
dtype
as
mstype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
test_missing_return
():
class
NetMissReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetMissReturn
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
if
y
==
1
:
return
3
elif
y
==
2
:
for
i
in
range
(
z
):
return
i
+
z
i
=
0
while
i
<
z
:
return
i
+
z
def
g
(
u
):
return
x
+
u
# here method 'construct' misses a return statement
g
(
y
)
else
:
return
7
else
:
return
5
net
=
NetMissReturn
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
TypeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Missing return statement in bound method 'construct'"
in
str
(
er
.
value
)
def
test_nest_function_missing_return
():
class
NetNestFuncMissReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetNestFuncMissReturn
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
if
y
==
1
:
return
3
elif
y
==
2
:
for
i
in
range
(
z
):
return
i
+
z
i
=
0
while
i
<
z
:
return
i
+
z
def
g
(
u
):
x
+=
u
# nested function 'g' misses a return a statement
return
g
(
y
)
else
:
return
7
else
:
return
5
net
=
NetNestFuncMissReturn
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
TypeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Missing return statement in function 'g'"
in
str
(
er
.
value
)
def
test_raise_in_method
():
class
NetRaiseInMethod
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetRaiseInMethod
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
# add not support grammar 'raise' here
raise
ValueError
(
'Illegal case'
)
else
:
return
y
+
z
net
=
NetRaiseInMethod
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Unsupported syntax 'Raise' at"
in
str
(
er
.
value
)
def
test_raise_in_nested_function
():
class
NetNestRaise
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetNestRaise
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
def
nest_fn
(
u
):
if
u
>
0
:
# add not support grammar 'raise' here
raise
ValueError
(
'Illegal case'
)
return
u
+
z
+
1
return
nest_fn
(
y
)
else
:
return
y
+
z
net
=
NetNestRaise
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Unsupported syntax 'Raise' at "
in
str
(
er
.
value
)
def
test_nest_branch_with_return
():
class
NetBranchWithReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetBranchWithReturn
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
else
:
return
5
context
.
set_context
(
save_graphs
=
True
)
net
=
NetBranchWithReturn
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
net
(
x
,
y
,
z
)
def
test_any_with_no_return
():
class
NetAnyNoReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetAnyNoReturn
,
self
).
__init__
()
def
construct
(
self
,
inp
):
result
=
inp
.
any
()
if
result
:
return
6
np_input
=
np
.
arange
(
2
*
3
*
4
).
reshape
((
2
,
3
,
4
)).
astype
(
np
.
bool_
)
tensor
=
Tensor
(
np_input
)
net
=
NetAnyNoReturn
()
with
pytest
.
raises
(
TypeError
)
as
er
:
net
(
tensor
)
assert
"Missing return statement in bound method 'construct'"
in
str
(
er
.
value
)
def
test_missing_construct
():
class
NetMissConstruct
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetMissConstruct
,
self
).
__init__
()
def
construct1
(
self
,
inp
):
return
5
np_input
=
np
.
arange
(
2
*
3
*
4
).
reshape
((
2
,
3
,
4
)).
astype
(
np
.
bool_
)
tensor
=
Tensor
(
np_input
)
net
=
NetMissConstruct
()
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
tensor
)
assert
"Unsupported syntax 'Raise' at "
in
str
(
er
.
value
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录