test_xla_math.py 3.0 KB
Newer Older
1 2
import platform

3
import numpy as np
4
import pytest
5 6 7 8

import megengine.functional as F
import megengine.jit as jit
import megengine.tensor as tensor
9
from megengine import is_cuda_available
10 11 12
from megengine.autodiff.grad_manager import GradManager


13 14 15
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
16 17 18 19 20 21 22 23 24
def test_matmul():
    def tester(lhs_shape, rhs_shape, lhs_transpose, rhs_transpose, dtype=None):
        lhs = tensor(0.1 * np.random.randn(*lhs_shape), dtype=dtype)
        rhs = tensor(0.1 * np.random.randn(*rhs_shape), dtype=dtype)
        out = F.matmul(lhs, rhs, lhs_transpose, rhs_transpose)
        dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)

        gm = GradManager()

25
        @jit.xla_trace(without_host=True)
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        def func(lhs, rhs, dout):
            gm.attach([lhs, rhs])
            with gm:
                out = F.matmul(lhs, rhs, lhs_transpose, rhs_transpose)
                gm.backward(out, dout)
            return out, lhs.grad, rhs.grad

        mge_rsts = func(lhs, rhs, dout)
        mge_rsts[0].numpy()
        xla_rsts = func(lhs, rhs, dout)

        for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((5,), (5,), False, False)
    tester((4, 5), (5,), False, False)
    tester((5,), (5, 6), False, False)
    tester((5, 4), (5,), True, False)

    tester((4, 5), (5, 6), False, False)
    tester((4, 5), (6, 5), False, True)
    tester((5, 4), (5, 6), True, False)
    tester((5, 4), (6, 5), True, True)

    tester((2, 3, 4, 5), (5, 6), False, False)
    tester((2, 3, 4, 5), (6, 5), False, True)
    tester((2, 1, 5, 4), (5, 6), True, False)
    tester((2, 1, 5, 4), (6, 5), True, True)
    tester((1, 5, 4), (5, 6), True, False)
    tester((1, 5, 4), (6, 5), True, True)

    tester((4, 5), (2, 3, 5, 6), False, False)
    tester((4, 5), (2, 3, 6, 5), False, True)
    tester((5, 4), (2, 1, 5, 6), True, False)
    tester((5, 4), (2, 1, 6, 5), True, True)
    tester((5, 4), (1, 5, 6), True, False)
    tester((5, 4), (1, 6, 5), True, True)

    tester((1, 4, 5), (1, 5, 6), False, False)
    tester((1, 5, 4), (1, 5, 6), True, False)
    tester((3, 4, 5), (3, 5, 6), False, False)
    tester((3, 5, 4), (3, 6, 5), True, True)

    tester((5, 3, 2, 7, 8), (3, 2, 8, 9), False, False)
    tester((5, 1, 2, 7, 8), (1, 2, 9, 8), False, True)
    tester((5, 3, 2, 8, 7), (3, 1, 8, 9), True, False)
    tester((5, 3, 2, 8, 7), (1, 2, 9, 8), True, True)
    tester((5, 3, 2, 8, 7), (1, 8, 9), True, False)
    tester((5, 3, 1, 8, 7), (1, 9, 8), True, True)

    tester((3, 2, 7, 8), (4, 3, 2, 8, 9), False, False)
    tester((3, 1, 7, 8), (4, 3, 1, 9, 8), False, True)
    tester((3, 1, 8, 7), (4, 3, 2, 8, 9), True, False)
    tester((1, 2, 8, 7), (4, 2, 2, 9, 8), True, True)
    tester((1, 8, 7), (4, 3, 2, 8, 9), True, False)
    tester((1, 8, 7), (4, 3, 1, 9, 8), True, True)