From 59a9275c665dddc818f5eba25e5fd7034a3cbef0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 16 Sep 2020 16:14:10 +0800 Subject: [PATCH] fix(mge): fix optimize_for_inference during trace.dump GitOrigin-RevId: e10f7c323a1832a9727211c9ee6cd9242c869c3b --- imperative/python/megengine/jit/tracing.py | 4 +++- imperative/python/test/unit/test_tracing.py | 14 +++++++++++++- src/gopt/impl/inference.cpp | 2 +- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index e219f607b..842fef523 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -570,7 +570,9 @@ class trace: if h not in h2v: assert info.external assert info.bound_data - h2v[h] = graph.make_const(info.bound_data._dev_tensor()) + h2v[h] = graph.make_const( + info.bound_data.numpy(), dtype=info.dtype, device=info.device + ) ivars.append(h2v[h]) ovars = apply(op, *ivars) assert len(ovars) == len(ohandles) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index ecc811bd2..443222c4c 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -150,7 +150,7 @@ def test_dump_volatile(): (out,) = outputs assert ( cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) - == "SharedDeviceTensor" + == "ImmutableTensor" ) @@ -235,6 +235,18 @@ def test_optimize_for_inference(): assert computing_input.dtype == np.float16 +def test_optimize_for_inference_broadcast(): + a = tensor(np.ones(1, dtype=np.float32)) + + @trace(capture_as_const=True, tensor_shape=True) + def f(): + (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) + return b + + f() + f.dump(io.BytesIO()) + + def test_trace_cvt_bool(): set_tensor_shape(True) x = tensor([0], dtype=np.int32) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 9e40def94..f26b81b8b 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const { } SymbolVar new_var; - bool is_default_format = var->layout().format.is_default(); + bool is_default_format = var->format().is_default(); if (cg::is_static_var_value(var) && is_default_format) { // use ImmutableTensor for inferable vars HostTensorND hv; -- GitLab