From ce014acbdee85825c6f7c0949d74f40b886b2fca Mon Sep 17 00:00:00 2001 From: Zhenghai Zhang <65210872+ccsuzzh@users.noreply.github.com> Date: Mon, 27 Feb 2023 22:31:46 +0800 Subject: [PATCH] fix fp16 dtype checking for argsort op (#50939) --- .../paddle/fluid/tests/unittests/test_argsort_op.py | 13 +++++++++++++ python/paddle/tensor/search.py | 12 ++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index e3f90d7fd2d..2b993280af7 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -499,5 +499,18 @@ class TestArgsortWithInputNaN(unittest.TestCase): paddle.enable_static() +class TestArgsortOpFp16(unittest.TestCase): + def test_fp16(self): + x_np = np.random.random((2, 8)).astype('float16') + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[2, 8], name='x', dtype='float16') + out = paddle.argsort(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_np}, fetch_list=[out]) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 30ddfb13986..a7762a3e786 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -39,7 +39,7 @@ def argsort(x, axis=-1, descending=False, name=None): Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. Args: - x(Tensor): An input N-D Tensor with type float32, float64, int16, + x(Tensor): An input N-D Tensor with type float16, float32, float64, int16, int32, int64, uint8. axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is Rank(x). when axis<0, it works the same way @@ -101,7 +101,15 @@ def argsort(x, axis=-1, descending=False, name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + [ + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'uint8', + ], 'argsort', ) -- GitLab