diff --git a/paddle/phi/kernels/cpu/unbind_kernel.cc b/paddle/phi/kernels/cpu/unbind_kernel.cc index 39cc2f8fc4662a0893fb8b73b138a52b810f59b8..e8d0c01352c97c479177fade7f59c19168da1c2f 100644 --- a/paddle/phi/kernels/cpu/unbind_kernel.cc +++ b/paddle/phi/kernels/cpu/unbind_kernel.cc @@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind, CPU, ALL_LAYOUT, phi::UnbindKernel, + bool, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/gpu/unbind_kernel.cu b/paddle/phi/kernels/gpu/unbind_kernel.cu index 8a7aa8f6033ab9b86f87e792bc37f912562578a7..37272cebdf1188f4acfcc1b7cea4f2dbd153e558 100644 --- a/paddle/phi/kernels/gpu/unbind_kernel.cu +++ b/paddle/phi/kernels/gpu/unbind_kernel.cu @@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind, GPU, ALL_LAYOUT, phi::UnbindKernel, + bool, float, double, phi::dtype::float16, diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 4bc54f84a756ed7cb24a232a8c91ffb18ae85447..1820eaa3547e43a744f5b8d5e6a58c0731eb417c 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -201,6 +201,7 @@ class TestUnbindOp4(TestUnbindOp): class TestUnbindBF16Op(OpTest): def setUp(self): + paddle.disable_static() self._set_op_type() self.python_api = paddle.unbind self.dtype = self.get_dtype() @@ -247,5 +248,13 @@ class TestUnbindAxisError(unittest.TestCase): self.assertRaises(ValueError, test_invalid_axis) +class TestUnbindBool(unittest.TestCase): + def test_bool(self): + x = paddle.to_tensor([[True, True], [False, False]]) + xs = paddle.unbind(x, axis=0) + self.assertEqual(len(xs), 2) + np.testing.assert_array_equal(xs[0].numpy(), [True, True]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 59ebcbaafddc904692705d7db8704ff3b6459ccb..b61f5f6a5a910d6487b8919a0dea46bf2ad81815 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2759,7 +2759,7 @@ def unbind(input, axis=0): Removes a tensor dimension, then split the input tensor into multiple sub-Tensors. Args: - input (Tensor): The input variable which is an N-D Tensor, data type being float16, float32, float64, int32 or int64. + input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64. axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0. Returns: @@ -2808,7 +2808,7 @@ def unbind(input, axis=0): check_dtype( dtype, 'unbind', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'unbind', ) outs = [