test_modification.py 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine.functional as F
import megengine.module as M
from megengine.experimental.traced_module import trace_module
from megengine.experimental.traced_module.expr import CallFunction, GetAttr


class MyBlock(M.Module):
    def __init__(self, in_channels=3, channels=3):
        super(MyBlock, self).__init__()
        self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
        self.bn1 = M.BatchNorm2d(channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x) + 1
        return x


class MyModule(M.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.block0 = MyBlock()
        self.block1 = MyBlock()

    def forward(self, x):
        x = self.block0(x)
        x = self.block1(x)
        return x


def _init_cls(cls):
    module = cls()
    x = F.ones((1, 3, 3, 3))
    y = module(x)
    traced_module = trace_module(module, x)
    return traced_module, x, y


def _init_block():
    return _init_cls(MyBlock)


def _init_module():
    return _init_cls(MyModule)


def test_search():
    traced_module, *_ = _init_block()
    graph = traced_module.graph
    relu_expr = graph.get_call_function(F.relu).as_unique()
    assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu


def test_insert():
    traced_module, x, expect = _init_block()
    graph = traced_module.graph
    relu_node = graph.get_call_function(F.relu).as_unique().outputs
    neg_node = graph.insert_call_function(F.neg, relu_node)
    graph.replace_node({relu_node[0]: neg_node[0]})
    graph.compile()
    np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)


def test_delete():
    traced_module, x, expect = _init_block()
    graph = traced_module.graph
    relu_expr = graph.get_call_function(F.relu).as_unique()
    node = relu_expr.outputs
    repl_node = relu_expr.inputs
    graph.replace_node({node[0]: repl_node[0]})
    graph.compile()
    np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)


def test_flatten():
    traced_module, x, expect = _init_module()
    traced_module = traced_module.flatten()
    traced_module.graph.compile()
    assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
    assert len(traced_module.graph._exprs) == 12


def test_extra_block():
    class PostProcess(M.Module):
        def forward(self, x):
            return x * 2

    class Net(M.Module):
        def __init__(self, traced_module):
            super().__init__()
            self.post_process = PostProcess()
            self.traced_module = traced_module

        def forward(self, x):
            x = self.traced_module(x)
            x = self.post_process(x)
            return x

    traced_module, x, expect = _init_block()
    module = Net(traced_module)
    np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
    traced_module = trace_module(module, x)
    np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)