diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index 6feee512cc9f4ec411167d1dc26feed1d766787d..385ddb5e521a2e63f9cb3917608ab7de4e8389d5 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -259,6 +259,7 @@ PD_REGISTER_KERNEL(arg_min, GPU, ALL_LAYOUT, phi::ArgMinKernel, + phi::dtype::float16, float, double, int32_t, @@ -270,6 +271,7 @@ PD_REGISTER_KERNEL(arg_max, GPU, ALL_LAYOUT, phi::ArgMaxKernel, + phi::dtype::float16, float, double, int32_t, diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index c11fb3d1e28aaa4fe44fcd695d701bc674c267f1..cbcb4af926951bdd35a038a5188b02c1bf18bfdd 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -68,6 +68,26 @@ class TestCase2(BaseTestCase): self.axis = 0 +@unittest.skipIf(not paddle.is_compiled_with_cuda(), + "FP16 test runs only on GPU") +class TestCase0FP16(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = np.float16 + self.axis = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_cuda(), + "FP16 test runs only on GPU") +class TestCase1FP16(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = np.float16 + self.axis = 1 + + class TestCase2_1(BaseTestCase): def initTestCase(self): self.op_type = 'arg_max' @@ -202,4 +222,5 @@ class BaseTestComplex2_2(OpTest): if __name__ == '__main__': + paddle.enable_static() unittest.main()