diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index 6c2b6d15f6f639b00b914b1bc087cc88bc750b25..8d461a2bfb5aa1fba1f94e17b4cca634212fddce 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -254,6 +254,39 @@ class TestGatherNdAPI2(unittest.TestCase): expected_output = np.array([[3, 4]]) np.testing.assert_allclose(result, expected_output, rtol=1e-05) + def test_static_fp16_with_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.array( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]], + dtype='float16', + ) + index = np.array([[0, 1]], dtype='int32') + res_np = np.array([[3, 4]], dtype='float16') + + x = paddle.static.data( + name="x", shape=[2, 3, 2], dtype="float16" + ) + x.desc.set_need_check_feed(False) + idx = paddle.static.data( + name="index", shape=[1, 2], dtype="int32" + ) + idx.desc.set_need_check_feed(False) + + y = paddle.gather_nd(x, idx) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={"x": input, "index": index}, + fetch_list=[y], + ) + + np.testing.assert_allclose(res[0], res_np, rtol=1e-05) + def test_imperative(self): paddle.disable_static() input_1 = np.array([[1, 2], [3, 4], [5, 6]]) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7d82a83e2a0a9a0aaeb59750e133a549d2397c4b..c8147b1cbe2a8981b2679aee3b3e84fba9ce9684 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3677,7 +3677,7 @@ def gather_nd(x, index, name=None): = [23] Args: - x (Tensor): The input Tensor which it's data type should be bool, float32, float64, int32, int64. + x (Tensor): The input Tensor which it's data type should be bool, float16, float32, float64, int32, int64. index (Tensor): The index input with rank > 1, index.shape[-1] <= input.rank. Its dtype should be int32, int64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -3704,7 +3704,15 @@ def gather_nd(x, index, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], + [ + 'bool', + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + ], 'gather_np', ) check_variable_and_dtype(