提交 2e530779 编写于 作者: M Megvii Engine Team

fix(mge/trace): use xpux device when dump

GitOrigin-RevId: f37285f70e9d21ca0c3951ebe917351e94e1ec3f
上级 739f927c
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
...@@ -588,6 +589,8 @@ class trace: ...@@ -588,6 +589,8 @@ class trace:
len(self._output_bindings) len(self._output_bindings)
) )
) )
if arg_names is None:
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
if arg_names and not isinstance(arg_names, collections.abc.Sequence): if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,) arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings): if arg_names and len(arg_names) != len(self._arg_bindings):
...@@ -598,6 +601,8 @@ class trace: ...@@ -598,6 +601,8 @@ class trace:
) )
output_names = output_names or self._output_names output_names = output_names or self._output_names
dumped_device = as_device("xpux")
h2v = {} h2v = {}
graph = G.Graph() graph = G.Graph()
# only graph_opt_level takes effect in dump # only graph_opt_level takes effect in dump
...@@ -607,14 +612,14 @@ class trace: ...@@ -607,14 +612,14 @@ class trace:
info = self._tinfo[h] info = self._tinfo[h]
h2v[h] = graph.make_h2d( h2v[h] = graph.make_h2d(
dtype=info.dtype, dtype=info.dtype,
device=info.device, device=dumped_device,
shape=info.shape, shape=info.shape,
name=arg_names[i] if arg_names else None, name=arg_names[i] if arg_names else None,
) )
for k, h in self._kwarg_bindings.items(): for k, h in self._kwarg_bindings.items():
info = self._tinfo[h] info = self._tinfo[h]
h2v[h] = graph.make_h2d( h2v[h] = graph.make_h2d(
dtype=info.dtype, device=info.device, shape=info.shape, name=k dtype=info.dtype, device=dumped_device, shape=info.shape, name=k
) )
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
...@@ -625,7 +630,7 @@ class trace: ...@@ -625,7 +630,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
) )
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
......
...@@ -100,8 +100,8 @@ def test_dump(): ...@@ -100,8 +100,8 @@ def test_dump():
file = io.BytesIO() file = io.BytesIO()
dump_info = f.dump(file) dump_info = f.dump(file)
assert dump_info.nr_opr == 3 assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"]) np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"]) np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
file.seek(0) file.seek(0)
result = cgtools.load_and_inference(file, [a, b]) result = cgtools.load_and_inference(file, [a, b])
np.testing.assert_equal(result[0], y) np.testing.assert_equal(result[0], y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册