From 7003dcaa2f5814da8584d5cf3b9b1a97cffdc8f2 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 21 Apr 2022 14:21:41 +0800 Subject: [PATCH] Support FP16 argmax/argmin kernel (#42038) * support int16 argmax kernel * add fp16 test --- paddle/phi/kernels/gpu/arg_min_max_kernel.cu | 2 ++ .../tests/unittests/test_arg_min_max_op.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index 6feee512cc9..385ddb5e521 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -259,6 +259,7 @@ PD_REGISTER_KERNEL(arg_min, GPU, ALL_LAYOUT, phi::ArgMinKernel, + phi::dtype::float16, float, double, int32_t, @@ -270,6 +271,7 @@ PD_REGISTER_KERNEL(arg_max, GPU, ALL_LAYOUT, phi::ArgMaxKernel, + phi::dtype::float16, float, double, int32_t, diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index c11fb3d1e28..cbcb4af9269 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -68,6 +68,26 @@ class TestCase2(BaseTestCase): self.axis = 0 +@unittest.skipIf(not paddle.is_compiled_with_cuda(), + "FP16 test runs only on GPU") +class TestCase0FP16(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = np.float16 + self.axis = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_cuda(), + "FP16 test runs only on GPU") +class TestCase1FP16(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = np.float16 + self.axis = 1 + + class TestCase2_1(BaseTestCase): def initTestCase(self): self.op_type = 'arg_max' @@ -202,4 +222,5 @@ class BaseTestComplex2_2(OpTest): if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab