test_trivial.py 548 字节
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
import numpy as np

import megengine.jit as jit
import megengine.tensor as tensor


def test_get_var_shape():
    def tester(shape):
        x = tensor(np.random.randn(*shape).astype("float32"))

        @jit.trace(without_host=True, use_xla=True)
        def func(x):
            return x.shape

        mge_rst = func(x)
        xla_rst = func(x)
        np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)

    tester((2, 3, 4, 5))
    tester((1, 2, 3))
    tester((1,))


if __name__ == "__main__":
    test_get_var_shape()