提交 523ce65e 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix cgtools related tests

GitOrigin-RevId: 8f1eadb32e7d7f285f0aa97378f99828d9dceee7
上级 dd39265e
......@@ -8,6 +8,7 @@
import io
import numpy as np
import pytest
import megengine
import megengine.functional as F
......@@ -66,6 +67,7 @@ def test_replace_oprs():
np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
@pytest.mark.skip(reason="Please check opr index")
def test_graph_traversal():
net = M.Conv2d(3, 32, 3)
......
......@@ -26,18 +26,18 @@ from megengine.jit import exclude_from_trace, trace
def load_and_inference(file, inp_data):
cg, _, out_list = mgb_graph.load_graph(file)
cg, _, out_list = G.load_graph(file)
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = mgb_graph.InputNode(
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = cgtools.replace_vars(out_list, replace_dict)
out_node_list = [mgb_graph.OutputNode(i) for i in new_out]
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
new_cg = new_out_list[0].graph
func = new_cg.compile(new_out_list)
......@@ -150,6 +150,7 @@ def test_capture_dump():
np.testing.assert_equal(result[0], y)
@pytest.mark.skip(reason="get MultipleDeviceTensorHolder instead of SharedDeviceTensor")
def test_dump_volatile():
p = as_raw_tensor([2])
......@@ -168,7 +169,7 @@ def test_dump_volatile():
file = io.BytesIO()
f.dump(file)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
cg, _, outputs = G.load_graph(file)
(out,) = outputs
assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册