diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 431f807efd2ab27dcbd898d75dc2a3c4afc2618f..cf1beb5bc87d3a5aa4ad597009f07112bf81d0aa 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -42,6 +42,29 @@ class TestUnbind(unittest.TestCase): assert np.array_equal(res_1, input_1[0, 0:100]) assert np.array_equal(res_2, input_1[1, 0:100]) + def test_unbind_static_fp16_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([2, 3]).astype("float16") + + x = paddle.static.data(name="x", shape=[2, 3], dtype="float16") + y = paddle.unbind(x) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + }, + fetch_list=[y], + ) + + assert np.array_equal(res[0], input[0, :]) + assert np.array_equal(res[1], input[1, :]) + def test_unbind_dygraph(self): with fluid.dygraph.guard(): np_x = np.random.random([2, 3]).astype("float32") diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index c8147b1cbe2a8981b2679aee3b3e84fba9ce9684..2b4caeff7e56c2615c2b3e03ac07cdceca43b003 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2738,7 +2738,7 @@ def unbind(input, axis=0): Removes a tensor dimension, then split the input tensor into multiple sub-Tensors. Args: - input (Tensor): The input variable which is an N-D Tensor, data type being float32, float64, int32 or int64. + input (Tensor): The input variable which is an N-D Tensor, data type being float16, float32, float64, int32 or int64. axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0. Returns: @@ -2785,7 +2785,10 @@ def unbind(input, axis=0): check_type(input, 'input', (Variable), 'unbind') dtype = helper.input_dtype() check_dtype( - dtype, 'unbind', ['float32', 'float64', 'int32', 'int64'], 'unbind' + dtype, + 'unbind', + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'unbind', ) outs = [ helper.create_variable_for_type_inference(