提交 59a9275c 编写于 作者: M Megvii Engine Team

fix(mge): fix optimize_for_inference during trace.dump

GitOrigin-RevId: e10f7c323a1832a9727211c9ee6cd9242c869c3b
上级 bd7f885a
......@@ -570,7 +570,9 @@ class trace:
if h not in h2v:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(info.bound_data._dev_tensor())
h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device
)
ivars.append(h2v[h])
ovars = apply(op, *ivars)
assert len(ovars) == len(ohandles)
......
......@@ -150,7 +150,7 @@ def test_dump_volatile():
(out,) = outputs
assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "SharedDeviceTensor"
== "ImmutableTensor"
)
......@@ -235,6 +235,18 @@ def test_optimize_for_inference():
assert computing_input.dtype == np.float16
def test_optimize_for_inference_broadcast():
a = tensor(np.ones(1, dtype=np.float32))
@trace(capture_as_const=True, tensor_shape=True)
def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b
f()
f.dump(io.BytesIO())
def test_trace_cvt_bool():
set_tensor_shape(True)
x = tensor([0], dtype=np.int32)
......
......@@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const {
}
SymbolVar new_var;
bool is_default_format = var->layout().format.is_default();
bool is_default_format = var->format().is_default();
if (cg::is_static_var_value(var) && is_default_format) {
// use ImmutableTensor for inferable vars
HostTensorND hv;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册