Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
73dec436
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
7
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73dec436
编写于
5月 23, 2020
作者:
G
ggpolar
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mindconverter ut
上级
a432241f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
144 addition
and
5 deletion
+144
-5
mindinsight/mindconverter/forward_call.py
mindinsight/mindconverter/forward_call.py
+5
-5
tests/ut/mindconverter/__init__.py
tests/ut/mindconverter/__init__.py
+14
-0
tests/ut/mindconverter/test_config.py
tests/ut/mindconverter/test_config.py
+52
-0
tests/ut/mindconverter/test_forward_call.py
tests/ut/mindconverter/test_forward_call.py
+73
-0
未找到文件。
mindinsight/mindconverter/forward_call.py
浏览文件 @
73dec436
...
...
@@ -29,7 +29,7 @@ class ForwardCall(ast.NodeVisitor):
self
.
module_name
=
os
.
path
.
basename
(
filename
).
replace
(
'.py'
,
''
)
self
.
name_stack
=
[]
self
.
forward_stack
=
[]
self
.
calls
=
[]
self
.
calls
=
set
()
self
.
process
()
def
process
(
self
):
...
...
@@ -68,7 +68,7 @@ class ForwardCall(ast.NodeVisitor):
self
.
forward_stack
.
append
(
func_name
)
if
node
.
name
==
'forward'
:
self
.
calls
.
a
ppen
d
(
func_name
)
self
.
calls
.
a
d
d
(
func_name
)
self
.
generic_visit
(
node
)
...
...
@@ -85,12 +85,12 @@ class ForwardCall(ast.NodeVisitor):
if
isinstance
(
node
.
func
,
ast
.
Name
):
if
func_name
not
in
[
'super'
,
'str'
,
'repr'
]:
if
self
.
forward_stack
:
self
.
calls
.
a
ppen
d
(
func_name
)
self
.
calls
.
a
d
d
(
func_name
)
self
.
visit
(
node
.
func
)
else
:
if
self
.
forward_stack
:
if
'self'
in
func_name
:
self
.
calls
.
a
ppen
d
(
f
'
{
self
.
get_current_namespace
()
}
.
{
func_name
.
split
(
"."
)[
-
1
]
}
'
)
self
.
calls
.
a
d
d
(
f
'
{
self
.
get_current_namespace
()
}
.
{
func_name
.
split
(
"."
)[
-
1
]
}
'
)
else
:
self
.
calls
.
a
ppen
d
(
func_name
)
self
.
calls
.
a
d
d
(
func_name
)
self
.
visit
(
node
.
func
)
tests/ut/mindconverter/__init__.py
0 → 100644
浏览文件 @
73dec436
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
tests/ut/mindconverter/test_config.py
0 → 100644
浏览文件 @
73dec436
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test config module."""
from
collections
import
OrderedDict
import
pytest
from
mindinsight.mindconverter.config
import
APIPt
,
REQUIRED
class
TestAPIBase
:
"""Test the class of APIPt."""
function_name
=
"func"
@
pytest
.
mark
.
parametrize
(
'parameters'
,
[
'(out.size(0), -1'
,
'(2, 1, 0)'
])
def
test_parse_args_exception
(
self
,
parameters
):
"""Test parse arguments exception"""
parameters_spec
=
OrderedDict
(
in_channels
=
REQUIRED
,
out_channels
=
REQUIRED
)
api_parser
=
APIPt
(
self
.
function_name
,
parameters_spec
)
with
pytest
.
raises
(
ValueError
):
api_parser
.
parse_args
(
api_parser
.
name
,
parameters
)
def
test_parse_single_arg
(
self
):
"""Test parse one argument"""
source
=
'(1)'
parameters_spec
=
OrderedDict
(
in_channels
=
REQUIRED
)
api_parser
=
APIPt
(
self
.
function_name
,
parameters_spec
)
parsed_args
=
api_parser
.
parse_args
(
api_parser
.
name
,
source
)
assert
parsed_args
[
'in_channels'
]
==
'1'
def
test_parse_args
(
self
):
"""Test parse multiple arguments"""
source
=
'(1, 2)'
parameters_spec
=
OrderedDict
(
in_channels
=
REQUIRED
,
out_channels
=
REQUIRED
)
api_parser
=
APIPt
(
self
.
function_name
,
parameters_spec
)
parsed_args
=
api_parser
.
parse_args
(
api_parser
.
name
,
source
)
assert
parsed_args
[
'in_channels'
]
==
'1'
assert
parsed_args
[
'out_channels'
]
==
'2'
tests/ut/mindconverter/test_forward_call.py
0 → 100644
浏览文件 @
73dec436
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test forward_call module."""
import
ast
import
textwrap
from
unittest.mock
import
patch
from
mindinsight.mindconverter.forward_call
import
ForwardCall
class
TestForwardCall
:
"""Test the class of ForwardCall."""
source
=
textwrap
.
dedent
(
"""
\
import a
import a.nn as nn
import a.nn.functional as F
class TestNet:
def __init__(self):
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.forward1(out)
return out
def forward1(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
"""
)
@
patch
.
object
(
ForwardCall
,
'process'
)
def
test_process
(
self
,
mock_process
):
"""Test the function of visit ast tree to find out forward functions."""
mock_process
.
return_value
=
None
forward_call
=
ForwardCall
(
"mock"
)
forward_call
.
visit
(
ast
.
parse
(
self
.
source
))
expect_calls
=
[
'TestNet.forward1'
,
'TestNet.forward1'
,
'F.relu'
,
'TestNet.conv1'
,
'F.max_pool2d'
,
'TestNet.conv2'
,
'out.view'
,
'out.size'
,
'TestNet.fc1'
,
'TestNet.fc2'
,
'TestNet.fc3'
,
]
assert
[
forward_call
.
calls
].
sort
()
==
expect_calls
.
sort
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录