// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" #ifdef PADDLE_WITH_XPU_KP #define __forceinline__ __inline__ #endif namespace phi { namespace funcs { enum ActBwdOpFwdDeps { kNoDeps = 0x00, // Do not need any forward input/output kDepX = 0x01, // Only need forward input X kDepOut = 0x02, // Only need forward output Out }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; template struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; template <> struct Sine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sin(static_cast(val))); } }; template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; template <> struct Cosine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(cos(static_cast(val))); } }; // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sine(x) = sin(x) template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sine()); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / x; } }; // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosine()); } }; template struct LogitFunctor { template void operator()(Device d, X x, Out out, P p, float eps) const { // logit(x) = ln(x/(1-x)) auto tmp_x = (x.cwiseMin(static_cast(1.0 - eps))).cwiseMax(static_cast(eps)); if (!eps) { out.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) .select(p.constant(static_cast(NAN)), (tmp_x / (static_cast(1) - tmp_x)).log()); } else { out.device(d) = (tmp_x / (static_cast(1) - tmp_x)).log(); } } }; // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise template struct MishFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); out.device(d) = x * sp.tanh(); } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } }; template <> struct Tangent { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(tan(static_cast(val))); } }; // Tangent'(x) = -Tangent(x) template struct TanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout / x.unaryExpr(Cosine()).square(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.square(); } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.sqrt(); } }; // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.rsqrt(); } }; // For numerical stability, using the following formula instead of softplus(x) = // log(1 + exp(x)) // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = // 1, threshold = 20 by default), otherwise x template struct SoftplusFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) { auto x_beta = static_cast(beta) * x; out.device(d) = (x_beta > static_cast(threshold)) .select(x, (static_cast(1) + x_beta.exp()).log() / static_cast(beta)); } }; // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Tangent()); } }; template struct Sinh { HOSTDEVICE T operator()(const T& val) const { return sinh(val); } }; template <> struct Sinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sinhf(static_cast(val))); } }; template struct Cosh { HOSTDEVICE T operator()(const T& val) const { return cosh(val); } }; template <> struct Cosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(coshf(static_cast(val))); } }; // sinh(x) = sinh(x) template struct SinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sinh()); } }; // cosh(x) = cosh(x) template struct CoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosh()); } }; // sinh'(x) = cosh(x) template struct SinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosh'(x) = sinh(x) template struct CoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Sinh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } }; template <> struct Acos { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acos(static_cast(val))); } }; // Acos(x) = acos(x) template struct AcosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acos()); } }; // acos'(x) = -1/sqrt(1-x^2) template struct AcosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asin { HOSTDEVICE T operator()(const T& val) const { return asin(val); } }; template <> struct Asin { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asin(static_cast(val))); } }; // Asin(x) = asin(x) template struct AsinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asin()); } }; // asin'(x) = 1/sqrt(1-x^2) template struct AsinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atan { HOSTDEVICE T operator()(const T& val) const { return atan(val); } }; template <> struct Atan { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atan(static_cast(val))); } }; // Atan(x) = atan(x) template struct AtanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atan()); } }; // atan'(x) = 1 / (1 + x^2) template struct AtanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acosh { HOSTDEVICE T operator()(const T& val) const { return acosh(val); } }; template <> struct Acosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acosh(static_cast(val))); } }; // Acosh(x) = acosh(x) template struct AcoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acosh()); } }; // acosh'(x) = 1/sqrt(x^2 - 1) template struct AcoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * x - static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asinh { HOSTDEVICE T operator()(const T& val) const { return asinh(val); } }; template <> struct Asinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asinh(static_cast(val))); } }; // Asinh(x) = asinh(x) template struct AsinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asinh()); } }; // asinh'(x) = 1/sqrt(x^2 + 1) template struct AsinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x.square() + static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atanh { HOSTDEVICE T operator()(const T& val) const { return atanh(val); } }; template <> struct Atanh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atanh(static_cast(val))); } }; // Atanh(x) = atanh(x) template struct AtanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atanh()); } }; // atanh'(x) = 1/(1 - x^2) template struct AtanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // exp functor // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // expm1(x) = e^x - 1 template struct Expm1Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.expm1(); } }; // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { return v > static_cast(0) ? v : static_cast(0); }); } }; template struct ReluCUDAFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (out > static_cast(0)).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct ReluGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct TanhGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, DenseTensor* dOutNew, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "TanhGradGrad")); // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out // * ddx) if (dOutNew) { auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhGradGrad")); auto dout_new = EigenVector::Flatten( GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "TanhGradGrad")); dout_new.device(*d) = static_cast(-1) * dout * static_cast(2) * out * ddx; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "TanhGradGrad")); ddout.device(*d) = (static_cast(1) - out * out) * ddx; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut D_Dout DDx -> TanhTripleGrad -> D_DDx D_DDout d_OutNew D_Dout_new D_Dout = (-2) * Out * DDx * D_Dout_new D_DDx = (1-Out^2)*D_DDout + (-2) * Out * DOut * D_Dout_new D_OutNew = (-2) * Out * DDx * D_DDout + (-2) * DOut * DDx * D_Dout_new Out, DDX, DOut, D_DDOut, D_DOut_New // input D_OutNew, D_DOut, D_DDx // output */ template struct TanhTripleGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, const DenseTensor* d_DDOut, const DenseTensor* d_dOut_New, DenseTensor* d_d_Out, DenseTensor* d_Out_New, DenseTensor* d_DDx) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhTripleGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad")); auto d_ddOut = EigenVector::Flatten( GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); auto d_dOutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); if (d_Out_New) { auto d_OutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad")); d_OutNew.device(*d) = (static_cast(-2) * out * ddx * d_ddOut) - (static_cast(2) * dout * ddx * d_dOutNew); } if (d_d_Out) { auto d_dOut = EigenVector::Flatten( GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad")); d_dOut.device(*d) = static_cast(-2) * out * ddx * d_dOutNew; } if (d_DDx) { auto d_ddx = EigenVector::Flatten( GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad")); d_ddx.device(*d) = (static_cast(1) - (out * out)) * d_ddOut - static_cast(2) * out * dout * d_dOutNew; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct BReluFunctor : public BaseActivationFunctor { float t_min; float t_max; // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; template struct BReluGradFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { out.device(d) = x / (static_cast(1) + x.abs()); } }; template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { if (alpha < 1.f) { out.device(d) = x.cwiseMax(static_cast(alpha) * x); } else { out.device(d) = x.cwiseMin(static_cast(alpha) * x); } } }; template struct LeakyReluGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast(); auto temp2 = (x >= static_cast(0)).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LeakyReluGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { if (ddOut) { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad")); auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); out.device(d) = (x > th).template cast() * x; } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); dx.device(d) = dout * (x > th).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x.tanh() * x.tanh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct HardShrinkFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); out.device(d) = x * (temp1 || temp2).template cast(); } }; template struct HardShrinkGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); dx.device(d) = dout * (temp1 || temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast(); auto temp2 = (x < -lambdaT).template cast(); out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; template struct SoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast(); auto temp2 = (x < -lambdaT).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x < static_cast(0)) .select(static_cast(alpha) * (x.exp() - static_cast(1)), x); } }; template struct ELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { // case 1: alpha >= 0 // dx = dout, if out > 0 // dx = dout * (out + alpha), if out <= 0 dx.device(d) = (out > static_cast(0)) .select(dout, dout * (out + static_cast(alpha))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { // case 2: alpha < 0 // dx = dout, if x > 0 // dx = dout * (out + alpha), if x <=0 dx.device(d) = (x > static_cast(0)) .select(dout, dout * static_cast(alpha) * x.exp()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* ddX, DenseTensor* ddOut, const DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad")); if (dX) { auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); dx.device(*d) = ddx * dout * static_cast(alpha) * x.exp() * (x <= static_cast(0)).template cast(); } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * x.exp() * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // silu(x) = x / (1 + exp(-x)) template struct SiluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = static_cast(1) / (static_cast(1) + (-x).exp()); out.device(d) = x * temp; } }; // silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x})) template struct SiluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(1) + (-x).exp(); // 1+e^(-x) auto temp2 = x * (-x).exp(); // x*e^(-x) dx.device(d) = dout * ((static_cast(1) / temp1) * (static_cast(1) + (temp2 / temp1))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // relu(x) = max(x, 0) __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : zero; } }; template struct CudaReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // dx = dout * (out > 0) __device__ __forceinline__ T operator()(const T dout, const T out) const { return out > zero ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaCosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cos(x) = cos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cos(x)); } }; template struct CudaCosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * (-sin(x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaExpFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // exp(x) = exp(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(exp(x)); } }; template struct CudaSquareFunctor : public BaseActivationFunctor { // square(x) = x * x __device__ __forceinline__ T operator()(const T x) const { return x * x; } }; template struct CudaExpGradFunctor : public BaseActivationFunctor { // dx = dout * out __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // reciprocal(x) = 1 / x __device__ __forceinline__ T operator()(const T x) const { return one / x; } }; template struct CudaExpm1Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // expm1(x) = expm1(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(expm1(x)); } }; template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sin(x) = sin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sin(x)); } }; template struct CudaSoftsignFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // softsign(x) = x / (1 + abs(x)) __device__ __forceinline__ T operator()(const T x) const { return x / (one + abs(x)); } }; template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cos(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cos(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tan(x) = tan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tan(x)); } }; template struct CudaTanGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout / cos(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / (cos(x) * cos(x))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // asin(x) = asin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asin(x)); } }; template struct CudaAsinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // acos(x) = acos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acos(x)); } }; template struct CudaAcosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = -dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaCoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cosh(x) = cosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cosh(x)); } }; template struct CudaCoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * sinh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * sinh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sinh(x) = sinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sinh(x)); } }; template struct CudaSinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cosh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cosh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Acosh(x) = acosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acosh(x)); } }; template struct CudaAcoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1 / sqrt(x^2 - 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x - one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Asinh(x) = asinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asinh(x)); } }; template struct CudaAsinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/sqrt(x^2 + 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x + one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Atanh(x) = atanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atanh(x)); } }; template struct CudaSTanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } // stanh(x) = b * tanh(a * x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType a = static_cast(scale_a); MPType b = static_cast(scale_b); return static_cast(b * tanh(a * x)); } }; template struct CudaSoftplusFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); MPType x_beta = x * beta; return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); } }; template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/(1- x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / (one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sqrt(x) = sqrt(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sqrt(x)); } }; template struct CudaRsqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // rsqrt(x) = rsqrt(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(rsqrt(x)); } }; template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // atan(x) = atan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atan(x)); } }; template struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + x^2) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (one + x * x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tanh(x) = tanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tanh(x)); } }; template struct CudaTanhGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout * (1 - out^2) __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * (one - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaBReluFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } // brelu(x) = min(max(x, t_min), t_max) __device__ __forceinline__ T operator()(const T x) const { T t_min_cast = static_cast(t_min); T t_max_cast = static_cast(t_max); T temp_max = x > t_min_cast ? x : t_min_cast; T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast; return temp_min; } }; template struct CudaMishFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); return static_cast(x * tanh(sp)); } }; template struct CudaBReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } // dx = (x > t_min && x < t_max) ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { T t_min_cast = static_cast(t_min); T t_max_cast = static_cast(t_max); return (x > t_min_cast && x < t_max_cast) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // thresholded_relu(x) = x > threshold ? x : 0 __device__ __forceinline__ T operator()(const T x) const { return x > static_cast(threshold) ? x : zero; } }; template struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = x > threshold ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { return x > static_cast(threshold) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // leakyrelu(x) = x > 0 ? x : alpha * x __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : static_cast(alpha) * x; } }; template struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // dx = dout * (x > 0 ? 1 : alpha) __device__ __forceinline__ T operator()(const T dout, const T x) const { return x > zero ? dout : static_cast(alpha) * dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } // softshrink(x) = x - lambda, if x > lambda; // x + lambda, if x < -lambda; // 0, otherwise. __device__ __forceinline__ T operator()(const T x) const { T l = static_cast(lambda); T temp1 = static_cast(x > l); T temp2 = static_cast(x < -l); return temp1 * (x - l) + temp2 * (x + l); } }; template struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } // dx = dout, if x > lambda or x < -lambda else 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { T l = static_cast(lambda); return (x >= -l && x <= l) ? zero : dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanhShrinkFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tanhshrink(x) = x - tanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(x - tanh(x)); } }; template struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * tanh(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * tanh(x) * tanh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaHardShrinkFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x __device__ __forceinline__ T operator()(const T x) const { T t = static_cast(threshold); return (x > -t && x < t) ? zero : x; } }; template struct CudaHardShrinkGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = (x > -threshold && x < threshold) ? 0 : dout __device__ __forceinline__ T operator()(const T dout, const T x) const { T t = static_cast(threshold); return (x > -t && x < t) ? zero : dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaELUFunctor : public BaseActivationFunctor { using CT = typename phi::dtype::MPTypeTrait::Type; CT zero = static_cast(0.0f); CT one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // elu(x) = x, if x > 0 // elu(x) = alpha * (e^x - 1), if x <= 0 __device__ __forceinline__ T operator()(const T arg_x) const { CT x = static_cast(arg_x); CT temp = static_cast(alpha) * (exp(x) - one); CT res = x > zero ? x : temp; return static_cast(res); } }; template struct CudaELUGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // case 1: alpha >= 0 // dx = dout, if out > 0 // dx = dout * (out + alpha), if out <= 0 __device__ __forceinline__ T operator()(T arg_dout, T arg_out) const { MPType dout = static_cast(arg_dout); MPType out = static_cast(arg_out); MPType a = static_cast(alpha); MPType out_pos = static_cast(out > zero); MPType out_neg = static_cast(out <= zero); return static_cast(dout * (out_pos + out_neg * (out + a))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // case 2: alpha < 0 // dx = dout, if x > 0 // dx = dout * (out + alpha), if x <=0 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_out, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType out = static_cast(arg_out); MPType x = static_cast(arg_x); MPType a = static_cast(alpha); MPType x_pos = static_cast(x > zero); MPType x_neg = static_cast(x <= zero); return static_cast(dout * (x_pos + x_neg * (out + a))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSiluFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // silu(x) = x / (1 + exp(-x)) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(x / (one + exp(-x))); } }; template struct CudaSiluGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType temp = one / (one + exp(-x)); return static_cast(dout * (temp * (one + x * (one - temp)))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #endif } // namespace funcs } // namespace phi