未验证 提交 336cd205 编写于 作者: 张春乔 提交者: GitHub

suppot fp16 in gather_nd (#50909)

上级 7ffbf7e3
...@@ -254,6 +254,39 @@ class TestGatherNdAPI2(unittest.TestCase): ...@@ -254,6 +254,39 @@ class TestGatherNdAPI2(unittest.TestCase):
expected_output = np.array([[3, 4]]) expected_output = np.array([[3, 4]])
np.testing.assert_allclose(result, expected_output, rtol=1e-05) np.testing.assert_allclose(result, expected_output, rtol=1e-05)
def test_static_fp16_with_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.array(
[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]],
dtype='float16',
)
index = np.array([[0, 1]], dtype='int32')
res_np = np.array([[3, 4]], dtype='float16')
x = paddle.static.data(
name="x", shape=[2, 3, 2], dtype="float16"
)
x.desc.set_need_check_feed(False)
idx = paddle.static.data(
name="index", shape=[1, 2], dtype="int32"
)
idx.desc.set_need_check_feed(False)
y = paddle.gather_nd(x, idx)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": input, "index": index},
fetch_list=[y],
)
np.testing.assert_allclose(res[0], res_np, rtol=1e-05)
def test_imperative(self): def test_imperative(self):
paddle.disable_static() paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]]) input_1 = np.array([[1, 2], [3, 4], [5, 6]])
......
...@@ -3677,7 +3677,7 @@ def gather_nd(x, index, name=None): ...@@ -3677,7 +3677,7 @@ def gather_nd(x, index, name=None):
= [23] = [23]
Args: Args:
x (Tensor): The input Tensor which it's data type should be bool, float32, float64, int32, int64. x (Tensor): The input Tensor which it's data type should be bool, float16, float32, float64, int32, int64.
index (Tensor): The index input with rank > 1, index.shape[-1] <= input.rank. index (Tensor): The index input with rank > 1, index.shape[-1] <= input.rank.
Its dtype should be int32, int64. Its dtype should be int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -3704,7 +3704,15 @@ def gather_nd(x, index, name=None): ...@@ -3704,7 +3704,15 @@ def gather_nd(x, index, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], [
'bool',
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
],
'gather_np', 'gather_np',
) )
check_variable_and_dtype( check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册