From cbff4d7c1ad565a10fbd07390e638d51b2491fd4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 Jan 2021 14:57:24 +0800 Subject: [PATCH] fix(imperative/ops): fix infer_output_attrs_fallible for reshape GitOrigin-RevId: a93567d79abe501ac86fb6d8384019a9c0c34d06 --- .../test/unit/functional/test_tensor.py | 27 +++++++++++++++++++ imperative/src/impl/ops/broadcast.cpp | 6 ++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 2703c67e..6ad39a99 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -122,6 +122,33 @@ def test_reshape(): np.testing.assert_equal(yy.numpy(), y) +def test_reshape_shape_inference(): + x_shape_known = tensor([1, 2, 3, 4], dtype="float32") + x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) + tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) + tshp_known = astensor1d((2, 2), x_shape_known) + tshp_known_unspec = astensor1d((2, -1), x_shape_known) + + def check_shape(output, target): + source = output.shape + if isinstance(source, tensor): + source = source.numpy() + np.testing.assert_equal(source, target) + + def func(x, target_shape): + return x.reshape(target_shape) + + cases = [ + {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]}, + {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]}, + {"input": [x_shape_known, tshp_known], "output": [(2, 2),]}, + {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]}, + {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, + {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, + ] + opr_test(cases, func, compare_fn=check_shape, test_trace=True) + + def test_squeeze(): x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) xx = tensor(x) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 3ea2e870..71a019ff 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -115,9 +115,13 @@ std::tuple, bool> infer_output_attrs_fallible( } mgb_assert( tshp.layout.ndim == 1, - "target shape of Broadcast expects ndim=1; got ndim=%lu actually", + "target shape of Reshape expects ndim=1; got ndim=%lu actually", tshp.layout.ndim); + if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; + } + size_t target_ndim = tshp.layout.shape[0]; out_shape.ndim = target_ndim; auto* ptr = tshp.value.ptr(); -- GitLab