未验证 提交 cabc5f36 编写于 作者: J JZ-LIANG 提交者: GitHub

[Bugfix] reshape with zero input tensor (#35642)

* reshape support zero-input

* add unitest

* revise error message
上级 ecfe8375
......@@ -185,6 +185,8 @@ class ReshapeOp : public framework::OperatorWithKernel {
framework::make_ddim(shape), i, shape[i]));
}
// NOTE all non-zero values will be converted to True (include negative
// value)
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] =
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
......@@ -222,6 +224,21 @@ class ReshapeOp : public framework::OperatorWithKernel {
in_dims, in_size, framework::make_ddim(shape), capacity));
}
}
// support reshape with zero-input(input tensor with product(shape) == 0)
// by now we require that if the input tensor is zero shape, the target
// shape of output must be zero
if (in_size == 0) {
PADDLE_ENFORCE_EQ(
capacity, in_size,
platform::errors::InvalidArgument(
"The 'shape' in ReshapeOp is invalid. "
"The input tensor X's shape = [%s], X's capacity = %d."
"But the target shape of Out is [%s], the "
"capacity of 'Out' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
}
return framework::make_ddim(output_shape);
}
......
......@@ -464,5 +464,19 @@ class TestDygraphReshapeInplaceAPI(TestDygraphReshapeAPI):
self.reshape = paddle.reshape_
class TestReshapeZeroTensor(unittest.TestCase):
def test_reshape_zero_tensor_success(self):
zero_tensor = paddle.zeros([0, 2, 3])
# since we use "0" as the dimension copy semantically in reshape,
# we need to copy the 0 dim in the src tensor in order to make a successful zero tensor reshape
zero_tensor = zero_tensor.reshape([0, 6])
self.assertTrue(list(zero_tensor.shape) == [0, 6])
def test_reshape_zero_tensor_error(self):
zero_tensor = paddle.zeros([0, 2, 3])
with self.assertRaises(ValueError):
zero_tensor.reshape([2, 3])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册