From 553630aafcd956b2dd60ea92520244b3e89c9684 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 27 Mar 2023 12:50:05 +0800 Subject: [PATCH] unbind support bool dtype (#52080) * unbind support bool dtype * replace np.array_equal --- paddle/phi/kernels/cpu/unbind_kernel.cc | 1 + paddle/phi/kernels/gpu/unbind_kernel.cu | 1 + python/paddle/fluid/tests/unittests/test_unbind_op.py | 9 +++++++++ python/paddle/tensor/manipulation.py | 4 ++-- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/cpu/unbind_kernel.cc b/paddle/phi/kernels/cpu/unbind_kernel.cc index 39cc2f8fc46..e8d0c01352c 100644 --- a/paddle/phi/kernels/cpu/unbind_kernel.cc +++ b/paddle/phi/kernels/cpu/unbind_kernel.cc @@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind, CPU, ALL_LAYOUT, phi::UnbindKernel, + bool, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/gpu/unbind_kernel.cu b/paddle/phi/kernels/gpu/unbind_kernel.cu index 8a7aa8f6033..37272cebdf1 100644 --- a/paddle/phi/kernels/gpu/unbind_kernel.cu +++ b/paddle/phi/kernels/gpu/unbind_kernel.cu @@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind, GPU, ALL_LAYOUT, phi::UnbindKernel, + bool, float, double, phi::dtype::float16, diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 4bc54f84a75..1820eaa3547 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -201,6 +201,7 @@ class TestUnbindOp4(TestUnbindOp): class TestUnbindBF16Op(OpTest): def setUp(self): + paddle.disable_static() self._set_op_type() self.python_api = paddle.unbind self.dtype = self.get_dtype() @@ -247,5 +248,13 @@ class TestUnbindAxisError(unittest.TestCase): self.assertRaises(ValueError, test_invalid_axis) +class TestUnbindBool(unittest.TestCase): + def test_bool(self): + x = paddle.to_tensor([[True, True], [False, False]]) + xs = paddle.unbind(x, axis=0) + self.assertEqual(len(xs), 2) + np.testing.assert_array_equal(xs[0].numpy(), [True, True]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 59ebcbaafdd..b61f5f6a5a9 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2759,7 +2759,7 @@ def unbind(input, axis=0): Removes a tensor dimension, then split the input tensor into multiple sub-Tensors. Args: - input (Tensor): The input variable which is an N-D Tensor, data type being float16, float32, float64, int32 or int64. + input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64. axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0. Returns: @@ -2808,7 +2808,7 @@ def unbind(input, axis=0): check_dtype( dtype, 'unbind', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'unbind', ) outs = [ -- GitLab