From 73dec43618eed477f7f3d65aa03b165c6f598280 Mon Sep 17 00:00:00 2001 From: ggpolar Date: Sat, 23 May 2020 17:52:26 +0800 Subject: [PATCH] add mindconverter ut --- mindinsight/mindconverter/forward_call.py | 10 +-- tests/ut/mindconverter/__init__.py | 14 ++++ tests/ut/mindconverter/test_config.py | 52 +++++++++++++++ tests/ut/mindconverter/test_forward_call.py | 73 +++++++++++++++++++++ 4 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 tests/ut/mindconverter/__init__.py create mode 100644 tests/ut/mindconverter/test_config.py create mode 100644 tests/ut/mindconverter/test_forward_call.py diff --git a/mindinsight/mindconverter/forward_call.py b/mindinsight/mindconverter/forward_call.py index 61e4013..07e5270 100644 --- a/mindinsight/mindconverter/forward_call.py +++ b/mindinsight/mindconverter/forward_call.py @@ -29,7 +29,7 @@ class ForwardCall(ast.NodeVisitor): self.module_name = os.path.basename(filename).replace('.py', '') self.name_stack = [] self.forward_stack = [] - self.calls = [] + self.calls = set() self.process() def process(self): @@ -68,7 +68,7 @@ class ForwardCall(ast.NodeVisitor): self.forward_stack.append(func_name) if node.name == 'forward': - self.calls.append(func_name) + self.calls.add(func_name) self.generic_visit(node) @@ -85,12 +85,12 @@ class ForwardCall(ast.NodeVisitor): if isinstance(node.func, ast.Name): if func_name not in ['super', 'str', 'repr']: if self.forward_stack: - self.calls.append(func_name) + self.calls.add(func_name) self.visit(node.func) else: if self.forward_stack: if 'self' in func_name: - self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') + self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') else: - self.calls.append(func_name) + self.calls.add(func_name) self.visit(node.func) diff --git a/tests/ut/mindconverter/__init__.py b/tests/ut/mindconverter/__init__.py new file mode 100644 index 0000000..e307743 --- /dev/null +++ b/tests/ut/mindconverter/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/tests/ut/mindconverter/test_config.py b/tests/ut/mindconverter/test_config.py new file mode 100644 index 0000000..30c1010 --- /dev/null +++ b/tests/ut/mindconverter/test_config.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test config module.""" +from collections import OrderedDict + +import pytest + +from mindinsight.mindconverter.config import APIPt, REQUIRED + + +class TestAPIBase: + """Test the class of APIPt.""" + function_name = "func" + + @pytest.mark.parametrize('parameters', ['(out.size(0), -1', '(2, 1, 0)']) + def test_parse_args_exception(self, parameters): + """Test parse arguments exception""" + parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED) + api_parser = APIPt(self.function_name, parameters_spec) + with pytest.raises(ValueError): + api_parser.parse_args(api_parser.name, parameters) + + def test_parse_single_arg(self): + """Test parse one argument""" + source = '(1)' + parameters_spec = OrderedDict(in_channels=REQUIRED) + api_parser = APIPt(self.function_name, parameters_spec) + parsed_args = api_parser.parse_args(api_parser.name, source) + + assert parsed_args['in_channels'] == '1' + + def test_parse_args(self): + """Test parse multiple arguments""" + source = '(1, 2)' + parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED) + api_parser = APIPt(self.function_name, parameters_spec) + parsed_args = api_parser.parse_args(api_parser.name, source) + + assert parsed_args['in_channels'] == '1' + assert parsed_args['out_channels'] == '2' diff --git a/tests/ut/mindconverter/test_forward_call.py b/tests/ut/mindconverter/test_forward_call.py new file mode 100644 index 0000000..67f4596 --- /dev/null +++ b/tests/ut/mindconverter/test_forward_call.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test forward_call module.""" +import ast +import textwrap +from unittest.mock import patch + +from mindinsight.mindconverter.forward_call import ForwardCall + + +class TestForwardCall: + """Test the class of ForwardCall.""" + source = textwrap.dedent("""\ + import a + import a.nn as nn + import a.nn.functional as F + class TestNet: + def __init__(self): + self.conv1 = nn.Conv2d(3, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + out = self.forward1(out) + return out + + def forward1(self, x): + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + out = self.fc3(out) + return out + """) + + @patch.object(ForwardCall, 'process') + def test_process(self, mock_process): + """Test the function of visit ast tree to find out forward functions.""" + mock_process.return_value = None + forward_call = ForwardCall("mock") + forward_call.visit(ast.parse(self.source)) + + expect_calls = ['TestNet.forward1', + 'TestNet.forward1', + 'F.relu', + 'TestNet.conv1', + 'F.max_pool2d', + 'TestNet.conv2', + 'out.view', + 'out.size', + 'TestNet.fc1', + 'TestNet.fc2', + 'TestNet.fc3', + ] + + assert [forward_call.calls].sort() == expect_calls.sort() -- GitLab