test_converter.py 3.6 KB
Newer Older
Q
quyongxiu1 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test Converter"""
from mindinsight.mindconverter.converter import Converter


class TestConverter:
    """Test Converter"""

    converter_ins = Converter()

    def test_judge_forward(self):
        """test judge_forward"""
        name1 = 'conv1'
        forward_list = {'conv1', 'relu'}
        result1 = self.converter_ins.judge_forward(name1, forward_list)
        assert result1 is True

        name2 = 'self.forward'
        result2 = self.converter_ins.judge_forward(name2, forward_list)
        assert result2 is True

    def test_find_left_parentheses(self):
        """test find_left_parentheses"""
        code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
                                                  nn.ReLU(),
                                                  nn.ReLU(True),
                                                  nn.MaxPool2d(2, 2),
                                                  nn.Conv2d(6, 16, 5, stride=1, padding=0),
                                                  nn.ReLU(inplace=False),
                                                  nn.MaxPool2d(2, 2))'''
        right_index = len(code) - 1
        left_index = code.index('nn.Conv2d')
        result = self.converter_ins.find_left_parentheses(code, right_index)
        assert result == left_index - 1

    def test_find_api(self):
        """test find_api"""
        code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
                                  nn.ReLU(),
                                  nn.ReLU(True),
                                  nn.MaxPool2d(2, 2),  # TODO padding
                                  nn.Conv2d(6, 16, 5, stride=1, padding=0),
                                  nn.ReLU(inplace=False),
                                  nn.MaxPool2d(2, 2))'''
        index = 0
        is_forward = False
        result = self.converter_ins.find_api(code, index, is_forward)
        assert result == 'nn.Sequential'

    def test_get_call_name(self):
        """test get_call_name"""
        code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))'''
        end = len(code)
        call_name, index = self.converter_ins.get_call_name(code, end)

        assert call_name == ''
        assert index == -1

    def test_find_right_parentheses(self):
        """test find_right_parentheses"""
        code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
                                          nn.ReLU(),
                                          nn.ReLU(True),
                                          nn.MaxPool2d(2, 2),  # TODO padding
                                          nn.Conv2d(6, 16, 5, stride=1, padding=0),
                                          nn.ReLU(inplace=False),
                                          nn.MaxPool2d(2, 2))'''
        left_index = 0
        result = self.converter_ins.find_right_parentheses(code, left_index)
        assert_index = len(code) - 1
        assert result == assert_index