未验证 提交 553630aa 编写于 作者: L Leo Chen 提交者: GitHub

unbind support bool dtype (#52080)

* unbind support bool dtype

* replace np.array_equal
上级 a6449634
...@@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind, ...@@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnbindKernel, phi::UnbindKernel,
bool,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind, ...@@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::UnbindKernel, phi::UnbindKernel,
bool,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -201,6 +201,7 @@ class TestUnbindOp4(TestUnbindOp): ...@@ -201,6 +201,7 @@ class TestUnbindOp4(TestUnbindOp):
class TestUnbindBF16Op(OpTest): class TestUnbindBF16Op(OpTest):
def setUp(self): def setUp(self):
paddle.disable_static()
self._set_op_type() self._set_op_type()
self.python_api = paddle.unbind self.python_api = paddle.unbind
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
...@@ -247,5 +248,13 @@ class TestUnbindAxisError(unittest.TestCase): ...@@ -247,5 +248,13 @@ class TestUnbindAxisError(unittest.TestCase):
self.assertRaises(ValueError, test_invalid_axis) self.assertRaises(ValueError, test_invalid_axis)
class TestUnbindBool(unittest.TestCase):
def test_bool(self):
x = paddle.to_tensor([[True, True], [False, False]])
xs = paddle.unbind(x, axis=0)
self.assertEqual(len(xs), 2)
np.testing.assert_array_equal(xs[0].numpy(), [True, True])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -2759,7 +2759,7 @@ def unbind(input, axis=0): ...@@ -2759,7 +2759,7 @@ def unbind(input, axis=0):
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors. Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.
Args: Args:
input (Tensor): The input variable which is an N-D Tensor, data type being float16, float32, float64, int32 or int64. input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64.
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0. If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
Returns: Returns:
...@@ -2808,7 +2808,7 @@ def unbind(input, axis=0): ...@@ -2808,7 +2808,7 @@ def unbind(input, axis=0):
check_dtype( check_dtype(
dtype, dtype,
'unbind', 'unbind',
['float16', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'unbind', 'unbind',
) )
outs = [ outs = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册