From 61f65cd4958a62b6ac2d9882bfc8b787f2cec0bf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 28 Dec 2020 21:38:25 +0800 Subject: [PATCH] test(mge): fix megbrain_graph/cgtools test GitOrigin-RevId: 33ac776b56541035955ea898ef03aa57bb38afab --- .../test/unit/core/test_megbrain_graph.py | 8 +++--- imperative/python/test/unit/test_cgtools.py | 25 +++++++++++-------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/imperative/python/test/unit/core/test_megbrain_graph.py b/imperative/python/test/unit/core/test_megbrain_graph.py index ee651dd1c..ec2b93547 100644 --- a/imperative/python/test/unit/core/test_megbrain_graph.py +++ b/imperative/python/test/unit/core/test_megbrain_graph.py @@ -10,7 +10,7 @@ from concurrent.futures import Future import numpy as np -import megengine.functional as F +from megengine.core.ops.builtin import Elemwise from megengine.core.tensor import megbrain_graph as mgb_graph from megengine.tensor import Tensor @@ -71,7 +71,8 @@ def test_op(): v, _ = mgb_graph.input_callback( lambda: x, device=x.comp_node, dtype=x.dtype, graph=g ) - v = F.neg(v) + neg = Elemwise(Elemwise.Mode.NEGATE) + v = mgb_graph.apply_normal_op(neg, v)[0] y = Future() v = mgb_graph.output_callback(y.set_result, v) f = g.compile(v) @@ -88,7 +89,8 @@ def test_exception(): g = mgb_graph.Graph() x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g) - y = mgb_graph.OutputNode(F.neg(x)) + neg = Elemwise(Elemwise.Mode.NEGATE) + y = mgb_graph.OutputNode(mgb_graph.apply_normal_op(neg, x)[0]) f = g.compile(y.outputs[0]) try: f.execute() diff --git a/imperative/python/test/unit/test_cgtools.py b/imperative/python/test/unit/test_cgtools.py index c90760b48..5955bb3db 100644 --- a/imperative/python/test/unit/test_cgtools.py +++ b/imperative/python/test/unit/test_cgtools.py @@ -14,14 +14,15 @@ import megengine import megengine.functional as F import megengine.module as M import megengine.utils.comp_graph_tools as cgtools +from megengine.core.ops.builtin import Elemwise from megengine.core.tensor import megbrain_graph as mgb_graph -from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.core.tensor.megbrain_graph import apply_normal_op from megengine.core.tensor.utils import astensor1d from megengine.jit import trace def make_dev_tensor(value, dtype=None, device=None): - return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() + return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor() def test_replace_vars(): @@ -30,10 +31,12 @@ def test_replace_vars(): device = "xpux" dtype = np.float32 a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) - const = g.make_const(1.234) - a_plus_a = F.add(a.outputs[0], a.outputs[0]) - a_plus_a_mul_const = F.mul(a_plus_a, const) - rst = F.add(a_plus_a_mul_const, a.outputs[0]) + const = g.make_const(1.234, device=device) + add_op = Elemwise(Elemwise.Mode.ADD) + mul_op = Elemwise(Elemwise.Mode.MUL) + a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] + a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] + rst = apply_normal_op(add_op, a_plus_a_mul_const, a.outputs[0])[0] (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) func = g.compile(out.outputs[0]) @@ -50,11 +53,13 @@ def test_replace_oprs(): device = "xpux" dtype = np.float32 a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) - const = g.make_const(1.25) - a_plus_a = F.add(a.outputs[0], a.outputs[0]) + const = g.make_const(1.25, device=device) + add_op = Elemwise(Elemwise.Mode.ADD) + mul_op = Elemwise(Elemwise.Mode.MUL) + a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] old_opr = a_plus_a.op - a_plus_a_mul_const = F.mul(a_plus_a, const) - a_mul_a = F.mul(a.outputs[0], a.outputs[0]) + a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] + a_mul_a = apply_normal_op(mul_op, a.outputs[0], a.outputs[0])[0] new_opr = a_mul_a.op (new,) = cgtools.replace_oprs( [a_plus_a_mul_const._node], {old_opr._node: new_opr._node} -- GitLab