diff --git a/paddle/phi/kernels/gpu/trunc_grad_kernel.cu b/paddle/phi/kernels/gpu/trunc_grad_kernel.cu index 8a88383e6e4f0cb6d5eea57115cc4cf57253bb88..40e1404cd900df4d5fac5fd4746645b077ae1e83 100644 --- a/paddle/phi/kernels/gpu/trunc_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/trunc_grad_kernel.cu @@ -52,4 +52,6 @@ PD_REGISTER_KERNEL(trunc_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/trunc_kernel.cu b/paddle/phi/kernels/gpu/trunc_kernel.cu index dfc4f6589e9cf8e5babf8c3583a93862f75e13f4..bdbdb80a97f5c508c1a58afcfa16f1a1cc50d208 100644 --- a/paddle/phi/kernels/gpu/trunc_kernel.cu +++ b/paddle/phi/kernels/gpu/trunc_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -27,7 +28,10 @@ template class TruncFunctor { public: __device__ TruncFunctor(const T x) : x_(x) {} - __device__ T operator()() { return trunc(x_); } + __device__ T operator()() { + using MPType = typename phi::dtype::MPTypeTrait::Type; + return static_cast(trunc(static_cast(x_))); + } public: const T x_; @@ -78,5 +82,13 @@ void TruncKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - trunc, GPU, ALL_LAYOUT, phi::TruncKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(trunc, + GPU, + ALL_LAYOUT, + phi::TruncKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 7a05a47a2458aff2794710ee80189b1508a65abc..404403fca0453316392e712a304d547f423775df 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle +from paddle.fluid import core paddle.enable_static() @@ -90,5 +91,35 @@ class TestTruncAPI(unittest.TestCase): self.assertRaises(TypeError, paddle.trunc, x) +class TestTruncFP16OP(TestTruncOp): + def init_dtype_type(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestTruncBF16OP(OpTest): + def setUp(self): + self.python_api = paddle.trunc + self.op_type = "trunc" + self.dtype = np.uint16 + np.random.seed(2021) + x = np.random.random((20, 20)).astype("float32") + out = np.trunc(x) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5) + + if __name__ == "__main__": unittest.main()