From d8407c51b79d38c39461b83af1fe2ae926f2bc20 Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Thu, 18 May 2023 14:17:53 +0800 Subject: [PATCH] add fp16 and bf16 for trunc (#53876) --- paddle/phi/kernels/gpu/trunc_grad_kernel.cu | 4 ++- paddle/phi/kernels/gpu/trunc_kernel.cu | 18 ++++++++-- .../fluid/tests/unittests/test_trunc_op.py | 33 ++++++++++++++++++- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/trunc_grad_kernel.cu b/paddle/phi/kernels/gpu/trunc_grad_kernel.cu index 8a88383e6e4..40e1404cd90 100644 --- a/paddle/phi/kernels/gpu/trunc_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/trunc_grad_kernel.cu @@ -52,4 +52,6 @@ PD_REGISTER_KERNEL(trunc_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/trunc_kernel.cu b/paddle/phi/kernels/gpu/trunc_kernel.cu index dfc4f6589e9..bdbdb80a97f 100644 --- a/paddle/phi/kernels/gpu/trunc_kernel.cu +++ b/paddle/phi/kernels/gpu/trunc_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -27,7 +28,10 @@ template class TruncFunctor { public: __device__ TruncFunctor(const T x) : x_(x) {} - __device__ T operator()() { return trunc(x_); } + __device__ T operator()() { + using MPType = typename phi::dtype::MPTypeTrait::Type; + return static_cast(trunc(static_cast(x_))); + } public: const T x_; @@ -78,5 +82,13 @@ void TruncKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - trunc, GPU, ALL_LAYOUT, phi::TruncKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(trunc, + GPU, + ALL_LAYOUT, + phi::TruncKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 7a05a47a245..404403fca04 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle +from paddle.fluid import core paddle.enable_static() @@ -90,5 +91,35 @@ class TestTruncAPI(unittest.TestCase): self.assertRaises(TypeError, paddle.trunc, x) +class TestTruncFP16OP(TestTruncOp): + def init_dtype_type(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestTruncBF16OP(OpTest): + def setUp(self): + self.python_api = paddle.trunc + self.op_type = "trunc" + self.dtype = np.uint16 + np.random.seed(2021) + x = np.random.random((20, 20)).astype("float32") + out = np.trunc(x) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5) + + if __name__ == "__main__": unittest.main() -- GitLab