test_forward_call.py 2.6 KB
Newer Older
G
ggpolar 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test forward_call module."""
import ast
import textwrap
from unittest.mock import patch

from mindinsight.mindconverter.forward_call import ForwardCall


class TestForwardCall:
    """Test the class of ForwardCall."""
    source = textwrap.dedent("""\
        import a
        import a.nn as nn
        import a.nn.functional as F
        class TestNet:
            def __init__(self):
                self.conv1 = nn.Conv2d(3, 6, 5)
                self.conv2 = nn.Conv2d(6, 16, 5)
                self.fc1 = nn.Linear(16 * 5 * 5, 120)
                self.fc2 = nn.Linear(120, 84)
                self.fc3 = nn.Linear(84, 10)

            def forward(self, x):
                out = self.forward1(out)
                return out

            def forward1(self, x):
                out = F.relu(self.conv1(x))
                out = F.max_pool2d(out, 2)
                out = F.relu(self.conv2(out))
                out = F.max_pool2d(out, 2)
                out = out.view(out.size(0), -1)
                out = F.relu(self.fc1(out))
                out = F.relu(self.fc2(out))
                out = self.fc3(out)
                return out
    """)

    @patch.object(ForwardCall, 'process')
    def test_process(self, mock_process):
        """Test the function of visit ast tree to find out forward functions."""
        mock_process.return_value = None
        forward_call = ForwardCall("mock")
        forward_call.visit(ast.parse(self.source))

        expect_calls = ['TestNet.forward1',
                        'TestNet.forward1',
                        'F.relu',
                        'TestNet.conv1',
                        'F.max_pool2d',
                        'TestNet.conv2',
                        'out.view',
                        'out.size',
                        'TestNet.fc1',
                        'TestNet.fc2',
                        'TestNet.fc3',
                        ]

        assert [forward_call.calls].sort() == expect_calls.sort()