diff --git a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu index 599beb7a07a787cac3c1604ad54bfe71d3e8e656..8d1d23efaba60a8ce6c4013b99d2e2fad8cd1356 100644 --- a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/phi/kernels/kthvalue_grad_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -76,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad, double, int, int64_t, + phi::dtype::bfloat16, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/kthvalue_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_kernel.cu index 235abdbc803c3982210be7cd30d8831cfa349c15..2ecec80c27b242c1555dae2e9c2ce75e4c67e3e7 100644 --- a/paddle/phi/kernels/gpu/kthvalue_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_kernel.cu @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/kthvalue_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(kthvalue, double, int, int64_t, + phi::dtype::bfloat16, phi::dtype::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } diff --git a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py index b757974126cbb2537b2a304163fff11cc4c94fb2..66389a870e46f177acc8e024d66b24a3acbf1203 100644 --- a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py +++ b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py @@ -15,10 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid +from paddle.fluid import core def cal_kthvalue(x, k, axis, keepdim=False): @@ -207,5 +208,74 @@ class TestModeOpInStatic(unittest.TestCase): np.testing.assert_allclose(paddle_result, expect_value, rtol=1e-05) +class TestKthvalueFP16Op(OpTest): + def init_args(self): + self.k = 5 + self.axis = -1 + self.keepdim = False + self.input_data = np.random.random((2, 1, 2, 4, 10)) + self.dtype = np.float16 + + def setUp(self): + self.op_type = "kthvalue" + self.python_api = paddle.kthvalue + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': self.keepdim} + output, indices = cal_kthvalue( + self.input_data, k=self.k, axis=self.axis, keepdim=self.keepdim + ) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad({'X'}, 'Out') + + +class TestKthvalueWithKeepdimFP16Op(TestKthvalueFP16Op): + def init_args(self): + self.k = 2 + self.axis = 1 + self.keepdim = True + self.input_data = np.random.random((1, 3, 2, 4, 10)) + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestKthvalueBF16Op(OpTest): + def init_args(self): + self.k = 2 + self.axis = 1 + + def setUp(self): + self.init_args() + self.op_type = 'kthvalue' + self.python_api = paddle.kthvalue + self.dtype = np.uint16 + x = np.random.random((1, 3, 2, 4, 10)) + self.inputs = {'X': convert_float_to_uint16(x)} + self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': True} + out, indices = cal_kthvalue(x, k=self.k, axis=self.axis, keepdim=True) + self.outputs = {'Out': convert_float_to_uint16(out), 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + paddle.enable_static() + place = core.CUDAPlace(0) + self.check_grad_with_place(place, {'X'}, 'Out') + + if __name__ == '__main__': unittest.main()