提交 2e530779 编写于 作者: M Megvii Engine Team

fix(mge/trace): use xpux device when dump

GitOrigin-RevId: f37285f70e9d21ca0c3951ebe917351e94e1ec3f
上级 739f927c
......@@ -20,6 +20,7 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
......@@ -588,6 +589,8 @@ class trace:
len(self._output_bindings)
)
)
if arg_names is None:
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings):
......@@ -598,6 +601,8 @@ class trace:
)
output_names = output_names or self._output_names
dumped_device = as_device("xpux")
h2v = {}
graph = G.Graph()
# only graph_opt_level takes effect in dump
......@@ -607,14 +612,14 @@ class trace:
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=info.device,
device=dumped_device,
shape=info.shape,
name=arg_names[i] if arg_names else None,
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype, device=info.device, shape=info.shape, name=k
dtype=info.dtype, device=dumped_device, shape=info.shape, name=k
)
for op, ihandles, ohandles in self._seq:
......@@ -625,7 +630,7 @@ class trace:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
)
ivars.append(h2v[h])
ovars = apply(op, *ivars)
......
......@@ -100,8 +100,8 @@ def test_dump():
file = io.BytesIO()
dump_info = f.dump(file)
assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"])
np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"])
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
file.seek(0)
result = cgtools.load_and_inference(file, [a, b])
np.testing.assert_equal(result[0], y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册