test_elemwise.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
import numpy as np

import megengine as mge
import megengine.functional as F
import megengine.jit as jit
import megengine.tensor as tensor
from megengine.autodiff.grad_manager import GradManager


def test_elemwise():
    np.random.seed(123)
    mge.random.seed(123)

    def tester(felemwise, *inp_shapes, backward=True, dtype=None, atol=1e-5):
        dtype = dtype or np.float32
        inps = [
            tensor(0.1 * np.random.randn(*inp_shape), dtype=dtype)
            for inp_shape in inp_shapes
        ]
        doup = tensor(0.1 * np.random.randn(*felemwise(*inps).shape), dtype=dtype)

        gm = GradManager()

        @jit.trace(without_host=True, use_xla=True)
        def func(inps, doup):
            gm.attach(inps)
            with gm:
                oup = felemwise(*inps)
                if backward:
                    gm.backward(oup, doup)
                    return [oup, *[inp.grad for inp in inps]]
                else:
                    return [oup]

        mge_rsts = func(inps, doup)
        xla_rsts = func(inps, doup)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol)

    tester(F.neg, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
    tester(F.abs, (2, 32, 16), dtype=np.float32, atol=1e-5)
    tester(F.tanh, (4, 16, 3, 1), backward=False, dtype=np.float32, atol=1e-5)
    tester(F.exp, (2, 8), dtype=np.float32, atol=1e-5)
    tester(F.sqrt, (32,), dtype=np.float32, atol=1e-5)
    tester(F.log, (8, 8, 16), dtype=np.float32, atol=1e-5)
    tester(F.relu, (1,), dtype=np.float32, atol=1e-5)
    tester(F.gelu, (4, 16, 12, 12), dtype=np.float32, atol=2e-5)

    tester(F.add, (4, 16, 12, 12), (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
    tester(F.sub, (4, 16, 12, 12), (4, 16, 1, 1), dtype=np.float32, atol=1e-5)
    tester(F.mul, (4, 16, 12, 12), (1, 1, 12, 12), dtype=np.float32, atol=1e-5)
    tester(
        F.div,
        (4, 16, 1, 1),
        (4, 16, 12, 12),
        backward=False,
        dtype=np.float32,
        atol=1e-5,
    )
    tester(F.pow, (4, 1, 12, 12), (1, 16, 12, 12), dtype=np.float32, atol=1e-5)

    tester(
        F.equal, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.float32, atol=1e-5
    )
    tester(
        F.not_equal,
        (4, 16, 12, 12),
        (4, 16, 1, 1),
        backward=False,
        dtype=np.float32,
        atol=1e-5,
    )
    tester(
        F.greater,
        (4, 16, 1, 1),
        (4, 16, 12, 12),
        backward=False,
        dtype=np.float32,
        atol=1e-5,
    )
    tester(
        F.greater_equal,
        (16, 1, 1),
        (4, 16, 12, 12),
        backward=False,
        dtype=np.float32,
        atol=1e-5,
    )
    tester(
        F.less,
        (4, 16, 12, 1),
        (4, 16, 12, 12),
        backward=False,
        dtype=np.float32,
        atol=1e-5,
    )
    tester(
        F.less_equal,
        (1, 1, 12, 12),
        (4, 16, 12, 12),
        backward=False,
        dtype=np.float32,
        atol=1e-5,
    )


if __name__ == "__main__":
    test_elemwise()