diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index e219f607bf8d10cb941b625cef4cec1042169841..842fef523f06f6e02988e30de2e579da3cc81bee 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -570,7 +570,9 @@ class trace: if h not in h2v: assert info.external assert info.bound_data - h2v[h] = graph.make_const(info.bound_data._dev_tensor()) + h2v[h] = graph.make_const( + info.bound_data.numpy(), dtype=info.dtype, device=info.device + ) ivars.append(h2v[h]) ovars = apply(op, *ivars) assert len(ovars) == len(ohandles) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index ecc811bd27f188237ad26f443d4189c5d9246682..443222c4ca16de0d4e5b12877b0082c6e417e636 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -150,7 +150,7 @@ def test_dump_volatile(): (out,) = outputs assert ( cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) - == "SharedDeviceTensor" + == "ImmutableTensor" ) @@ -235,6 +235,18 @@ def test_optimize_for_inference(): assert computing_input.dtype == np.float16 +def test_optimize_for_inference_broadcast(): + a = tensor(np.ones(1, dtype=np.float32)) + + @trace(capture_as_const=True, tensor_shape=True) + def f(): + (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) + return b + + f() + f.dump(io.BytesIO()) + + def test_trace_cvt_bool(): set_tensor_shape(True) x = tensor([0], dtype=np.int32) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 9e40def9439e18ec6ad58a70fd5f6ba2246d20d8..f26b81b8b597ac4946833e66dc13f4dfa725768f 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const { } SymbolVar new_var; - bool is_default_format = var->layout().format.is_default(); + bool is_default_format = var->format().is_default(); if (cg::is_static_var_value(var) && is_default_format) { // use ImmutableTensor for inferable vars HostTensorND hv;