提交 61f65cd4 编写于 作者: M Megvii Engine Team

test(mge): fix megbrain_graph/cgtools test

GitOrigin-RevId: 33ac776b56541035955ea898ef03aa57bb38afab
上级 ae47fd4e
......@@ -10,7 +10,7 @@ from concurrent.futures import Future
import numpy as np
import megengine.functional as F
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.tensor import Tensor
......@@ -71,7 +71,8 @@ def test_op():
v, _ = mgb_graph.input_callback(
lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
)
v = F.neg(v)
neg = Elemwise(Elemwise.Mode.NEGATE)
v = mgb_graph.apply_normal_op(neg, v)[0]
y = Future()
v = mgb_graph.output_callback(y.set_result, v)
f = g.compile(v)
......@@ -88,7 +89,8 @@ def test_exception():
g = mgb_graph.Graph()
x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
y = mgb_graph.OutputNode(F.neg(x))
neg = Elemwise(Elemwise.Mode.NEGATE)
y = mgb_graph.OutputNode(mgb_graph.apply_normal_op(neg, x)[0])
f = g.compile(y.outputs[0])
try:
f.execute()
......
......@@ -14,14 +14,15 @@ import megengine
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.megbrain_graph import apply_normal_op
from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor()
def test_replace_vars():
......@@ -30,10 +31,12 @@ def test_replace_vars():
device = "xpux"
dtype = np.float32
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
const = g.make_const(1.234)
a_plus_a = F.add(a.outputs[0], a.outputs[0])
a_plus_a_mul_const = F.mul(a_plus_a, const)
rst = F.add(a_plus_a_mul_const, a.outputs[0])
const = g.make_const(1.234, device=device)
add_op = Elemwise(Elemwise.Mode.ADD)
mul_op = Elemwise(Elemwise.Mode.MUL)
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0]
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0]
rst = apply_normal_op(add_op, a_plus_a_mul_const, a.outputs[0])[0]
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
func = g.compile(out.outputs[0])
......@@ -50,11 +53,13 @@ def test_replace_oprs():
device = "xpux"
dtype = np.float32
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
const = g.make_const(1.25)
a_plus_a = F.add(a.outputs[0], a.outputs[0])
const = g.make_const(1.25, device=device)
add_op = Elemwise(Elemwise.Mode.ADD)
mul_op = Elemwise(Elemwise.Mode.MUL)
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0]
old_opr = a_plus_a.op
a_plus_a_mul_const = F.mul(a_plus_a, const)
a_mul_a = F.mul(a.outputs[0], a.outputs[0])
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0]
a_mul_a = apply_normal_op(mul_op, a.outputs[0], a.outputs[0])[0]
new_opr = a_mul_a.op
(new,) = cgtools.replace_oprs(
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册