// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" #ifdef PADDLE_WITH_XPU_KP #define __forceinline__ __inline__ #endif namespace phi { namespace funcs { enum ActBwdOpFwdDeps { kNoDeps = 0x00, // Do not need any forward input/output kDepX = 0x01, // Only need forward input X kDepOut = 0x02, // Only need forward output Out }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; template struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; template <> struct Sine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sin(static_cast(val))); } }; template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; template <> struct Cosine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(cos(static_cast(val))); } }; // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sine(x) = sin(x) template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sine()); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(-1) * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosine()); } }; template struct LogitFunctor { template void operator()(Device d, X x, Out out, P p, float eps) const { // logit(x) = ln(x/(1-x)) auto tmp_x = (x.cwiseMin(static_cast(1.0 - eps))).cwiseMax(static_cast(eps)); if (!eps) { out.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) .select(p.constant(static_cast(NAN)), (tmp_x / (static_cast(1) - tmp_x)).log()); } else { out.device(d) = (tmp_x / (static_cast(1) - tmp_x)).log(); } } }; // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise template struct MishFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); out.device(d) = x * sp.tanh(); } }; // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) // sp = softplus(x) template struct MishGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); auto gsp = static_cast(1) - (-sp).exp(); auto tsp = sp.tanh(); dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dout * a * b * (static_cast(1) - temp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } }; template <> struct Tangent { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(tan(static_cast(val))); } }; // Tangent'(x) = -Tangent(x) template struct TanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout / x.unaryExpr(Cosine()).square(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(2) * x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0.5) * dout / out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.rsqrt(); } }; template struct RsqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(-0.5) * dout * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // // For numerical stability, using the following formula instead of // softplus(x) = // // log(1 + exp(x)) // // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= // threshold(beta = // // 1, threshold = 20 by default), otherwise x template struct SoftplusFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto x_beta = static_cast(beta) * x; out.device(d) = (x_beta > static_cast(threshold)) .select(x, (static_cast(1) + x_beta.exp()).log() / static_cast(beta)); } }; // For numerical stability, using the following formula instead of // d(softplus(x))/dx = 1 / (1 + exp(-x)) // d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta // = 1, threshold = 20 by default), otherwise x template struct SoftplusGradFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto x_beta = static_cast(beta) * x; dx.device(d) = (x_beta > static_cast(threshold)) .select(dout, dout / (static_cast(1) + (-x_beta).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Tangent()); } }; template struct Sinh { HOSTDEVICE T operator()(const T& val) const { return sinh(val); } }; template <> struct Sinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sinhf(static_cast(val))); } }; template struct Cosh { HOSTDEVICE T operator()(const T& val) const { return cosh(val); } }; template <> struct Cosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(coshf(static_cast(val))); } }; // sinh(x) = sinh(x) template struct SinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sinh()); } }; // cosh(x) = cosh(x) template struct CoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosh()); } }; // sinh'(x) = cosh(x) template struct SinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosh'(x) = sinh(x) template struct CoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Sinh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } }; template <> struct Acos { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acos(static_cast(val))); } }; // Acos(x) = acos(x) template struct AcosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acos()); } }; // acos'(x) = -1/sqrt(1-x^2) template struct AcosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asin { HOSTDEVICE T operator()(const T& val) const { return asin(val); } }; template <> struct Asin { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asin(static_cast(val))); } }; // Asin(x) = asin(x) template struct AsinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asin()); } }; // asin'(x) = 1/sqrt(1-x^2) template struct AsinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atan { HOSTDEVICE T operator()(const T& val) const { return atan(val); } }; template <> struct Atan { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atan(static_cast(val))); } }; // Atan(x) = atan(x) template struct AtanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atan()); } }; // atan'(x) = 1 / (1 + x^2) template struct AtanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct LogitGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { // logit(x)' = 1/(x*(1-x)) dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) .select(p.constant(static_cast(0)), dout * (static_cast(1) / ((static_cast(1) - x) * x))); } }; template struct Acosh { HOSTDEVICE T operator()(const T& val) const { return acosh(val); } }; template <> struct Acosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acosh(static_cast(val))); } }; // Acosh(x) = acosh(x) template struct AcoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acosh()); } }; // acosh'(x) = 1/sqrt(x^2 - 1) template struct AcoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * x - static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asinh { HOSTDEVICE T operator()(const T& val) const { return asinh(val); } }; template <> struct Asinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asinh(static_cast(val))); } }; // Asinh(x) = asinh(x) template struct AsinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asinh()); } }; // asinh'(x) = 1/sqrt(x^2 + 1) template struct AsinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x.square() + static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atanh { HOSTDEVICE T operator()(const T& val) const { return atanh(val); } }; template <> struct Atanh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atanh(static_cast(val))); } }; // Atanh(x) = atanh(x) template struct AtanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atanh()); } }; // atanh'(x) = 1/(1 - x^2) template struct AtanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // exp functor // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // expm1(x) = e^x - 1 template struct Expm1Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.expm1(); } }; template struct Expm1GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out + dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { return v > static_cast(0) ? v : static_cast(0); }); } }; template struct ReluCUDAFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (out > static_cast(0)).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct ReluGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct TanhGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, DenseTensor* dOutNew, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "TanhGradGrad")); // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out // * ddx) if (dOutNew) { auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhGradGrad")); auto dout_new = EigenVector::Flatten( GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "TanhGradGrad")); dout_new.device(*d) = static_cast(-1) * dout * static_cast(2) * out * ddx; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "TanhGradGrad")); ddout.device(*d) = (static_cast(1) - out * out) * ddx; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut D_Dout DDx -> TanhTripleGrad -> D_DDx D_DDout d_OutNew D_Dout_new D_Dout = (-2) * Out * DDx * D_Dout_new D_DDx = (1-Out^2)*D_DDout + (-2) * Out * DOut * D_Dout_new D_OutNew = (-2) * Out * DDx * D_DDout + (-2) * DOut * DDx * D_Dout_new Out, DDX, DOut, D_DDOut, D_DOut_New // input D_OutNew, D_DOut, D_DDx // output */ template struct TanhTripleGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, const DenseTensor* d_DDOut, const DenseTensor* d_dOut_New, DenseTensor* d_d_Out, DenseTensor* d_Out_New, DenseTensor* d_DDx) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhTripleGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad")); auto d_ddOut = EigenVector::Flatten( GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); auto d_dOutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); if (d_Out_New) { auto d_OutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad")); d_OutNew.device(*d) = (static_cast(-2) * out * ddx * d_ddOut) - (static_cast(2) * dout * ddx * d_dOutNew); } if (d_d_Out) { auto d_dOut = EigenVector::Flatten( GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad")); d_dOut.device(*d) = static_cast(-2) * out * ddx * d_dOutNew; } if (d_DDx) { auto d_ddx = EigenVector::Flatten( GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad")); d_ddx.device(*d) = (static_cast(1) - (out * out)) * d_ddOut - static_cast(2) * out * dout * d_dOutNew; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct BReluFunctor : public BaseActivationFunctor { float t_min; float t_max; // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; template struct BReluGradFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { if (alpha < 1.f) { out.device(d) = x.cwiseMax(static_cast(alpha) * x); } else { out.device(d) = x.cwiseMin(static_cast(alpha) * x); } } }; template struct LeakyReluGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast(); auto temp2 = (x >= static_cast(0)).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LeakyReluGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { if (ddOut) { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad")); auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); out.device(d) = (x > th).template cast() * x; } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); dx.device(d) = dout * (x > th).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; template struct Relu6GradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(threshold))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x.tanh() * x.tanh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct HardShrinkFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); out.device(d) = x * (temp1 || temp2).template cast(); } }; template struct HardShrinkGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); dx.device(d) = dout * (temp1 || temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast(); auto temp2 = (x < -lambdaT).template cast(); out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; template struct SoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast(); auto temp2 = (x < -lambdaT).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x < static_cast(0)) .select(static_cast(alpha) * (x.exp() - static_cast(1)), x); } }; template struct ELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { // case 1: alpha >= 0 // dx = dout, if out > 0 // dx = dout * (out + alpha), if out <= 0 dx.device(d) = (out > static_cast(0)) .select(dout, dout * (out + static_cast(alpha))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { // case 2: alpha < 0 // dx = dout, if x > 0 // dx = dout * (out + alpha), if x <=0 dx.device(d) = (x > static_cast(0)) .select(dout, dout * static_cast(alpha) * x.exp()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* ddX, DenseTensor* ddOut, const DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad")); if (dX) { auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); dx.device(*d) = ddx * dout * static_cast(alpha) * x.exp() * (x <= static_cast(0)).template cast(); } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * x.exp() * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // silu(x) = x / (1 + exp(-x)) template struct SiluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = static_cast(1) / (static_cast(1) + (-x).exp()); out.device(d) = x * temp; } }; // silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x})) template struct SiluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(1) + (-x).exp(); // 1+e^(-x) auto temp2 = x * (-x).exp(); // x*e^(-x) dx.device(d) = dout * ((static_cast(1) / temp1) * (static_cast(1) + (temp2 / temp1))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x / (static_cast(1) + x.abs()); } }; // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template struct SigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out * (static_cast(1) - out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut -> SigmoidGradGrad -> DOutNew DDX DDOut DDOut = (1-Out)*Out*DDX DOutNew = (1-2*Out)*DOut*DDX */ template struct SigmoidGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, DenseTensor* dOutNew, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidGradGrad")); if (dOutNew) { auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidGradGrad")); auto dout_new = EigenVector::Flatten( GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SigmoidGradGrad")); dout_new.device(*d) = (static_cast(1) - static_cast(2) * out) * dout * ddx; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SigmoidGradGrad")); ddout.device(*d) = (static_cast(1) - out) * out * ddx; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut D_Dout DDx -> SigmoidTripleGrad -> D_DDx D_DDout d_OutNew D_Dout_new D_Dout = (1-2*Out)*DDx*D_Dout_new D_DDx = (1-Out)*Out*D_DDout + (1-2*Out)*DOut*D_Dout_new D_OutNew = (DDx-2*Out*DDx)*D_DDout - 2*DOut*DDx*D_Dout_new Out, DDX, DOut, D_DDOut, D_DOut_New // input D_OutNew, D_DOut, D_DDx // output */ template struct SigmoidTripleGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, const DenseTensor* d_DDOut, const DenseTensor* d_dOut_New, DenseTensor* d_d_Out, DenseTensor* d_Out_New, DenseTensor* d_DDx) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidTripleGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidTripleGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidTripleGrad")); auto d_dOutNew = EigenVector::Flatten(GET_DATA_SAFELY( d_dOut_New, "Input", "D_DOut_New", "SigmoidTripleGrad")); if (d_Out_New) { auto d_OutNew = EigenVector::Flatten(GET_DATA_SAFELY( d_Out_New, "Output", "D_OutNew", "SigmoidTripleGrad")); d_OutNew.device(*d) = -static_cast(2) * dout * ddx * d_dOutNew; if (d_DDOut) { auto d_ddOut = EigenVector::Flatten( GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad")); d_OutNew.device(*d) = (ddx - static_cast(2) * out * ddx) * d_ddOut + d_OutNew; } } if (d_d_Out) { auto d_dOut = EigenVector::Flatten( GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "SigmoidTripleGrad")); d_dOut.device(*d) = (static_cast(1) - static_cast(2) * out) * ddx * d_dOutNew; } if (d_DDx) { auto d_ddx = EigenVector::Flatten( GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "SigmoidTripleGrad")); d_ddx.device(*d) = (static_cast(1) - static_cast(2) * out) * dout * d_dOutNew; if (d_DDOut) { auto d_ddOut = EigenVector::Flatten( GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad")); d_ddx.device(*d) = d_ddx + (static_cast(1) - out) * out * d_ddOut; } } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: // out = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) // = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) // // Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; // Originally: f' = exp(-x) / (1 + exp(-x)) // For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out) const { auto temp = x * static_cast(slope) + static_cast(offset); out.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; template struct HardSigmoidGradFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(1))) .template cast() * static_cast(slope); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log2(x) = logarithm to the base 2 of the elements of x template struct Log2Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log() / static_cast(log(2)); } }; // the gradient of log2(x) is 1/(x*ln(2)) template struct Log2GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * static_cast(log(2))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log10(x) = logarithm to the base 10 of the elements of x template struct Log10Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log() / static_cast(log(10)); } }; // the gradient of log10(x) is 1/(x*ln(10)) template struct Log10GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * static_cast(log(10))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log1p(x) = natural logarithm of x+1 template struct Log1pFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = (static_cast(1) + x).log(); } }; template struct Log1pGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / (x + static_cast(1))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LogGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* ddX, DenseTensor* ddOut, const DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad")); // ddout = ddx / x; dx = -(dout / x) * (ddx / x) // calculate dx first, so ddout can inplace ddx if (dX) { auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad")); auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad")); dx.device(*d) = dout * static_cast(-1) * ddx / (x * x); } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad")); ddout.device(*d) = ddx * static_cast(1) / x; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // HardSwish = min(max(0, x+3), 6) * x / 6 template struct HardSwishFunctor : public BaseActivationFunctor { float threshold; float scale; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x + static_cast(offset)) .cwiseMax(static_cast(0)) .cwiseMin(static_cast(threshold)) * x / static_cast(scale); } }; template struct HardSwishGradFunctor : public BaseActivationFunctor { float threshold; float scale; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = ((x + static_cast(offset)) < static_cast(threshold)) .template cast(); dx.device(d) = dout * (((x + static_cast(offset)) > static_cast(0)).template cast() * (static_cast(2) * x + static_cast(offset)) / static_cast(scale) * tmp + static_cast(1) * (static_cast(1) - tmp)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SwishFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); } }; template struct SwishGradFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); auto out = x * temp1; auto temp2 = temp1 * (static_cast(1) - (static_cast(beta) * out)); dx.device(d) = dout * ((static_cast(beta) * out) + temp2); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.pow(static_cast(factor)); } }; template struct PowGradFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor) - static_cast(1)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // floor(x) = flooring(x) template struct FloorFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.floor(); } }; // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.round(); } }; // ceil(x) = ceiling(x) template struct CeilFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.ceil(); } }; template struct NegativeFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = -x; } }; template struct ZeroGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kNoDeps; } }; template struct SqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* dX, const DenseTensor* ddX, DenseTensor* dOut, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // calculate dy first, so ddy can inplace ddx if (dOut) { auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); dout.device(*d) = dx * ddx * static_cast(-1) / out; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); ddout.device(*d) = ddx * static_cast(0.5) / out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct RsqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* dX, const DenseTensor* ddX, DenseTensor* dOut, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx if (dOut) { auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x < static_cast(0)) .select(static_cast(alpha) * ((x / static_cast(alpha)).exp() - static_cast(1)), x); } }; template struct CELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp_a_pos = static_cast(alpha > 0); auto temp_a_neg = static_cast(alpha <= 0); auto temp_x_pos = (x > static_cast(0)).template cast(); auto temp_x_neg = (x <= static_cast(0)).template cast(); // dx = dout, if alpha > 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 // dx = dout , if alpha < 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 dx.device(d) = dout * temp_a_pos * temp_x_pos + dout * (x / static_cast(alpha)).exp() * temp_a_pos * temp_x_neg + dout * temp_a_neg * temp_x_pos + dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CELUGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* dOut, const DenseTensor* ddX, DenseTensor* dX, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad")); if (dX) { auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad")); dx.device(*d) = ddx * dout / static_cast(alpha) * (x / static_cast(alpha)).exp() * (x <= static_cast(0)).template cast(); } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + (x / static_cast(alpha)).exp() * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SquareGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* dOut, const DenseTensor* ddX, DenseTensor* dX, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad")); // square GradGrad: ddy=2x*ddx, dx=2dy*ddx // calculate dx first, so ddy can inplace ddx if (dX) { auto dx = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad")); dx.device(*d) = ddx * static_cast(2) * dout; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad")); ddout.device(*d) = ddx * static_cast(2) * x; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // relu(x) = max(x, 0) __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : zero; } }; template struct CudaReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // dx = dout * (out > 0) __device__ __forceinline__ T operator()(const T dout, const T out) const { return out > zero ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaCosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cos(x) = cos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cos(x)); } }; template struct CudaCosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * (-sin(x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaExpFunctor : public BaseActivationFunctor { // exp(x) = expf(x) __device__ __forceinline__ T operator()(const T x) const { return static_cast(expf(static_cast(x))); } }; template <> struct CudaExpFunctor : public BaseActivationFunctor { // exp(x) = exp(x) __device__ __forceinline__ double operator()(const double x) const { return exp(x); } }; template struct CudaSeluFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale", &scale}, {"alpha", &alpha}}; } __device__ __forceinline__ T operator()(const T x) const { T res = x; if (res <= zero) { res = alpha * expf(res) - alpha; } res *= scale; return res; } private: float scale; float alpha; T zero = static_cast(0.0f); }; template <> struct CudaSeluFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale", &scale}, {"alpha", &alpha}}; } __device__ __forceinline__ double operator()(const double x) const { double res = x; double alpha_cast = static_cast(alpha); double scale_cast = static_cast(scale); if (res <= zero) { res = alpha_cast * exp(res) - alpha_cast; } res *= scale_cast; return res; } private: float scale; float alpha; double zero = static_cast(0.0f); }; template struct CudaSquareFunctor : public BaseActivationFunctor { // square(x) = x * x __device__ __forceinline__ T operator()(const T x) const { return x * x; } }; template struct CudaSquareGradFunctor : public BaseActivationFunctor { T two = static_cast(2.0f); // dx = dout * 2 * x __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout * two * x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaExpGradFunctor : public BaseActivationFunctor { // dx = dout * out __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // reciprocal(x) = 1 / x __device__ __forceinline__ T operator()(const T x) const { return one / x; } }; template struct CudaReciprocalGradFunctor : public BaseActivationFunctor { // dx = -dout * out^2 __device__ __forceinline__ T operator()(const T dout, const T out) const { return -dout * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaExpm1Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // expm1(x) = expm1(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(expm1(x)); } }; template struct CudaExpm1GradFunctor : public BaseActivationFunctor { // dx = dout * out __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * out + dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sin(x) = sin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sin(x)); } }; template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cos(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cos(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tan(x) = tan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tan(x)); } }; template struct CudaTanGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout / cos(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / (cos(x) * cos(x))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // asin(x) = asin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asin(x)); } }; template struct CudaAsinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // acos(x) = acos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acos(x)); } }; template struct CudaAcosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = -dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaCoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cosh(x) = cosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cosh(x)); } }; template struct CudaCoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * sinh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * sinh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sinh(x) = sinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sinh(x)); } }; template struct CudaSinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cosh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cosh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Acosh(x) = acosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acosh(x)); } }; template struct CudaAcoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1 / sqrt(x^2 - 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x - one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Asinh(x) = asinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asinh(x)); } }; template struct CudaAsinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/sqrt(x^2 + 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x + one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Atanh(x) = atanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atanh(x)); } }; template struct CudaSTanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } // stanh(x) = b * tanh(a * x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType a = static_cast(scale_a); MPType b = static_cast(scale_b); return static_cast(b * tanh(a * x)); } }; template struct CudaSTanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType a = static_cast(scale_a); MPType b = static_cast(scale_b); MPType temp = tanh(a * x); return static_cast(dout * a * b * (one - temp * temp)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSoftplusFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); MPType x_beta = x * beta; return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); } }; template struct CudaSoftplusGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); MPType x_beta = x * beta; return x_beta > t ? arg_dout : static_cast(dout / (one + exp(-x_beta))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/(1- x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / (one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sqrt(x) = sqrt(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sqrt(x)); } }; template struct CudaSqrtGradFunctor : public BaseActivationFunctor { T one_half = static_cast(0.5f); // dx = dout * 0.5 / out __device__ __forceinline__ T operator()(const T dout, const T out) const { return one_half * dout / out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaRsqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // rsqrt(x) = rsqrt(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(rsqrt(x)); } }; template struct CudaRsqrtGradFunctor : public BaseActivationFunctor { T minus_one_half = static_cast(-0.5f); // dx = -0.5 * dout * out^3 __device__ __forceinline__ T operator()(const T dout, const T out) const { return minus_one_half * dout * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // atan(x) = atan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atan(x)); } }; template struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + x^2) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (one + x * x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tanh(x) = tanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tanh(x)); } }; template struct CudaTanhGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout * (1 - out^2) __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * (one - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaBReluFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } // brelu(x) = min(max(x, t_min), t_max) __device__ __forceinline__ T operator()(const T x) const { T t_min_cast = static_cast(t_min); T t_max_cast = static_cast(t_max); T temp_max = x > t_min_cast ? x : t_min_cast; T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast; return temp_min; } }; template struct CudaMishFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise // Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); return static_cast(x * tanh(sp)); } }; template struct CudaMishGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) // sp = softplus(x) // Inputs: args[0], the input dout // args[1], the input x __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); MPType gsp = (x > static_cast(threshold)) ? one : one / (one + exp(-x)); MPType tsp = tanh(sp); return static_cast(dout * (tsp + x * (one - tsp * tsp) * gsp)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaBReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } // dx = (x > t_min && x < t_max) ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { T t_min_cast = static_cast(t_min); T t_max_cast = static_cast(t_max); return (x > t_min_cast && x < t_max_cast) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // thresholded_relu(x) = x > threshold ? x : 0 __device__ __forceinline__ T operator()(const T x) const { return x > static_cast(threshold) ? x : zero; } }; template struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = x > threshold ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { return x > static_cast(threshold) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaRelu6Functor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // relu6(x) = min(max(0, x), 6) __device__ __forceinline__ T operator()(const T x) const { T t = static_cast(threshold); return x <= zero ? zero : (x < t ? x : t); } }; template struct CudaRelu6GradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = (out > 0 && out < t) ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T out) const { T t = static_cast(threshold); return (out > zero && out < t) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // leakyrelu(x) = x > 0 ? x : alpha * x __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : static_cast(alpha) * x; } }; template struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // dx = dout * (x > 0 ? 1 : alpha) __device__ __forceinline__ T operator()(const T dout, const T x) const { return x > zero ? dout : static_cast(alpha) * dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } // softshrink(x) = x - lambda, if x > lambda; // x + lambda, if x < -lambda; // 0, otherwise. __device__ __forceinline__ T operator()(const T x) const { T l = static_cast(lambda); T temp1 = static_cast(x > l); T temp2 = static_cast(x < -l); return temp1 * (x - l) + temp2 * (x + l); } }; template struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } // dx = dout, if x > lambda or x < -lambda else 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { T l = static_cast(lambda); return (x >= -l && x <= l) ? zero : dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanhShrinkFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tanhshrink(x) = x - tanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(x - tanh(x)); } }; template struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * tanh(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * tanh(x) * tanh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaHardShrinkFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x __device__ __forceinline__ T operator()(const T x) const { T t = static_cast(threshold); return (x > -t && x < t) ? zero : x; } }; template struct CudaHardShrinkGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = (x > -threshold && x < threshold) ? 0 : dout __device__ __forceinline__ T operator()(const T dout, const T x) const { T t = static_cast(threshold); return (x > -t && x < t) ? zero : dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaELUFunctor : public BaseActivationFunctor { using CT = typename phi::dtype::MPTypeTrait::Type; CT zero = static_cast(0.0f); CT one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // elu(x) = x, if x > 0 // elu(x) = alpha * (e^x - 1), if x <= 0 __device__ __forceinline__ T operator()(const T arg_x) const { CT x = static_cast(arg_x); CT temp = static_cast(alpha) * (exp(x) - one); CT res = x > zero ? x : temp; return static_cast(res); } }; template struct CudaELUGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // case 1: alpha >= 0 // dx = dout, if out > 0 // dx = dout * (out + alpha), if out <= 0 __device__ __forceinline__ T operator()(T arg_dout, T arg_out) const { MPType dout = static_cast(arg_dout); MPType out = static_cast(arg_out); MPType a = static_cast(alpha); MPType out_pos = static_cast(out > zero); MPType out_neg = static_cast(out <= zero); return static_cast(dout * (out_pos + out_neg * (out + a))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // case 2: alpha < 0 // dx = dout, if x > 0 // dx = dout * (out + alpha), if x <=0 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_out, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType out = static_cast(arg_out); MPType x = static_cast(arg_x); MPType a = static_cast(alpha); MPType x_pos = static_cast(x > zero); MPType x_neg = static_cast(x <= zero); return static_cast(dout * (x_pos + x_neg * (out + a))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSiluFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // silu(x) = x / (1 + exp(-x)) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(x / (one + exp(-x))); } }; template struct CudaSiluGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType temp = one / (one + exp(-x)); return static_cast(dout * (temp * (one + x * (one - temp)))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSoftsignFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // softsign(x) = x / (1 + abs(x)) __device__ __forceinline__ T operator()(const T x) const { // Using abs directly will cause namespace conflict return x / (one + (x > -x ? x : -x)); } }; template struct CudaSoftsignGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + abs(x))^2 __device__ __forceinline__ T operator()(const T dout, const T x) const { // Using abs directly will cause namespace conflict T temp = one + (x > -x ? x : -x); return dout / (temp * temp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSigmoidFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // sigmoid(x) = 1 / (1 + exp(-x)) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(one / (one + exp(-x))); } }; template struct CudaSigmoidGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout * out * (1 - out) __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * out * (one - out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaLogSigmoidFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); // logsigmoid(x) = log(1 / (1 + exp(-x))) // For numerical stability, // logsigmoid(x) = // - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType temp = x > zero ? zero : -x; return static_cast(-temp - log(exp(-temp) + exp(-x - temp))); } }; template struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); // dx = dout * exp(-x) / (1 + exp(-x)) // For numerical stability: // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x, // 0))) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType temp1 = x > zero ? zero : -x; MPType temp2 = exp(-x - temp1); return static_cast(dout * (temp2 / (exp(-temp1) + temp2))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaHardSigmoidFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); T one = static_cast(1.0f); float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } // hard_sigmoid(x) = 0, when x <= -3 // 1, when x >= 3 // x * slope + offset, otherwise __device__ __forceinline__ T operator()(const T x) const { T temp = x * static_cast(slope) + static_cast(offset); T temp_max = temp > zero ? temp : zero; T temp_min = temp_max < one ? temp_max : one; return temp_min; } }; template struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); T one = static_cast(1.0f); float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } // dx = (out > 0 && out < 1) ? dout * slope : 0 __device__ __forceinline__ T operator()(const T dout, const T out) const { return (out > zero && out < one) ? dout * static_cast(slope) : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaLogFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // log(x) = log(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(log(x)); } }; template struct CudaLogGradFunctor : public BaseActivationFunctor { // dx = dout / x __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaLog1pFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // log1p(x) = log(1 + x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(log(one + x)); } }; template struct CudaLog1pGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + x) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (one + x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaLog2Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // log2(x) = log2(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(log2(x)); } }; template struct CudaLog2GradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; T log_two = static_cast(log(static_cast(2.0f))); // dx = dout / (x * log(2)) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (x * log_two); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaLog10Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // log10(x) = log10(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(log10(x)); } }; template struct CudaLog10GradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; T log_ten = static_cast(log(static_cast(10.0f))); // dx = dout / (x * log(10)) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (x * log_ten); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSwishFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } // swish(x) = x / (1 + exp(-beta * x)) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); MPType b = static_cast(beta); return static_cast(x / (one + exp(-b * x))); } }; template struct CudaSwishGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType temp1 = one / (one + exp(-b * x)); MPType out = x * temp1; MPType temp2 = b * out; MPType temp3 = temp1 * (one - temp2); return static_cast(dout * (temp2 + temp3)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaHardSwishFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; float scale; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; } // hard_swish(x) = 0, when x <= -offset // x , when x >= threshold - offset // x * (x + offset) / scale, otherwise // threshold = scale = 6, offset = 3 by default __device__ __forceinline__ T operator()(const T x) const { T t = static_cast(threshold); T temp = x + static_cast(offset); T temp_max = temp > zero ? temp : zero; T temp_min = temp_max < t ? temp_max : t; return temp_min * x / static_cast(scale); } }; template struct CudaHardSwishGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); T one = static_cast(1.0f); T two = static_cast(2.0f); float threshold; float scale; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; } // dx = 0, when x <= -offset // dout , when x >= threshold - offset // dout * (2 * x / scale + offset / scale), otherwise // threshold = scale = 6, offset = 3 by default __device__ __forceinline__ T operator()(const T dout, const T x) const { T o = static_cast(offset); T s = static_cast(scale); T temp1 = static_cast(x + o > zero); T temp2 = static_cast(x + o < static_cast(threshold)); return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaCeilFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // ceil(x) = ceil(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(ceil(x)); } }; template struct CudaFloorFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // floor(x) = floor(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(floor(x)); } }; template struct CudaRoundFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // round(x) = round(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(round(x)); } }; // GradFunctor for ceil, floor and round template struct CudaZeroGradFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T x) const { return static_cast(0.0f); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kNoDeps; } }; template struct CudaCELUFunctor : public BaseActivationFunctor { using CT = typename phi::dtype::MPTypeTrait::Type; CT zero = static_cast(0.0f); CT one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1)) __device__ __forceinline__ T operator()(const T arg_x) const { CT x = static_cast(arg_x); CT temp = static_cast(alpha) * (exp(x / static_cast(alpha)) - one); CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); return static_cast(res); } }; template struct CudaCELUGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType zero = static_cast(0.0f); MPType one = static_cast(1.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // dx = dout, if alpha > 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 // dx = dout , if alpha < 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); MPType a = static_cast(alpha); MPType temp_a_pos = static_cast(alpha > 0.0f); MPType temp_a_neg = static_cast(alpha <= 0.0f); MPType temp_x_pos = static_cast(x > zero); MPType temp_x_neg = static_cast(x <= zero); return static_cast( dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) + temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #endif } // namespace funcs } // namespace phi