提交 faa697ce 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!155 add mindconverter ut

Merge pull request !155 from ggpolar/br_wzk_dev
...@@ -29,7 +29,7 @@ class ForwardCall(ast.NodeVisitor): ...@@ -29,7 +29,7 @@ class ForwardCall(ast.NodeVisitor):
self.module_name = os.path.basename(filename).replace('.py', '') self.module_name = os.path.basename(filename).replace('.py', '')
self.name_stack = [] self.name_stack = []
self.forward_stack = [] self.forward_stack = []
self.calls = [] self.calls = set()
self.process() self.process()
def process(self): def process(self):
...@@ -68,7 +68,7 @@ class ForwardCall(ast.NodeVisitor): ...@@ -68,7 +68,7 @@ class ForwardCall(ast.NodeVisitor):
self.forward_stack.append(func_name) self.forward_stack.append(func_name)
if node.name == 'forward': if node.name == 'forward':
self.calls.append(func_name) self.calls.add(func_name)
self.generic_visit(node) self.generic_visit(node)
...@@ -85,12 +85,12 @@ class ForwardCall(ast.NodeVisitor): ...@@ -85,12 +85,12 @@ class ForwardCall(ast.NodeVisitor):
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']: if func_name not in ['super', 'str', 'repr']:
if self.forward_stack: if self.forward_stack:
self.calls.append(func_name) self.calls.add(func_name)
self.visit(node.func) self.visit(node.func)
else: else:
if self.forward_stack: if self.forward_stack:
if 'self' in func_name: if 'self' in func_name:
self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
else: else:
self.calls.append(func_name) self.calls.add(func_name)
self.visit(node.func) self.visit(node.func)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test config module."""
from collections import OrderedDict
import pytest
from mindinsight.mindconverter.config import APIPt, REQUIRED
class TestAPIBase:
"""Test the class of APIPt."""
function_name = "func"
@pytest.mark.parametrize('parameters', ['(out.size(0), -1', '(2, 1, 0)'])
def test_parse_args_exception(self, parameters):
"""Test parse arguments exception"""
parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED)
api_parser = APIPt(self.function_name, parameters_spec)
with pytest.raises(ValueError):
api_parser.parse_args(api_parser.name, parameters)
def test_parse_single_arg(self):
"""Test parse one argument"""
source = '(1)'
parameters_spec = OrderedDict(in_channels=REQUIRED)
api_parser = APIPt(self.function_name, parameters_spec)
parsed_args = api_parser.parse_args(api_parser.name, source)
assert parsed_args['in_channels'] == '1'
def test_parse_args(self):
"""Test parse multiple arguments"""
source = '(1, 2)'
parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED)
api_parser = APIPt(self.function_name, parameters_spec)
parsed_args = api_parser.parse_args(api_parser.name, source)
assert parsed_args['in_channels'] == '1'
assert parsed_args['out_channels'] == '2'
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test forward_call module."""
import ast
import textwrap
from unittest.mock import patch
from mindinsight.mindconverter.forward_call import ForwardCall
class TestForwardCall:
"""Test the class of ForwardCall."""
source = textwrap.dedent("""\
import a
import a.nn as nn
import a.nn.functional as F
class TestNet:
def __init__(self):
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.forward1(out)
return out
def forward1(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
""")
@patch.object(ForwardCall, 'process')
def test_process(self, mock_process):
"""Test the function of visit ast tree to find out forward functions."""
mock_process.return_value = None
forward_call = ForwardCall("mock")
forward_call.visit(ast.parse(self.source))
expect_calls = ['TestNet.forward1',
'TestNet.forward1',
'F.relu',
'TestNet.conv1',
'F.max_pool2d',
'TestNet.conv2',
'out.view',
'out.size',
'TestNet.fc1',
'TestNet.fc2',
'TestNet.fc3',
]
assert [forward_call.calls].sort() == expect_calls.sort()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册