diff --git a/paddle/phi/kernels/funcs/complex_functors.h b/paddle/phi/kernels/funcs/complex_functors.h index 8b292cb5dc52e24f90ca54c2b62b08f42552c10e..e6ffeb3b5602e9bd090e5df444c1d62e06c84498 100644 --- a/paddle/phi/kernels/funcs/complex_functors.h +++ b/paddle/phi/kernels/funcs/complex_functors.h @@ -110,53 +110,6 @@ struct AbsFunctor>> { int64_t numel_; }; -template -struct AbsGradCUDAFunctor { - HOSTDEVICE inline AbsGradCUDAFunctor() {} - - HOSTDEVICE inline T operator()(const T x, const T dout) const { - T output; - if (x == T(0)) { - output = T(0); - } else { - output = T(dout) * (x / T(std::abs(x))); - } - return output; - } -}; - -template <> -struct AbsGradCUDAFunctor> { - HOSTDEVICE inline AbsGradCUDAFunctor() {} - HOSTDEVICE inline phi::dtype::complex operator()( - const phi::dtype::complex x, const float dout) const { - phi::dtype::complex output; - if (x == phi::dtype::complex(0)) { - output = phi::dtype::complex(0); - } else { - output = phi::dtype::complex(dout) * - (x / phi::dtype::complex(abs(x))); - } - return output; - } -}; - -template <> -struct AbsGradCUDAFunctor> { - HOSTDEVICE inline AbsGradCUDAFunctor() {} - HOSTDEVICE inline phi::dtype::complex operator()( - const phi::dtype::complex x, const double dout) const { - phi::dtype::complex output; - if (x == phi::dtype::complex(0)) { - output = phi::dtype::complex(0); - } else { - output = phi::dtype::complex(dout) * - (x / phi::dtype::complex(abs(x))); - } - return output; - } -}; - template struct AbsGradFunctor { AbsGradFunctor(const dtype::Real* dout, @@ -179,6 +132,28 @@ struct AbsGradFunctor { int64_t numel_; }; +template <> +struct AbsGradFunctor { + AbsGradFunctor(const dtype::Real* dout, + const phi::dtype::bfloat16* x, + phi::dtype::bfloat16* output, + int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + if (x_[idx] == static_cast(0)) { + output_[idx] = static_cast(0); + } else { + output_[idx] = dout_[idx] * (x_[idx] / (abs(x_[idx]))); + } + } + + const dtype::Real* dout_; + const phi::dtype::bfloat16* x_; + phi::dtype::bfloat16* output_; + int64_t numel_; +}; + template <> struct AbsGradFunctor> { AbsGradFunctor(const float* dout, diff --git a/paddle/phi/kernels/gpu/abs_grad_kernel.cu b/paddle/phi/kernels/gpu/abs_grad_kernel.cu index 8edb6b71224d6d4a1b601bc632675efe890dbf86..a1afa8569b2fa9a2b18e2485d520521afe081151 100644 --- a/paddle/phi/kernels/gpu/abs_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_grad_kernel.cu @@ -31,6 +31,7 @@ PD_REGISTER_KERNEL(abs_grad, int, int64_t, phi::dtype::float16, + phi::dtype::bfloat16, complex, complex) { kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); diff --git a/paddle/phi/kernels/gpu/abs_kernel.cu b/paddle/phi/kernels/gpu/abs_kernel.cu index d025f4b61e76315558768d2797cc7d91c6de098f..9f27c986166f478582ba1c39a18253aed2537da9 100644 --- a/paddle/phi/kernels/gpu/abs_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_kernel.cu @@ -16,8 +16,8 @@ #include #include - #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" @@ -36,7 +36,18 @@ struct CudaAbsFunctor>> { }; template -struct CudaAbsFunctor>> { +struct CudaAbsFunctor< + T, + std::enable_if_t>::value && + std::is_same::value>> { + __device__ __forceinline__ T operator()(const T x) const { return abs(x); } +}; + +template +struct CudaAbsFunctor< + T, + std::enable_if_t>::value && + !std::is_same::value>> { __device__ __forceinline__ T operator()(const T x) const { return std::abs(x); } @@ -63,5 +74,6 @@ PD_REGISTER_KERNEL(abs, int, int64_t, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/abs_grad_kernel_impl.h b/paddle/phi/kernels/impl/abs_grad_kernel_impl.h index 9dad40b57c916c0670763802cbaf3fc89d49d0c0..7064eec4f9e99aaadaebef4f73760358cda3b10d 100644 --- a/paddle/phi/kernels/impl/abs_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/abs_grad_kernel_impl.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/kernels/abs_grad_kernel.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" @@ -22,6 +23,70 @@ namespace phi { #if defined(__NVCC__) + +template +struct AbsGradCUDAFunctor { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + + HOSTDEVICE inline T operator()(const T x, const T dout) const { + T output; + if (x == T(0)) { + output = T(0); + } else { + output = T(dout) * (x / T(std::abs(x))); + } + return output; + } +}; + +template <> +struct AbsGradCUDAFunctor { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + + HOSTDEVICE inline phi::dtype::bfloat16 operator()( + const phi::dtype::bfloat16 x, const phi::dtype::bfloat16 dout) const { + phi::dtype::bfloat16 output; + if (x == phi::dtype::bfloat16(0)) { + output = static_cast(0); + } else { + output = (dout) * (x / abs(x)); + } + return output; + } +}; + +template <> +struct AbsGradCUDAFunctor> { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + HOSTDEVICE inline phi::dtype::complex operator()( + const phi::dtype::complex x, const float dout) const { + phi::dtype::complex output; + if (x == phi::dtype::complex(0)) { + output = phi::dtype::complex(0); + } else { + output = phi::dtype::complex(dout) * + (x / phi::dtype::complex(abs(x))); + } + return output; + } +}; + +template <> +struct AbsGradCUDAFunctor> { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + HOSTDEVICE inline phi::dtype::complex operator()( + const phi::dtype::complex x, const double dout) const { + phi::dtype::complex output; + if (x == phi::dtype::complex(0)) { + output = phi::dtype::complex(0); + } else { + output = phi::dtype::complex(dout) * + (x / phi::dtype::complex(abs(x))); + } + return output; + } +}; + template void AbsGradKernelImpl(const GPUContext& dev_ctx, const DenseTensor& x, @@ -30,9 +95,10 @@ void AbsGradKernelImpl(const GPUContext& dev_ctx, std::vector ins = {&x, &dout}; std::vector outs = {dx}; dev_ctx.Alloc(dx); - phi::funcs::AbsGradCUDAFunctor abs_grad_cuda_functor; + AbsGradCUDAFunctor abs_grad_cuda_functor; phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, abs_grad_cuda_functor); } + template void AbsGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 4411fdc3d1006bca44e78e2adee4e1e2b42893e1..913777c2515f43545356dc1e9cc719006aa5a3c5 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -3699,6 +3699,7 @@ def create_test_act_bf16_class( create_test_act_bf16_class(TestRelu) +create_test_act_bf16_class(TestAbs) if __name__ == "__main__": unittest.main()