test_xla_tensor.py 8.5 KB
Newer Older
1 2
import platform

3
import numpy as np
4
import pytest
5 6 7 8

import megengine.functional as F
import megengine.jit as jit
import megengine.tensor as tensor
9
from megengine import is_cuda_available
10 11 12
from megengine.autodiff.grad_manager import GradManager


13 14 15
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
16 17 18 19 20 21 22 23 24
def test_broadcast_to():
    def tester(ishape, tgtshape):
        dtype = None
        dtype = dtype or np.float32
        inp = tensor(np.random.randn(*ishape), dtype=dtype)
        dout = tensor(np.random.randn(*tgtshape), dtype=dtype)

        gm = GradManager()

25
        @jit.xla_trace(without_host=True)
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        def func(inp, dout):
            gm.attach([inp])
            with gm:
                out = F.broadcast_to(inp, tgtshape)
                gm.backward(out, dout)
            return [out, inp.grad]

        mge_rsts = func(inp, dout)
        xla_rsts = func(inp, dout)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((1, 1, 1), (1, 1, 1, 1))
    tester((1, 1, 1, 1), (1, 1, 1, 1))
    tester((1, 1, 1, 1), (4, 5, 6, 7))
    tester((1, 1, 1), (4, 5, 6, 7))
    tester((5, 6, 7), (4, 5, 6, 7))
    tester((1, 6, 1), (4, 5, 6, 7))
    tester((1, 5, 6, 7), (4, 5, 6, 7))
    tester((1,), (4, 5, 1, 7))
    tester((4, 5, 3, 1), (4, 5, 3, 7))
    tester((4, 5, 3, 7), (4, 5, 3, 7))


50 51 52
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
53 54 55 56 57 58 59 60 61
def test_reshape():
    def tester(ishape, tgt_shape, dtype=None):
        dtype = dtype or np.float32
        inp = tensor(np.random.randn(*ishape), dtype=dtype)
        oshape = F.reshape(inp, tgt_shape).shape
        dout = tensor(np.random.randn(*oshape), dtype=dtype)

        gm = GradManager()

62
        @jit.xla_trace(without_host=True)
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
        def func(inp, dout):
            gm.attach([inp])
            with gm:
                out = F.reshape(inp, tgt_shape)
                gm.backward(out, dout)
            return [out, inp.grad]

        mge_rsts = func(inp, dout)
        xla_rsts = func(inp, dout)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((1,), (1,))
    tester((1,), (1, 1, 1, 1))
    tester((2, 3, 4), (24,))
    tester((2, 3, 4), (2, 12))
    tester((2, 3, 4), (4, 3, 2))
    tester((2, 1, 4), (8, 1))
    tester((2, 1, 4), (-1))
    tester((2, 1, 4), (-1, 2))


85 86 87
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
88 89 90 91 92 93 94 95 96
def test_transpose():
    def tester(ishape, permutation, dtype=None):
        dtype = dtype or np.float32
        inp = tensor(np.random.randn(*ishape), dtype=dtype)
        oshape = F.transpose(inp, permutation).shape
        dout = tensor(np.random.randn(*oshape), dtype=dtype)

        gm = GradManager()

97
        @jit.xla_trace(without_host=True)
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        def func(inp, dout):
            gm.attach([inp])
            with gm:
                out = F.transpose(inp, permutation)
                gm.backward(out, dout)
            return [out, inp.grad]

        mge_rsts = func(inp, dout)
        xla_rsts = func(inp, dout)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((1,), (0,))
    tester((2, 3, 4), (0, 2, 1))
    tester((2, 3, 4), (2, 0, 1))
    tester((2, 3, 1), (0, 1, 2))
    tester((2, 3, 1, 4), (3, 1, 0, 2))


117 118 119
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
120 121 122 123 124 125 126 127 128
def test_expand_dims():
    def tester(ishape, axis, dtype=None):
        dtype = dtype or np.float32
        inp = tensor(np.random.randn(*ishape), dtype=dtype)
        oshape = F.expand_dims(inp, axis).shape
        dout = tensor(np.random.randn(*oshape), dtype=dtype)

        gm = GradManager()

129
        @jit.xla_trace(without_host=True)
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        def func(inp, dout):
            gm.attach([inp])
            with gm:
                out = F.expand_dims(inp, axis)
                gm.backward(out, dout)
            return [out, inp.grad]

        mge_rsts = func(inp, dout)
        xla_rsts = func(inp, dout)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((2, 1, 4), 0)
    tester((2, 3, 4), 1)
    tester((2, 3, 4, 5), -1)


147 148 149
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
150 151 152 153 154 155 156 157 158
def test_concat():
    def tester(*ishapes, axis, dtype=None):
        dtype = dtype or np.float32
        inps = [tensor(np.random.randn(*ishape), dtype=dtype) for ishape in ishapes]
        oshape = F.concat(inps, axis=axis).shape
        dout = tensor(np.random.randn(*oshape), dtype=dtype)

        gm = GradManager()

159
        @jit.xla_trace(without_host=True)
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
        def func(*inps, dout):
            gm.attach(inps)
            with gm:
                out = F.concat(inps, axis=axis)
                gm.backward(out, dout)
            rets = [inp.grad for inp in inps] + [out]
            return rets

        mge_rsts = func(*inps, dout=dout)
        xla_rsts = func(*inps, dout=dout)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((6, 5, 4), (6, 3, 4), (6, 1, 4), axis=1)
    tester((6, 5, 2), (6, 5, 1), axis=-1)
    tester((2, 5, 4), (6, 5, 4), axis=0)
    tester((1, 5, 4), (1, 5, 4), axis=0)
    tester((6, 5, 1), axis=-1)


180 181 182
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
183 184 185 186 187 188 189 190 191
def test_split():
    def tester(ishape, axis, nsplit_or_sections, dtype=None):
        dtype = dtype or np.float32
        inp = tensor(np.random.randn(*ishape), dtype=dtype)
        oshapes = [o.shape for o in F.split(inp, nsplit_or_sections, axis)]
        douts = [tensor(np.random.randn(*oshape), dtype=dtype) for oshape in oshapes]

        gm = GradManager()

192
        @jit.xla_trace(without_host=True)
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        def func(inp, douts):
            gm.attach([inp])
            with gm:
                outs = list(F.split(inp, nsplit_or_sections, axis))
                gm.backward(outs, douts)
            rets = outs + [inp.grad]
            return rets

        mge_rsts = func(inp, douts)
        xla_rsts = func(inp, douts)
        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((32, 16, 8), -2, 5)
    tester((32, 16, 8), 0, [8, 14, 27])
    tester((32, 16, 8), 1, 1)
    tester((32, 16, 8), 1, 16)


212 213 214
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
215 216 217 218 219
def test_fill_and_fill_like():
    def tester(ref_shape, value, dtype=None):
        dtype = dtype or np.float32
        ref = tensor(np.random.randn(*ref_shape), dtype=dtype)

220
        @jit.xla_trace(without_host=True)
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        def func(ref):
            return (
                F.full_like(ref, value),
                F.full(ref.shape, value, dtype=dtype),
                F.ones_like(ref),
                F.ones(ref.shape, dtype=dtype),
                F.zeros_like(ref),
                F.zeros(ref.shape, dtype=dtype),
            )

        mge_rst = func(ref)
        xla_rst = func(ref)
        for mge, xla in zip(mge_rst, xla_rst):
            np.testing.assert_allclose(mge.numpy(), xla.numpy(), atol=1e-5)

    tester((1,), 0.1)
    tester((16,), 0.1)
    tester((1, 16), 0.1)
    tester((32, 16), 0.1)
    tester((32, 16), 0)
    tester((1, 1, 16), 1)