diff --git a/mindinsight/mindconverter/forward_call.py b/mindinsight/mindconverter/forward_call.py index 61e401331b6b8faf02c2befd98a4114ab4dec9cc..07e52707300a275ef5523a2a513f0bf6bf3105aa 100644 --- a/mindinsight/mindconverter/forward_call.py +++ b/mindinsight/mindconverter/forward_call.py @@ -29,7 +29,7 @@ class ForwardCall(ast.NodeVisitor): self.module_name = os.path.basename(filename).replace('.py', '') self.name_stack = [] self.forward_stack = [] - self.calls = [] + self.calls = set() self.process() def process(self): @@ -68,7 +68,7 @@ class ForwardCall(ast.NodeVisitor): self.forward_stack.append(func_name) if node.name == 'forward': - self.calls.append(func_name) + self.calls.add(func_name) self.generic_visit(node) @@ -85,12 +85,12 @@ class ForwardCall(ast.NodeVisitor): if isinstance(node.func, ast.Name): if func_name not in ['super', 'str', 'repr']: if self.forward_stack: - self.calls.append(func_name) + self.calls.add(func_name) self.visit(node.func) else: if self.forward_stack: if 'self' in func_name: - self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') + self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') else: - self.calls.append(func_name) + self.calls.add(func_name) self.visit(node.func) diff --git a/tests/ut/mindconverter/__init__.py b/tests/ut/mindconverter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e30774307ca2107b3a81c071ad33c042ef924790 --- /dev/null +++ b/tests/ut/mindconverter/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/tests/ut/mindconverter/test_config.py b/tests/ut/mindconverter/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..30c1010484b432b49ea860810e90aff3e571e858 --- /dev/null +++ b/tests/ut/mindconverter/test_config.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test config module.""" +from collections import OrderedDict + +import pytest + +from mindinsight.mindconverter.config import APIPt, REQUIRED + + +class TestAPIBase: + """Test the class of APIPt.""" + function_name = "func" + + @pytest.mark.parametrize('parameters', ['(out.size(0), -1', '(2, 1, 0)']) + def test_parse_args_exception(self, parameters): + """Test parse arguments exception""" + parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED) + api_parser = APIPt(self.function_name, parameters_spec) + with pytest.raises(ValueError): + api_parser.parse_args(api_parser.name, parameters) + + def test_parse_single_arg(self): + """Test parse one argument""" + source = '(1)' + parameters_spec = OrderedDict(in_channels=REQUIRED) + api_parser = APIPt(self.function_name, parameters_spec) + parsed_args = api_parser.parse_args(api_parser.name, source) + + assert parsed_args['in_channels'] == '1' + + def test_parse_args(self): + """Test parse multiple arguments""" + source = '(1, 2)' + parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED) + api_parser = APIPt(self.function_name, parameters_spec) + parsed_args = api_parser.parse_args(api_parser.name, source) + + assert parsed_args['in_channels'] == '1' + assert parsed_args['out_channels'] == '2' diff --git a/tests/ut/mindconverter/test_forward_call.py b/tests/ut/mindconverter/test_forward_call.py new file mode 100644 index 0000000000000000000000000000000000000000..67f459600c9620a8062583e1d2e0c00e89d7a68e --- /dev/null +++ b/tests/ut/mindconverter/test_forward_call.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test forward_call module.""" +import ast +import textwrap +from unittest.mock import patch + +from mindinsight.mindconverter.forward_call import ForwardCall + + +class TestForwardCall: + """Test the class of ForwardCall.""" + source = textwrap.dedent("""\ + import a + import a.nn as nn + import a.nn.functional as F + class TestNet: + def __init__(self): + self.conv1 = nn.Conv2d(3, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + out = self.forward1(out) + return out + + def forward1(self, x): + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + out = self.fc3(out) + return out + """) + + @patch.object(ForwardCall, 'process') + def test_process(self, mock_process): + """Test the function of visit ast tree to find out forward functions.""" + mock_process.return_value = None + forward_call = ForwardCall("mock") + forward_call.visit(ast.parse(self.source)) + + expect_calls = ['TestNet.forward1', + 'TestNet.forward1', + 'F.relu', + 'TestNet.conv1', + 'F.max_pool2d', + 'TestNet.conv2', + 'out.view', + 'out.size', + 'TestNet.fc1', + 'TestNet.fc2', + 'TestNet.fc3', + ] + + assert [forward_call.calls].sort() == expect_calls.sort()