test_trace_module.py 4.8 KB
Newer Older
1 2
from collections import OrderedDict

3 4
import numpy as np

5
import megengine.functional as F
6
import megengine.module as M
7
from megengine import Tensor
8 9 10
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.module import Module
11
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
12
from megengine.traced_module.expr import Apply, CallFunction, CallMethod, Constant
13 14


15
class MyModule1(M.Module):
16 17 18 19 20 21 22
    def forward(self, x):
        y = Tensor(x)
        y += 1
        x = x + 2
        return x, y


23
class MyModule2(M.Module):
24 25 26 27 28 29 30
    def forward(self, x):
        y = Tensor([1, x, 1])
        y += 1
        x = x + 2
        return x, y


31 32 33 34 35 36
class MyModule3(M.Module):
    def __init__(self):
        super().__init__()
        self.modules = [
            M.Elemwise("ADD"),
            M.Elemwise("ADD"),
37 38 39
            OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
            M.Elemwise("RELU"),
            M.Elemwise("RELU"),
40 41 42 43 44
        ]

    def forward(self, a, b):
        x = self.modules[0](a, b)
        y = self.modules[1](a, b)
45 46 47 48 49
        assert list(self.modules[2].keys()) == ["a", "b"]
        for _, m in self.modules[2].items():
            y = m(x, y)
        for m in self.modules[3:]:
            y = m(y)
50 51 52
        return y


53 54 55 56 57 58 59 60 61
class MyModule4(M.Module):
    def __init__(self):
        super().__init__()
        self.add = F.add

    def forward(self, x, y):
        return self.add(x, y)


62 63 64 65 66 67 68 69
class MyModule5(M.Module):
    def forward(self, x):
        a = x + x
        b = x * a
        b.name = "result"
        return b


70
def test_trace_module():
71
    enable_expr_checker()
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    x = Tensor(1)
    m1 = MyModule1()
    tm1 = trace_module(m1, x)

    m2 = MyModule2()
    tm2 = trace_module(m2, x)
    inp = Tensor(2)
    gt = m1(inp)
    output = tm1(inp)
    for a, b in zip(output, gt):
        np.testing.assert_equal(a.numpy(), b.numpy())

    gt1 = m2(inp)
    output1 = tm2(inp)

    for a, b in zip(output1, gt1):
        np.testing.assert_equal(a.numpy(), b.numpy())
89 90 91 92 93 94 95 96 97 98

    a, b = Tensor(1), Tensor(2)
    m3 = MyModule3()
    gt = m3(a, b)
    tm3 = trace_module(m3, a, b)
    out = tm3(a, b)
    np.testing.assert_equal(out.numpy(), gt.numpy())
    assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
    assert isinstance(tm3.modules.__dict__["2"], TracedModule)
    assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
99
    assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
100 101 102

    m4 = MyModule4()
    tm4 = trace_module(m4, a, b)
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm4 = trace_module(m4, a, y=b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm4 = trace_module(m4, x=a, y=b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, a, b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, a, y=b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, x=a, y=b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

132 133
    assert len(tm4.graph._exprs) == 1
    assert isinstance(tm4.graph._exprs[0], CallFunction)
134 135 136 137 138 139 140 141 142 143 144 145

    class MyModule5(Module):
        def __init__(self):
            super().__init__()
            self.m1 = tm4

        def forward(self, x, y):
            return self.m1(x, y)

    tm6 = trace_module(MyModule5(), a, b)
    assert tm6.m1.argspec is None
    assert tm6.m1._is_top is False
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167


def test_trace_module_2():
    class Model(M.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            out = x.shape
            out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
            return out

    traced_model = trace_module(Model(), Tensor(([1,])))

    assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
        traced_model.graph._exprs[0].opdef, builtin.GetVarShape
    )
    assert isinstance(traced_model.graph._exprs[1], Constant)
    assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
        traced_model.graph._exprs[2].opdef, builtin.Elemwise
    )
    assert int(traced_model(Tensor([1, 2]))[0]) == 3
168 169 170 171 172 173


def test_rename():
    model = MyModule5()
    tm_model = trace_module(model, Tensor(1))
    assert isinstance(tm_model.graph.outputs[0].expr, CallMethod)