未验证 提交 553630aa 编写于 作者: L Leo Chen 提交者: GitHub

unbind support bool dtype (#52080)

* unbind support bool dtype

* replace np.array_equal
上级 a6449634
......@@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind,
CPU,
ALL_LAYOUT,
phi::UnbindKernel,
bool,
float,
double,
phi::dtype::float16,
......
......@@ -21,6 +21,7 @@ PD_REGISTER_KERNEL(unbind,
GPU,
ALL_LAYOUT,
phi::UnbindKernel,
bool,
float,
double,
phi::dtype::float16,
......
......@@ -201,6 +201,7 @@ class TestUnbindOp4(TestUnbindOp):
class TestUnbindBF16Op(OpTest):
def setUp(self):
paddle.disable_static()
self._set_op_type()
self.python_api = paddle.unbind
self.dtype = self.get_dtype()
......@@ -247,5 +248,13 @@ class TestUnbindAxisError(unittest.TestCase):
self.assertRaises(ValueError, test_invalid_axis)
class TestUnbindBool(unittest.TestCase):
def test_bool(self):
x = paddle.to_tensor([[True, True], [False, False]])
xs = paddle.unbind(x, axis=0)
self.assertEqual(len(xs), 2)
np.testing.assert_array_equal(xs[0].numpy(), [True, True])
if __name__ == '__main__':
unittest.main()
......@@ -2759,7 +2759,7 @@ def unbind(input, axis=0):
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.
Args:
input (Tensor): The input variable which is an N-D Tensor, data type being float16, float32, float64, int32 or int64.
input (Tensor): The input variable which is an N-D Tensor, data type being bool, float16, float32, float64, int32 or int64.
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
Returns:
......@@ -2808,7 +2808,7 @@ def unbind(input, axis=0):
check_dtype(
dtype,
'unbind',
['float16', 'float32', 'float64', 'int32', 'int64'],
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'unbind',
)
outs = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册