From 43efb979201914085eab60c8a9746b52ea59bcd4 Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Fri, 14 Apr 2023 16:12:18 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon4=20No58=E3=80=91kthvalue=20(?= =?UTF-8?q?#51615)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../phi/kernels/gpu/kthvalue_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/kthvalue_kernel.cu | 3 +- .../fluid/tests/unittests/test_kthvalue_op.py | 72 ++++++++++++++++++- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu index 599beb7a07a..8d1d23efaba 100644 --- a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/phi/kernels/kthvalue_grad_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -76,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad, double, int, int64_t, + phi::dtype::bfloat16, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/kthvalue_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_kernel.cu index 235abdbc803..2ecec80c27b 100644 --- a/paddle/phi/kernels/gpu/kthvalue_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_kernel.cu @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/kthvalue_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(kthvalue, double, int, int64_t, + phi::dtype::bfloat16, phi::dtype::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } diff --git a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py index b757974126c..66389a870e4 100644 --- a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py +++ b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py @@ -15,10 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid +from paddle.fluid import core def cal_kthvalue(x, k, axis, keepdim=False): @@ -207,5 +208,74 @@ class TestModeOpInStatic(unittest.TestCase): np.testing.assert_allclose(paddle_result, expect_value, rtol=1e-05) +class TestKthvalueFP16Op(OpTest): + def init_args(self): + self.k = 5 + self.axis = -1 + self.keepdim = False + self.input_data = np.random.random((2, 1, 2, 4, 10)) + self.dtype = np.float16 + + def setUp(self): + self.op_type = "kthvalue" + self.python_api = paddle.kthvalue + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': self.keepdim} + output, indices = cal_kthvalue( + self.input_data, k=self.k, axis=self.axis, keepdim=self.keepdim + ) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad({'X'}, 'Out') + + +class TestKthvalueWithKeepdimFP16Op(TestKthvalueFP16Op): + def init_args(self): + self.k = 2 + self.axis = 1 + self.keepdim = True + self.input_data = np.random.random((1, 3, 2, 4, 10)) + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestKthvalueBF16Op(OpTest): + def init_args(self): + self.k = 2 + self.axis = 1 + + def setUp(self): + self.init_args() + self.op_type = 'kthvalue' + self.python_api = paddle.kthvalue + self.dtype = np.uint16 + x = np.random.random((1, 3, 2, 4, 10)) + self.inputs = {'X': convert_float_to_uint16(x)} + self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': True} + out, indices = cal_kthvalue(x, k=self.k, axis=self.axis, keepdim=True) + self.outputs = {'Out': convert_float_to_uint16(out), 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + paddle.enable_static() + place = core.CUDAPlace(0) + self.check_grad_with_place(place, {'X'}, 'Out') + + if __name__ == '__main__': unittest.main() -- GitLab