From cabc5f360a9fbfaf91812214bacc993bc92973b6 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Mon, 13 Sep 2021 14:16:06 +0800 Subject: [PATCH] [Bugfix] reshape with zero input tensor (#35642) * reshape support zero-input * add unitest * revise error message --- paddle/fluid/operators/reshape_op.cc | 17 +++++++++++++++++ .../fluid/tests/unittests/test_reshape_op.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index ae7e1c07b14..8913642a594 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -185,6 +185,8 @@ class ReshapeOp : public framework::OperatorWithKernel { framework::make_ddim(shape), i, shape[i])); } + // NOTE all non-zero values will be converted to True (include negative + // value) capacity *= (shape[i] ? shape[i] : in_dims[i]); output_shape[i] = (shape[i] ? static_cast(shape[i]) : in_dims[i]); @@ -222,6 +224,21 @@ class ReshapeOp : public framework::OperatorWithKernel { in_dims, in_size, framework::make_ddim(shape), capacity)); } } + + // support reshape with zero-input(input tensor with product(shape) == 0) + // by now we require that if the input tensor is zero shape, the target + // shape of output must be zero + if (in_size == 0) { + PADDLE_ENFORCE_EQ( + capacity, in_size, + platform::errors::InvalidArgument( + "The 'shape' in ReshapeOp is invalid. " + "The input tensor X's shape = [%s], X's capacity = %d." + "But the target shape of Out is [%s], the " + "capacity of 'Out' is %d.", + in_dims, in_size, framework::make_ddim(shape), capacity)); + } + return framework::make_ddim(output_shape); } diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 4e296e7a889..a0063738d36 100755 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -464,5 +464,19 @@ class TestDygraphReshapeInplaceAPI(TestDygraphReshapeAPI): self.reshape = paddle.reshape_ +class TestReshapeZeroTensor(unittest.TestCase): + def test_reshape_zero_tensor_success(self): + zero_tensor = paddle.zeros([0, 2, 3]) + # since we use "0" as the dimension copy semantically in reshape, + # we need to copy the 0 dim in the src tensor in order to make a successful zero tensor reshape + zero_tensor = zero_tensor.reshape([0, 6]) + self.assertTrue(list(zero_tensor.shape) == [0, 6]) + + def test_reshape_zero_tensor_error(self): + zero_tensor = paddle.zeros([0, 2, 3]) + with self.assertRaises(ValueError): + zero_tensor.reshape([2, 3]) + + if __name__ == "__main__": unittest.main() -- GitLab