/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/phi/kernels/funcs/activation_functor.h" namespace paddle { namespace operators { template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // reciprocal(x) = 1 / x __device__ __forceinline__ T operator()(const T x) const { return one / x; } }; template struct CudaReciprocalGradFunctor : public BaseActivationFunctor { // dx = -dout * out^2 __device__ __forceinline__ T operator()(const T dout, const T out) const { return -dout * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaExpFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // exp(x) = exp(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(exp(x)); } }; template struct CudaExpGradFunctor : public BaseActivationFunctor { // dx = dout * out __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaExpm1Functor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // expm1(x) = expm1(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(expm1(x)); } }; template struct CudaExpm1GradFunctor : public BaseActivationFunctor { // dx = dout * out __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * out + dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaSquareFunctor : public BaseActivationFunctor { // square(x) = x * x __device__ __forceinline__ T operator()(const T x) const { return x * x; } }; template struct CudaSquareGradFunctor : public BaseActivationFunctor { T two = static_cast(2.0f); // dx = dout * 2 * x __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout * two * x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSqrtFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // sqrt(x) = sqrt(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sqrt(x)); } }; template struct CudaSqrtGradFunctor : public BaseActivationFunctor { T one_half = static_cast(0.5f); // dx = dout * 0.5 / out __device__ __forceinline__ T operator()(const T dout, const T out) const { return one_half * dout / out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaRsqrtFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; // rsqrt(x) = rsqrt(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(rsqrt(x)); } }; template struct CudaRsqrtGradFunctor : public BaseActivationFunctor { T minus_one_half = static_cast(-0.5f); // dx = -0.5 * dout * out^3 __device__ __forceinline__ T operator()(const T dout, const T out) const { return minus_one_half * dout * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaSoftReluFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold))) // threshold should not be negative __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType t = static_cast(threshold); MPType temp_min = x < t ? x : t; MPType temp_max = temp_min > -t ? temp_min : -t; return static_cast(log(one + exp(temp_max))); } }; template struct CudaSoftReluGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0 // threshold should not be negative __device__ __forceinline__ T operator()(const T arg_dout, const T arg_out) const { MPType dout = static_cast(arg_dout); MPType out = static_cast(arg_out); MPType t = static_cast(threshold); return (out > -t && out < t) ? static_cast(dout * (one - exp(-out))) : static_cast(0.0f); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaSTanhFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } // stanh(x) = b * tanh(a * x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType a = static_cast(scale_a); MPType b = static_cast(scale_b); return static_cast(b * tanh(a * x)); } }; template struct CudaSTanhGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType a = static_cast(scale_a); MPType b = static_cast(scale_b); MPType temp = tanh(a * x); return static_cast(dout * a * b * (one - temp * temp)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSoftplusFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); MPType x_beta = x * beta; return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); } }; template struct CudaSoftplusGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); MPType x_beta = x * beta; return x_beta > t ? arg_dout : static_cast(dout / (one + exp(-x_beta))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSoftsignFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // softsign(x) = x / (1 + abs(x)) __device__ __forceinline__ T operator()(const T x) const { return x / (one + abs(x)); } }; template struct CudaSoftsignGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + abs(x))^2 __device__ __forceinline__ T operator()(const T dout, const T x) const { T temp = one + abs(x); return dout / (temp * temp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaRelu6Functor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // relu6(x) = min(max(0, x), 6) __device__ __forceinline__ T operator()(const T x) const { T t = static_cast(threshold); return x <= zero ? zero : (x < t ? x : t); } }; template struct CudaRelu6GradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = (out > 0 && out < t) ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T out) const { T t = static_cast(threshold); return (out > zero && out < t) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaMishFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); return static_cast(x * tanh(sp)); } }; template struct CudaMishGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) // sp = softplus(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); MPType gsp = (x > static_cast(threshold)) ? one : one / (one + exp(-x)); MPType tsp = tanh(sp); return static_cast(dout * (tsp + x * (one - tsp * tsp) * gsp)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaCELUFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; CT zero = static_cast(0.0f); CT one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1)) __device__ __forceinline__ T operator()(const T arg_x) const { CT x = static_cast(arg_x); CT temp = static_cast(alpha) * (exp(x / static_cast(alpha)) - one); CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); return static_cast(res); } }; template struct CudaCELUGradFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; MPType zero = static_cast(0.0f); MPType one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // dx = dout, if alpha > 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 // dx = dout , if alpha < 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType a = static_cast(alpha); MPType temp_a_pos = static_cast(alpha > 0.0f); MPType temp_a_neg = static_cast(alpha <= 0.0f); MPType temp_x_pos = static_cast(x > zero); MPType temp_x_neg = static_cast(x <= zero); return static_cast( dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) + temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template class ActivationCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor* x = nullptr; framework::Tensor* out = nullptr; ExtractActivationTensor(ctx, &x, &out); out->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); std::vector ins = {x}; std::vector outs = {out}; auto functor = Functor(); auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } }; template class ActivationGradCudaKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *x, *out, *d_out; framework::Tensor* d_x = nullptr; x = out = d_out = nullptr; ExtractActivationGradTensor(ctx, &x, &out, &d_out, &d_x); d_x->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); auto functor = Functor(); auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } std::vector ins = {d_out}; std::vector outs = {d_x}; if (static_cast(Functor::FwdDeps()) == static_cast(ActBwdOpFwdDeps::kDepOut)) { // Only need forward output Out ins.push_back(out); paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } else if (static_cast(Functor::FwdDeps()) == static_cast(ActBwdOpFwdDeps::kDepX)) { // Only need forward input X ins.push_back(x); paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } else { paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, &outs, functor); } } }; USE_PHI_FUNCTOR(CudaCos) USE_PHI_FUNCTOR(CudaTan) USE_PHI_FUNCTOR(CudaAcos) USE_PHI_FUNCTOR(CudaSin) USE_PHI_FUNCTOR(CudaAsin) USE_PHI_FUNCTOR(CudaAtan) USE_PHI_FUNCTOR(CudaSinh) USE_PHI_FUNCTOR(CudaCosh) USE_PHI_FUNCTOR(CudaAsinh) USE_PHI_FUNCTOR(CudaAcosh) USE_PHI_FUNCTOR(CudaAtanh) USE_PHI_FUNCTOR(CudaTanh) USE_PHI_FUNCTOR(CudaBRelu) USE_PHI_FUNCTOR(CudaLeakyRelu) USE_PHI_FUNCTOR(CudaThresholdedRelu) USE_PHI_FUNCTOR(CudaHardShrink) USE_PHI_FUNCTOR(CudaSoftShrink) USE_PHI_FUNCTOR(CudaTanhShrink) USE_PHI_FUNCTOR(CudaSilu) USE_PHI_FUNCTOR(CudaELU) USE_PHI_FUNCTOR(CudaSigmoid) USE_PHI_FUNCTOR(CudaLogSigmoid) USE_PHI_FUNCTOR(CudaHardSigmoid) USE_PHI_FUNCTOR(CudaLog) USE_PHI_FUNCTOR(CudaLog2) USE_PHI_FUNCTOR(CudaLog10) USE_PHI_FUNCTOR(CudaLog1p) USE_PHI_FUNCTOR(CudaSwish) USE_PHI_FUNCTOR(CudaHardSwish) template using CudaRoundFunctor = phi::funcs::CudaRoundFunctor; template using CudaFloorFunctor = phi::funcs::CudaFloorFunctor; template using CudaCeilFunctor = phi::funcs::CudaCeilFunctor; template using CudaZeroGradFunctor = phi::funcs::CudaZeroGradFunctor; template using CudaELUGradNegativeAlphaFunctor = phi::funcs::CudaELUGradNegativeAlphaFunctor; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ act_type, ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>); \ REGISTER_OP_CUDA_KERNEL( \ act_type##_grad, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>); #define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor, \ grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ act_type, ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>, \ ops::ActivationCudaKernel>); \ REGISTER_OP_CUDA_KERNEL( \ act_type##_grad, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>, \ ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ======================== celu register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor, CudaCELUGradFunctor); REGISTER_OP_CUDA_KERNEL( celu_grad_grad, ops::CELUDoubleGradKernel>, ops::CELUDoubleGradKernel>, ops::CELUDoubleGradKernel>); /* ========================================================================== */ /* =========================== sqrt register ============================= */ REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor, CudaSqrtGradFunctor); REGISTER_OP_CUDA_KERNEL( sqrt_grad_grad, ops::SqrtDoubleGradKernel>, ops::SqrtDoubleGradKernel>, ops::SqrtDoubleGradKernel>, ops::SqrtDoubleGradKernel>); /* ========================================================================== */ /* =========================== rsqrt register ============================= */ REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor, CudaRsqrtGradFunctor); REGISTER_OP_CUDA_KERNEL( rsqrt_grad_grad, ops::RsqrtDoubleGradKernel>, ops::RsqrtDoubleGradKernel>, ops::RsqrtDoubleGradKernel>); /* ========================================================================== */ /* =========================== square register ============================ */ REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor, CudaSquareGradFunctor); REGISTER_OP_CUDA_KERNEL( square_grad_grad, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>, ops::SquareDoubleGradKernel>); /* ========================================================================== */ /* ========================== logit register ============================ */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( logit, ops::LogitKernel, ops::LogitKernel, ops::LogitKernel); REGISTER_OP_CUDA_KERNEL( logit_grad, ops::LogitGradKernel, ops::LogitGradKernel, ops::LogitGradKernel); /* ========================================================================== */ /* ========================== exp register ============================ */ REGISTER_OP_CUDA_KERNEL( exp, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationKernel>, ops::ActivationKernel>, ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( exp_grad, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ========================== expm1 register ============================ */ REGISTER_OP_CUDA_KERNEL( expm1, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>, ops::ActivationCudaKernel>); REGISTER_OP_CUDA_KERNEL( expm1_grad, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>, ops::ActivationGradCudaKernel>); /* ========================================================================== */ #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ CudaSoftShrinkGradFunctor); \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ CudaReciprocalGradFunctor); \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \ __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \ __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); \ __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \ __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, \ CudaTanhShrinkGradFunctor); \ __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \ CudaHardShrinkGradFunctor); \ __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) #ifdef PADDLE_WITH_XPU_KP REGISTER_OP_KERNEL( brelu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( brelu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(ceil, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( ceil_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( celu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( elu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(exp, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( exp_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(floor, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( floor_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( hard_shrink, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( hard_shrink_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( hard_sigmoid, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( hard_sigmoid_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(hard_swish, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( hard_swish_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( leaky_relu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( leaky_relu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(log, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( log_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(log1p, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( log1p_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( logsigmoid, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( logsigmoid_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( reciprocal, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( reciprocal_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( relu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( relu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(relu6, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( relu6_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(sigmoid, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( sigmoid_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(silu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( silu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(soft_relu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( soft_relu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(softplus, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( softplus_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( softshrink, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( softshrink_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(softsign, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( softsign_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(sqrt, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( sqrt_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(square, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( square_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL(swish, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( swish_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); REGISTER_OP_KERNEL( thresholded_relu, KP, plat::XPUPlace, ops::ActivationCudaKernel>); REGISTER_OP_KERNEL( thresholded_relu_grad, KP, plat::XPUPlace, ops::ActivationGradCudaKernel>); #endif // PADDLE_WITH_XPU_KP