diff --git a/tests/ut/mindconverter/test_converter.py b/tests/ut/mindconverter/test_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..9674cfc613f32702c1445d87869705195f269a8f --- /dev/null +++ b/tests/ut/mindconverter/test_converter.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test Converter""" +from mindinsight.mindconverter.converter import Converter + + +class TestConverter: + """Test Converter""" + + converter_ins = Converter() + + def test_judge_forward(self): + """test judge_forward""" + name1 = 'conv1' + forward_list = {'conv1', 'relu'} + result1 = self.converter_ins.judge_forward(name1, forward_list) + assert result1 is True + + name2 = 'self.forward' + result2 = self.converter_ins.judge_forward(name2, forward_list) + assert result2 is True + + def test_find_left_parentheses(self): + """test find_left_parentheses""" + code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), + nn.ReLU(), + nn.ReLU(True), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5, stride=1, padding=0), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2))''' + right_index = len(code) - 1 + left_index = code.index('nn.Conv2d') + result = self.converter_ins.find_left_parentheses(code, right_index) + assert result == left_index - 1 + + def test_find_api(self): + """test find_api""" + code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), + nn.ReLU(), + nn.ReLU(True), + nn.MaxPool2d(2, 2), # TODO padding + nn.Conv2d(6, 16, 5, stride=1, padding=0), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2))''' + index = 0 + is_forward = False + result = self.converter_ins.find_api(code, index, is_forward) + assert result == 'nn.Sequential' + + def test_get_call_name(self): + """test get_call_name""" + code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))''' + end = len(code) + call_name, index = self.converter_ins.get_call_name(code, end) + + assert call_name == '' + assert index == -1 + + def test_find_right_parentheses(self): + """test find_right_parentheses""" + code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), + nn.ReLU(), + nn.ReLU(True), + nn.MaxPool2d(2, 2), # TODO padding + nn.Conv2d(6, 16, 5, stride=1, padding=0), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2))''' + left_index = 0 + result = self.converter_ins.find_right_parentheses(code, left_index) + assert_index = len(code) - 1 + assert result == assert_index