Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
38083e05
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
38083e05
编写于
8月 14, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix coredump missing return statement after while loop
上级
406ce735
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
281 addition
and
35 deletion
+281
-35
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+2
-2
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+14
-0
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+30
-5
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
+1
-0
mindspore/core/ir/manager.cc
mindspore/core/ir/manager.cc
+5
-1
tests/ut/cpp/pipeline/parse/parser_test.cc
tests/ut/cpp/pipeline/parse/parser_test.cc
+0
-15
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+23
-0
tests/ut/python/ops/test_ops_check.py
tests/ut/python/ops/test_ops_check.py
+5
-12
tests/ut/python/pipeline/parse/test_grammar_constraints.py
tests/ut/python/pipeline/parse/test_grammar_constraints.py
+201
-0
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
38083e05
...
@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
...
@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
get_parse_method_of_class
,
get_scope_name
,
is_class_member
,
parse_cb
,
resolve_symbol
,
convert_to_ms_tensor
)
is_class_member
,
parse_cb
,
resolve_symbol
,
convert_to_ms_tensor
,
get_object_description
)
from
.serialize
import
*
from
.serialize
import
*
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
...
@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
...
@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
,
'convert_to_ms_tensor'
]
'create_slice_obj'
,
'convert_to_ms_tensor'
,
'get_object_description'
]
mindspore/_extends/parse/parser.py
浏览文件 @
38083e05
...
@@ -322,6 +322,20 @@ def convert_to_ms_tensor(data):
...
@@ -322,6 +322,20 @@ def convert_to_ms_tensor(data):
return
MsTensor
(
data
)
return
MsTensor
(
data
)
def
get_object_description
(
obj
,
fname
,
fline
):
"""return method or funcition description for error report, include location, class name, etc."""
if
isinstance
(
obj
,
types
.
MethodType
):
obj_cls
=
obj
.
__self__
.
__class__
class_name
=
f
'
{
obj_cls
.
__module__
}
.
{
obj_cls
.
__qualname__
}
'
cls_fname
=
inspect
.
getfile
(
obj_cls
)
_
,
cls_fline
=
inspect
.
getsourcelines
(
obj_cls
)
class_loc
=
f
'
{
cls_fname
}
:
{
cls_fline
}
'
return
f
"bound method '
{
obj
.
__name__
}
' at
{
fname
}
:
{
fline
}
of <
{
class_name
}
at
{
class_loc
}
object>"
if
isinstance
(
obj
,
(
types
.
FunctionType
,
ast
.
FunctionDef
)):
return
f
"function '
{
obj
.
name
}
' at
{
fname
}
:
{
fline
}
"
return
str
(
obj
)
class
Parser
:
class
Parser
:
"""
"""
Parser python code to ast tree.
Parser python code to ast tree.
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
38083e05
...
@@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() {
...
@@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() {
RemoveUnnecessaryPhis
();
RemoveUnnecessaryPhis
();
MS_EXCEPTION_IF_NULL
(
pFnBlock
);
MS_EXCEPTION_IF_NULL
(
pFnBlock
);
// check whether the functions refered by this function and itself are missing 'return' statement
auto
mng
=
Manage
(
pFnBlock
->
func_graph
(),
false
);
for
(
auto
func_graph
:
mng
->
func_graphs
())
{
if
(
func_graph
->
get_return
()
!=
nullptr
)
{
continue
;
}
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
ret
[
0
],
ret
[
1
]);
MS_EXCEPTION
(
TypeError
)
<<
"Missing return statement in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
// clear manager info after checking missing return
for
(
auto
fg
:
mng
->
func_graphs
())
{
fg
->
ClearAllManagerInfo
();
}
return
pFnBlock
->
func_graph
();
return
pFnBlock
->
func_graph
();
}
}
...
@@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
...
@@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
(
void
)
ParseStatements
(
pFunBlock
,
funcObj
);
(
void
)
ParseStatements
(
pFunBlock
,
funcObj
);
if
(
current_fg
->
get_return
()
==
nullptr
)
{
if
(
current_fg
->
get_return
()
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Graph return node is null, loc:"
<<
GetLocation
(
node
)
->
ToString
(
);
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
errcode_
=
PARSE_NO_RETURN
;
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
node
,
ret
[
0
],
ret
[
1
])
;
return
pFunBlock
;
MS_EXCEPTION
(
TypeError
)
<<
"Missing return statement in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
}
GenerateArgsDefaultValueForFunction
(
pFunBlock
,
node
);
GenerateArgsDefaultValueForFunction
(
pFunBlock
,
node
);
return
pFunBlock
;
return
pFunBlock
;
...
@@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
...
@@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
}
}
auto
filename
=
location
[
0
].
cast
<
std
::
string
>
();
auto
filename
=
location
[
0
].
cast
<
std
::
string
>
();
auto
line_no
=
location
[
1
].
cast
<
int
>
();
auto
line_no
=
location
[
1
].
cast
<
int
>
();
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
;
auto
fn_loc
=
block
->
func_graph
()
->
debug_info
()
->
location
();
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
fn_loc
->
file_name
(),
fn_loc
->
line
());
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
<<
" in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
}
}
}
...
@@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
...
@@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
auto
filename
=
ret
[
0
].
cast
<
std
::
string
>
();
auto
filename
=
ret
[
0
].
cast
<
std
::
string
>
();
auto
line_no
=
ret
[
1
].
cast
<
int
>
();
auto
line_no
=
ret
[
1
].
cast
<
int
>
();
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
;
auto
fn_loc
=
block
->
func_graph
()
->
debug_info
()
->
location
();
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
fn_loc
->
file_name
(),
fn_loc
->
line
());
MS_LOG
(
EXCEPTION
)
<<
"Unsupported syntax '"
<<
node_name
<<
"' at "
<<
filename
<<
":"
<<
line_no
<<
" in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
}
}
}
...
...
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
浏览文件 @
38083e05
...
@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
...
@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const
char
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
[]
=
"get_class_member_namespace_symbol"
;
const
char
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
[]
=
"get_class_member_namespace_symbol"
;
const
char
PYTHON_MOD_GET_PARSE_METHOD
[]
=
"get_parse_method_of_class"
;
const
char
PYTHON_MOD_GET_PARSE_METHOD
[]
=
"get_parse_method_of_class"
;
const
char
PYTHON_MOD_GET_BPROP_METHOD
[]
=
"get_bprop_method_of_class"
;
const
char
PYTHON_MOD_GET_BPROP_METHOD
[]
=
"get_bprop_method_of_class"
;
const
char
PYTHON_MOD_GET_OBJECT_DESCRIPTION
[]
=
"get_object_description"
;
const
char
PYTHON_MOD_CONVERT_TO_MS_TENSOR
[]
=
"convert_to_ms_tensor"
;
const
char
PYTHON_MOD_CONVERT_TO_MS_TENSOR
[]
=
"convert_to_ms_tensor"
;
const
char
PYTHON_PARSE_GET_ARGS
[]
=
"get_args"
;
const
char
PYTHON_PARSE_GET_ARGS
[]
=
"get_args"
;
...
...
mindspore/core/ir/manager.cc
浏览文件 @
38083e05
...
@@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
...
@@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
FuncGraphSetPtr
func_graphs_to_check
=
std
::
make_shared
<
FuncGraphSet
>
();
FuncGraphSetPtr
func_graphs_to_check
=
std
::
make_shared
<
FuncGraphSet
>
();
while
(
!
nodes_ordered
.
empty
())
{
while
(
!
nodes_ordered
.
empty
())
{
AnfNodePtr
node
=
nodes_ordered
.
pop
();
AnfNodePtr
node
=
nodes_ordered
.
pop
();
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
==
nullptr
)
{
// Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor
MS_LOG
(
WARNING
)
<<
"Node to be dropped is nullptr"
;
continue
;
}
if
(
!
all_nodes_
.
contains
(
node
))
{
if
(
!
all_nodes_
.
contains
(
node
))
{
continue
;
continue
;
}
}
...
...
tests/ut/cpp/pipeline/parse/parser_test.cc
浏览文件 @
38083e05
...
@@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) {
...
@@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) {
ASSERT_TRUE
(
nullptr
!=
func_graph
);
ASSERT_TRUE
(
nullptr
!=
func_graph
);
}
}
TEST_F
(
TestParser
,
TestParseGraphFailure
)
{
GetPythonFunction
(
"get_no_return_fn"
);
// create parser
std
::
shared_ptr
<
ParseAst
>
ast
=
std
::
make_shared
<
ParseAst
>
(
fn
);
bool
succ
=
ast
->
InitParseAstInfo
();
ASSERT_TRUE
(
succ
=
true
);
std
::
shared_ptr
<
Parser
>
parser
=
std
::
make_shared
<
Parser
>
(
ast
);
// parse ast to graph
FuncGraphPtr
func_graph
=
parser
->
ParseFuncGraph
();
ASSERT_EQ
(
PARSE_NO_RETURN
,
parser
->
errcode
());
ASSERT_TRUE
(
nullptr
==
func_graph
);
}
TEST_F
(
TestParser
,
TestParseGraphIf
)
{
TEST_F
(
TestParser
,
TestParseGraphIf
)
{
GetPythonFunction
(
"test_if"
);
GetPythonFunction
(
"test_if"
);
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
38083e05
...
@@ -689,3 +689,26 @@ def test_while_concat():
...
@@ -689,3 +689,26 @@ def test_while_concat():
x
=
Tensor
(
np
.
arange
(
10
*
2
*
3
).
reshape
(
10
,
2
,
3
).
astype
(
np
.
float32
))
x
=
Tensor
(
np
.
arange
(
10
*
2
*
3
).
reshape
(
10
,
2
,
3
).
astype
(
np
.
float32
))
net
=
Net
(
x
)
net
=
Net
(
x
)
net
(
x
)
net
(
x
)
def
test_tensor_all_construct_lack_branch
():
class
NetConditionLackBranch
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetConditionLackBranch
,
self
).
__init__
()
self
.
logicaland
=
P
.
LogicalAnd
()
self
.
logicalor
=
P
.
LogicalOr
()
def
construct
(
self
,
input1
,
input2
):
if
input1
.
all
():
return
self
.
logicaland
(
input1
,
input2
)
while
input1
.
any
():
return
self
.
logicalor
(
input1
,
input2
)
# NOTICE: here missing return statement, default return None
input_np_1
=
np
.
random
.
choice
([
True
],
size
=
(
2
,
3
,
4
,
5
))
input_tensor_1
=
Tensor
(
input_np_1
)
input_np_2
=
np
.
random
.
choice
([
True
,
False
],
size
=
(
2
,
3
,
4
,
5
))
input_tensor_2
=
Tensor
(
input_np_2
)
net
=
NetConditionLackBranch
()
with
pytest
.
raises
(
Exception
):
net
(
input_tensor_1
,
input_tensor_2
)
tests/ut/python/ops/test_ops_check.py
浏览文件 @
38083e05
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
functools
import
functools
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
...
@@ -62,13 +63,9 @@ def test_net_without_construct():
...
@@ -62,13 +63,9 @@ def test_net_without_construct():
""" test_net_without_construct """
""" test_net_without_construct """
net
=
NetMissConstruct
()
net
=
NetMissConstruct
()
inp
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
))
inp
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
))
try
:
with
pytest
.
raises
(
RuntimeError
)
as
err
:
_executor
.
compile
(
net
,
inp
)
_executor
.
compile
(
net
,
inp
)
except
RuntimeError
as
err
:
assert
"Unsupported syntax 'Raise' at "
in
str
(
err
.
value
)
if
str
(
err
).
find
(
"Unsupported syntax 'Raise' at "
)
>=
0
:
print
(
str
(
err
))
else
:
raise
err
class
NetWithRaise
(
nn
.
Cell
):
class
NetWithRaise
(
nn
.
Cell
):
...
@@ -87,13 +84,9 @@ def test_net_with_raise():
...
@@ -87,13 +84,9 @@ def test_net_with_raise():
""" test_net_with_raise """
""" test_net_with_raise """
net
=
NetWithRaise
()
net
=
NetWithRaise
()
inp
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
))
inp
=
Tensor
(
np
.
ones
([
1
,
1
,
32
,
32
]).
astype
(
np
.
float32
))
try
:
with
pytest
.
raises
(
RuntimeError
)
as
err
:
_executor
.
compile
(
net
,
inp
)
_executor
.
compile
(
net
,
inp
)
except
RuntimeError
as
err
:
assert
"Unsupported syntax 'Raise' at "
in
str
(
err
.
value
)
if
str
(
err
).
find
(
"Unsupported syntax 'Raise' at "
)
>=
0
:
print
(
str
(
err
))
else
:
raise
err
class
NetAddN
(
nn
.
Cell
):
class
NetAddN
(
nn
.
Cell
):
...
...
tests/ut/python/pipeline/parse/test_grammar_constraints.py
0 → 100644
浏览文件 @
38083e05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test mindspore grammar constraints
1. funtion must have return statement
2. raise statement can not be used
"""
# pylint: disable=R1705, R1710, W0223
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
dtype
as
mstype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
test_missing_return
():
class
NetMissReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetMissReturn
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
if
y
==
1
:
return
3
elif
y
==
2
:
for
i
in
range
(
z
):
return
i
+
z
i
=
0
while
i
<
z
:
return
i
+
z
def
g
(
u
):
return
x
+
u
# here method 'construct' misses a return statement
g
(
y
)
else
:
return
7
else
:
return
5
net
=
NetMissReturn
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
TypeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Missing return statement in bound method 'construct'"
in
str
(
er
.
value
)
def
test_nest_function_missing_return
():
class
NetNestFuncMissReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetNestFuncMissReturn
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
if
y
==
1
:
return
3
elif
y
==
2
:
for
i
in
range
(
z
):
return
i
+
z
i
=
0
while
i
<
z
:
return
i
+
z
def
g
(
u
):
x
+=
u
# nested function 'g' misses a return a statement
return
g
(
y
)
else
:
return
7
else
:
return
5
net
=
NetNestFuncMissReturn
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
TypeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Missing return statement in function 'g'"
in
str
(
er
.
value
)
def
test_raise_in_method
():
class
NetRaiseInMethod
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetRaiseInMethod
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
# add not support grammar 'raise' here
raise
ValueError
(
'Illegal case'
)
else
:
return
y
+
z
net
=
NetRaiseInMethod
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Unsupported syntax 'Raise' at"
in
str
(
er
.
value
)
def
test_raise_in_nested_function
():
class
NetNestRaise
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetNestRaise
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
elif
x
==
20
:
def
nest_fn
(
u
):
if
u
>
0
:
# add not support grammar 'raise' here
raise
ValueError
(
'Illegal case'
)
return
u
+
z
+
1
return
nest_fn
(
y
)
else
:
return
y
+
z
net
=
NetNestRaise
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
x
,
y
,
z
)
assert
"Unsupported syntax 'Raise' at "
in
str
(
er
.
value
)
def
test_nest_branch_with_return
():
class
NetBranchWithReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetBranchWithReturn
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
if
x
==
1
:
return
10
else
:
return
5
context
.
set_context
(
save_graphs
=
True
)
net
=
NetBranchWithReturn
()
x
=
Tensor
(
0
,
mstype
.
int32
)
y
=
Tensor
(
5
,
mstype
.
int32
)
z
=
Tensor
(
2
,
mstype
.
int32
)
net
(
x
,
y
,
z
)
def
test_any_with_no_return
():
class
NetAnyNoReturn
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetAnyNoReturn
,
self
).
__init__
()
def
construct
(
self
,
inp
):
result
=
inp
.
any
()
if
result
:
return
6
np_input
=
np
.
arange
(
2
*
3
*
4
).
reshape
((
2
,
3
,
4
)).
astype
(
np
.
bool_
)
tensor
=
Tensor
(
np_input
)
net
=
NetAnyNoReturn
()
with
pytest
.
raises
(
TypeError
)
as
er
:
net
(
tensor
)
assert
"Missing return statement in bound method 'construct'"
in
str
(
er
.
value
)
def
test_missing_construct
():
class
NetMissConstruct
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetMissConstruct
,
self
).
__init__
()
def
construct1
(
self
,
inp
):
return
5
np_input
=
np
.
arange
(
2
*
3
*
4
).
reshape
((
2
,
3
,
4
)).
astype
(
np
.
bool_
)
tensor
=
Tensor
(
np_input
)
net
=
NetMissConstruct
()
with
pytest
.
raises
(
RuntimeError
)
as
er
:
net
(
tensor
)
assert
"Unsupported syntax 'Raise' at "
in
str
(
er
.
value
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录