// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { namespace funcs { enum ActBwdOpFwdDeps { kNoDeps = 0x00, // Do not need any forward input/output kDepX = 0x01, // Only need forward input X kDepOut = 0x02, // Only need forward output Out }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; template struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; template <> struct Sine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sin(static_cast(val))); } }; template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; template <> struct Cosine { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(cos(static_cast(val))); } }; // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sine(x) = sin(x) template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sine()); } }; // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosine()); } }; template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } }; template <> struct Tangent { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(tan(static_cast(val))); } }; // Tangent'(x) = -Tangent(x) template struct TanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout / x.unaryExpr(Cosine()).square(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Tangent()); } }; template struct Sinh { HOSTDEVICE T operator()(const T& val) const { return sinh(val); } }; template <> struct Sinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(sinhf(static_cast(val))); } }; template struct Cosh { HOSTDEVICE T operator()(const T& val) const { return cosh(val); } }; template <> struct Cosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(coshf(static_cast(val))); } }; // sinh(x) = sinh(x) template struct SinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sinh()); } }; // cosh(x) = cosh(x) template struct CoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosh()); } }; // sinh'(x) = cosh(x) template struct SinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosh'(x) = sinh(x) template struct CoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Sinh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } }; template <> struct Acos { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acos(static_cast(val))); } }; // Acos(x) = acos(x) template struct AcosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acos()); } }; // acos'(x) = -1/sqrt(1-x^2) template struct AcosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asin { HOSTDEVICE T operator()(const T& val) const { return asin(val); } }; template <> struct Asin { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asin(static_cast(val))); } }; // Asin(x) = asin(x) template struct AsinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asin()); } }; // asin'(x) = 1/sqrt(1-x^2) template struct AsinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atan { HOSTDEVICE T operator()(const T& val) const { return atan(val); } }; template <> struct Atan { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atan(static_cast(val))); } }; // Atan(x) = atan(x) template struct AtanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atan()); } }; // atan'(x) = 1 / (1 + x^2) template struct AtanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Acosh { HOSTDEVICE T operator()(const T& val) const { return acosh(val); } }; template <> struct Acosh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(acosh(static_cast(val))); } }; // Acosh(x) = acosh(x) template struct AcoshFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acosh()); } }; // acosh'(x) = 1/sqrt(x^2 - 1) template struct AcoshGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * x - static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asinh { HOSTDEVICE T operator()(const T& val) const { return asinh(val); } }; template <> struct Asinh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(asinh(static_cast(val))); } }; // Asinh(x) = asinh(x) template struct AsinhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asinh()); } }; // asinh'(x) = 1/sqrt(x^2 + 1) template struct AsinhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x.square() + static_cast(1)).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atanh { HOSTDEVICE T operator()(const T& val) const { return atanh(val); } }; template <> struct Atanh { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { return dtype::float16(atanh(static_cast(val))); } }; // Atanh(x) = atanh(x) template struct AtanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atanh()); } }; // atanh'(x) = 1/(1 - x^2) template struct AtanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { return v > static_cast(0) ? v : static_cast(0); }); } }; template struct ReluCUDAFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (out > static_cast(0)).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct ReluGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct TanhGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, DenseTensor* dOutNew, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhGradGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "TanhGradGrad")); // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out // * ddx) if (dOutNew) { auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhGradGrad")); auto dout_new = EigenVector::Flatten( GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "TanhGradGrad")); dout_new.device(*d) = static_cast(-1) * dout * static_cast(2) * out * ddx; } if (ddOut) { auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "TanhGradGrad")); ddout.device(*d) = (static_cast(1) - out * out) * ddx; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut D_Dout DDx -> TanhTripleGrad -> D_DDx D_DDout d_OutNew D_Dout_new D_Dout = (-2) * Out * DDx * D_Dout_new D_DDx = (1-Out^2)*D_DDout + (-2) * Out * DOut * D_Dout_new D_OutNew = (-2) * Out * DDx * D_DDout + (-2) * DOut * DDx * D_Dout_new Out, DDX, DOut, D_DDOut, D_DOut_New // input D_OutNew, D_DOut, D_DDx // output */ template struct TanhTripleGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const DenseTensor* Out, const DenseTensor* ddX, const DenseTensor* dOut, const DenseTensor* d_DDOut, const DenseTensor* d_dOut_New, DenseTensor* d_d_Out, DenseTensor* d_Out_New, DenseTensor* d_DDx) const { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhTripleGrad")); auto out = EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad")); auto d_ddOut = EigenVector::Flatten( GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); auto d_dOutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); if (d_Out_New) { auto d_OutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad")); d_OutNew.device(*d) = (static_cast(-2) * out * ddx * d_ddOut) - (static_cast(2) * dout * ddx * d_dOutNew); } if (d_d_Out) { auto d_dOut = EigenVector::Flatten( GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad")); d_dOut.device(*d) = static_cast(-2) * out * ddx * d_dOutNew; } if (d_DDx) { auto d_ddx = EigenVector::Flatten( GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad")); d_ddx.device(*d) = (static_cast(1) - (out * out)) * d_ddOut - static_cast(2) * out * dout * d_dOutNew; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct BReluFunctor : public BaseActivationFunctor { float t_min; float t_max; // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; template struct BReluGradFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { if (alpha < 1.f) { out.device(d) = x.cwiseMax(static_cast(alpha) * x); } else { out.device(d) = x.cwiseMin(static_cast(alpha) * x); } } }; template struct LeakyReluGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast(); auto temp2 = (x >= static_cast(0)).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LeakyReluGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const DenseTensor* X, const DenseTensor* Out, const DenseTensor* ddX, DenseTensor* ddOut, DenseTensor* dOut, DenseTensor* dX) const { if (ddOut) { auto* d = dev.eigen_device(); auto ddx = EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad")); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad")); auto ddout = EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); out.device(d) = (x > th).template cast() * x; } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); dx.device(d) = dout * (x > th).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // relu(x) = max(x, 0) __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : zero; } }; template struct CudaReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); // dx = dout * (out > 0) __device__ __forceinline__ T operator()(const T dout, const T out) const { return out > zero ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaCosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cos(x) = cos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cos(x)); } }; template struct CudaCosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * (-sin(x)) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout * sin(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sin(x) = sin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sin(x)); } }; template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cos(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cos(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tan(x) = tan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tan(x)); } }; template struct CudaTanGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout / cos(x)^2 __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / (cos(x) * cos(x))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // asin(x) = asin(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asin(x)); } }; template struct CudaAsinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // acos(x) = acos(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acos(x)); } }; template struct CudaAcosGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = -dout / sqrt(1 - x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(-dout / sqrt(one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaCoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // cosh(x) = cosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(cosh(x)); } }; template struct CudaCoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * sinh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * sinh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaSinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // sinh(x) = sinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(sinh(x)); } }; template struct CudaSinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // dx = dout * cosh(x) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * cosh(x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAcoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Acosh(x) = acosh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(acosh(x)); } }; template struct CudaAcoshGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1 / sqrt(x^2 - 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x - one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAsinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Asinh(x) = asinh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(asinh(x)); } }; template struct CudaAsinhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/sqrt(x^2 + 1) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / sqrt(x * x + one)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // Atanh(x) = atanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atanh(x)); } }; template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); // dx = dout * 1/(1- x^2) __device__ __forceinline__ T operator()(const T arg_dout, const T arg_x) const { MPType dout = static_cast(arg_dout); MPType x = static_cast(arg_x); return static_cast(dout * one / (one - x * x)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // atan(x) = atan(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(atan(x)); } }; template struct CudaAtanGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout / (1 + x^2) __device__ __forceinline__ T operator()(const T dout, const T x) const { return dout / (one + x * x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaTanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; // tanh(x) = tanh(x) __device__ __forceinline__ T operator()(const T arg_x) const { MPType x = static_cast(arg_x); return static_cast(tanh(x)); } }; template struct CudaTanhGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); // dx = dout * (1 - out^2) __device__ __forceinline__ T operator()(const T dout, const T out) const { return dout * (one - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct CudaBReluFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } // brelu(x) = min(max(x, t_min), t_max) __device__ __forceinline__ T operator()(const T x) const { T t_min_cast = static_cast(t_min); T t_max_cast = static_cast(t_max); T temp_max = x > t_min_cast ? x : t_min_cast; T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast; return temp_min; } }; template struct CudaBReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } // dx = (x > t_min && x < t_max) ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { T t_min_cast = static_cast(t_min); T t_max_cast = static_cast(t_max); return (x > t_min_cast && x < t_max_cast) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // thresholded_relu(x) = x > threshold ? x : 0 __device__ __forceinline__ T operator()(const T x) const { return x > static_cast(threshold) ? x : zero; } }; template struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } // dx = x > threshold ? dout : 0 __device__ __forceinline__ T operator()(const T dout, const T x) const { return x > static_cast(threshold) ? dout : zero; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CudaLeakyReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // leakyrelu(x) = x > 0 ? x : alpha * x __device__ __forceinline__ T operator()(const T x) const { return x > zero ? x : static_cast(alpha) * x; } }; template struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } // dx = dout * (x > 0 ? 1 : alpha) __device__ __forceinline__ T operator()(const T dout, const T x) const { return x > zero ? dout : static_cast(alpha) * dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; #endif } // namespace funcs } // namespace phi