test_megbrain_graph.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from concurrent.futures import Future

import numpy as np

13
from megengine.core.ops.builtin import Elemwise
14
from megengine.core.tensor import megbrain_graph as mgb_graph
15
from megengine.tensor import Tensor
16 17 18 19


def test_io():
    g = mgb_graph.Graph()
20
    x = Tensor(np.random.randn(3).astype("float32"), device="xpux")._dev_tensor()
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    vx, _ = mgb_graph.input_callback(
        lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
    )
    y = Future()
    v = mgb_graph.output_callback(y.set_result, vx)
    f = g.compile(v)
    f()

    np.testing.assert_equal(x.numpy(), y.result().numpy())


def test_io2():
    g = mgb_graph.Graph()
    g.options.async_exec_level = 0b100
    dtype, device = "float32", "xpux"
    px = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
    py = mgb_graph.OutputNode(px.outputs[0])
    f = g.compile(py.outputs[0])

    for _ in range(3):
        f.execute()
42
        x = Tensor(np.random.randn(10).astype(dtype), device=device)._dev_tensor()
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
        px.set_value(x)
        y = py.get_value()
        np.testing.assert_equal(x.numpy(), y.numpy())
        f.wait()


def test_attr_output():
    g = mgb_graph.Graph()
    g.options.async_exec_level = 0b100
    dtype, device = "float32", "xpux"
    px = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
    py = mgb_graph.AttrOutputNode(px.outputs[0])
    f = g.compile(py.outputs[0])

    for shape in [(2,), (3,), (5,)]:
        f.execute()
59
        x = Tensor(np.random.randn(*shape).astype(dtype), device=device)._dev_tensor()
60 61 62 63 64 65 66 67 68 69
        px.set_value(x)
        ay = py.get_value()
        assert ay.shape == shape
        assert ay.dtype == np.dtype(dtype)
        assert ay.device == device
        f.wait()


def test_op():
    g = mgb_graph.Graph()
70
    x = Tensor(np.random.randn(10).astype("float32"), device="xpux")._dev_tensor()
71 72 73
    v, _ = mgb_graph.input_callback(
        lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
    )
74
    neg = Elemwise(Elemwise.Mode.NEGATE)
75
    v = mgb_graph.apply_normal_varnode(neg, v)[0]
76 77 78 79 80 81
    y = Future()
    v = mgb_graph.output_callback(y.set_result, v)
    f = g.compile(v)
    f()

    np.testing.assert_equal(x.numpy(), -y.result().numpy())
82 83 84 85 86 87 88 89 90 91


def test_exception():
    err_msg = "QwQ"

    def throw_exc():
        raise RuntimeError(err_msg)

    g = mgb_graph.Graph()
    x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
92
    neg = Elemwise(Elemwise.Mode.NEGATE)
93
    y = mgb_graph.OutputNode(mgb_graph.apply_normal_varnode(neg, x)[0])
94 95 96 97 98 99
    f = g.compile(y.outputs[0])
    try:
        f.execute()
        y.get_value()
    except Exception as exc:
        assert err_msg in str(exc)