test_tracing.py 2.6 KB
Newer Older
M
Megvii Engine Team 已提交
1 2
import io

M
Megvii Engine Team 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
import numpy as np

from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.jit import exclude_from_trace, trace


def test_trace():
    for symbolic in [False, True]:

        @trace(symbolic=symbolic)
        def f(x):
            op = ops.Elemwise(mode="negate")
            (y,) = apply(op, x)
            return y

        x = as_raw_tensor([1]).numpy()
        y = f.__wrapped__(as_raw_tensor(x)).numpy()

        for i in range(3):
            np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)


def test_exclude_from_trace():
    for symbolic in [False, True]:

        @trace(symbolic=symbolic)
        def f(x):
            neg = ops.Elemwise(mode="negate")
            (x,) = apply(neg, x)
            with exclude_from_trace():
                if i % 2:
                    (x,) = apply(neg, x)
            (x,) = apply(neg, x)
            return x

        x = as_raw_tensor([1]).numpy()

        for i in range(3):
            y = f.__wrapped__(as_raw_tensor(x)).numpy()
            np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)


def test_print_in_trace():
    for symbolic in [False]:  # cannot read value in symbolic mode

        @trace(symbolic=symbolic)
        def f(x):
            nonlocal buf
            neg = ops.Elemwise(mode="negate")
            (x,) = apply(neg, x)
            buf = x.numpy()
            (x,) = apply(neg, x)
            return x

        buf = None
        x = as_raw_tensor([1]).numpy()

        for i in range(3):
            y = f.__wrapped__(as_raw_tensor(x)).numpy()
            z = buf
            buf = None
            np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
            np.testing.assert_equal(z, buf)
M
Megvii Engine Team 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84


def test_dump():
    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        op = ops.Elemwise(mode="negate")
        (y,) = apply(op, x)
        return y

    x = as_raw_tensor([1]).numpy()
    y = f.__wrapped__(as_raw_tensor(x)).numpy()

    for i in range(3):
        np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)

    file = io.BytesIO()
    f.dump(file)
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103


def test_trace_profiler():
    for symbolic in [False, True]:

        @trace(symbolic=symbolic, profiling=True)
        def f(x):
            op = ops.Elemwise(mode="negate")
            (y,) = apply(op, x)
            return y

        x = as_raw_tensor([1]).numpy()
        y = f.__wrapped__(as_raw_tensor(x)).numpy()

        f(as_raw_tensor(x))
        f(as_raw_tensor(x))  # XXX: has to run twice

        out = f.get_profile()
        assert out.get("profiler")