From f3aec87117a287c7b5f2e2cd770f0e4155d2cc45 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Mon, 27 Feb 2023 15:12:17 +0800 Subject: [PATCH] [bug fix] fix fp16 dtype checking for argmax op (#50811) * fix fp16 dtype checking for argmax op * run fp16 test when place is gpu * Update search.py fix doc --- .../fluid/tests/unittests/test_arg_min_max_v2_op.py | 13 +++++++++++++ python/paddle/tensor/search.py | 12 ++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index f5cb975019c..d23648ba65f 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -366,5 +366,18 @@ class TestArgMinMaxOpError(unittest.TestCase): self.assertRaises(ValueError, test_argmin_dtype_type) +class TestArgMaxOpFp16(unittest.TestCase): + def test_fp16(self): + x_np = np.random.random((10, 16)).astype('float16') + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 16], name='x', dtype='float16') + out = paddle.argmax(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 2b6b4671efb..30ddfb13986 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -127,7 +127,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): element along the provided axis. Args: - x(Tensor): An input N-D Tensor with type float32, float64, int16, + x(Tensor): An input N-D Tensor with type float16, float32, float64, int16, int32, int64, uint8. axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way @@ -185,7 +185,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + [ + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'uint8', + ], 'paddle.argmax', ) check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') -- GitLab