diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 2703c67e274399174911e8133b6e9cc22ae036e6..6ad39a990a56b35118333855653c6b545862b332 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -122,6 +122,33 @@ def test_reshape(): np.testing.assert_equal(yy.numpy(), y) +def test_reshape_shape_inference(): + x_shape_known = tensor([1, 2, 3, 4], dtype="float32") + x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) + tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) + tshp_known = astensor1d((2, 2), x_shape_known) + tshp_known_unspec = astensor1d((2, -1), x_shape_known) + + def check_shape(output, target): + source = output.shape + if isinstance(source, tensor): + source = source.numpy() + np.testing.assert_equal(source, target) + + def func(x, target_shape): + return x.reshape(target_shape) + + cases = [ + {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]}, + {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]}, + {"input": [x_shape_known, tshp_known], "output": [(2, 2),]}, + {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]}, + {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, + {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, + ] + opr_test(cases, func, compare_fn=check_shape, test_trace=True) + + def test_squeeze(): x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) xx = tensor(x) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 3ea2e870684dc19e94be740b3b6821f2749a8fd5..71a019ff5c4f55672589426cba8b73f075b3f97d 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -115,9 +115,13 @@ std::tuple, bool> infer_output_attrs_fallible( } mgb_assert( tshp.layout.ndim == 1, - "target shape of Broadcast expects ndim=1; got ndim=%lu actually", + "target shape of Reshape expects ndim=1; got ndim=%lu actually", tshp.layout.ndim); + if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; + } + size_t target_ndim = tshp.layout.shape[0]; out_shape.ndim = target_ndim; auto* ptr = tshp.value.ptr();