diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 65080f04832d68dc7bb44a992b5d072bb3384813..5c26ab9bbbc8212fc43d3213cd5d8ae346d42ba1 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1641,9 +1641,7 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker, ops::LogitGradOpMaker, ops::LogitGradOpMaker); REGISTER_OPERATOR(logit_grad, ops::LogitGradOp); -REGISTER_OP_CPU_KERNEL( - logit_grad, ops::LogitGradKernel, - ops::LogitGradKernel); + /* ========================================================================== */ /* ======================== celu register ============================ @@ -1692,7 +1690,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); REGISTER_OP_CPU_KERNEL( sqrt_grad_grad, ops::SqrtDoubleGradKernel>, @@ -1720,7 +1717,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); REGISTER_OP_CPU_KERNEL( rsqrt_grad_grad, ops::RsqrtDoubleGradKernel::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL(square, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - square_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); - REGISTER_OP_CPU_KERNEL( square_grad_grad, ops::SquareDoubleGradKernel, std::conditional>(), ops::ActFwdInplaceInferer, void>::type); -REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad, - ops::ActivationGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - expm1_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>, - ops::ActivationGradKernel>); /* ========================================================================== */ /* ========================== Log register ==================================*/ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 98c73c3cb3f4f0b5fc7d761aee352f0c51e8794d..d3068ec6c79d02f5a53bada6079728091e7359c6 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -287,6 +287,16 @@ USE_PHI_FUNCTOR(Silu) USE_PHI_FUNCTOR(ELU) USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU) +USE_PHI_FUNCTOR(Expm1) +USE_PHI_FUNCTOR(Mish) +USE_PHI_FUNCTOR(STanh) +USE_PHI_FUNCTOR(Reciprocal) +USE_PHI_FUNCTOR(Square) +USE_PHI_FUNCTOR(Sqrt) +USE_PHI_FUNCTOR(Rsqrt) +USE_PHI_FUNCTOR(Softplus) +USE_PHI_FUNCTOR(Softsign) + template using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; @@ -454,26 +464,62 @@ template using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; template -struct SqrtGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = static_cast(0.5) * dout / out; +struct SqrtGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, const framework::Tensor* Out, + const framework::Tensor* ddX, framework::Tensor* ddOut, + framework::Tensor* dOut, const framework::Tensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); + // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx + // calculate dy first, so ddy can inplace ddx + if (dOut) { + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); + dout.device(*d) = dx * ddx * static_cast(-1) / out; + } + if (ddOut) { + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); + ddout.device(*d) = ddx * static_cast(0.5) / out; + } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template -struct RsqrtGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = static_cast(-0.5) * dout * out * out * out; - } +struct RsqrtGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, const framework::Tensor* Out, + const framework::Tensor* ddX, framework::Tensor* ddOut, + framework::Tensor* dOut, const framework::Tensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); + // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx + if (dOut) { + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); + dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; + } + if (ddOut) { + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); + ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; + } + } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } @@ -519,19 +565,6 @@ struct RoundFunctor : public BaseActivationFunctor { } }; -template -struct ReciprocalGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(-1) * out * out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { @@ -614,17 +647,6 @@ struct Log1pGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct SquareGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(2) * x; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { @@ -706,66 +728,6 @@ struct HardSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// For numerical stability, using the following formula instead of -// d(softplus(x))/dx = 1 / (1 + exp(-x)) -// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta -// = 1, threshold = 20 by default), otherwise x -template -struct SoftplusGradFunctor : public BaseActivationFunctor { - float beta; - float threshold; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}, {"threshold", &threshold}}; - } - - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) { - auto x_beta = static_cast(beta) * x; - dx.device(d) = - (x_beta > static_cast(threshold)) - .select(dout, dout / (static_cast(1) + (-x_beta).exp())); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) -// sp = softplus(x) -template -struct MishGradFunctor : public BaseActivationFunctor { - float threshold; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) { - auto sp = (x > static_cast(threshold)) - .select(x, (static_cast(1) + x.exp()).log()); - auto gsp = static_cast(1) - (-sp).exp(); - auto tsp = sp.tanh(); - dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// d(softsign(x))/dx = 1 / (1 + |x|)^2 -// Taken from https://en.wikipedia.org/wiki/Activation_function -template -struct SoftsignGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) { - dx.device(d) = - dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; @@ -909,38 +871,6 @@ struct PowGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct LogitGradFunctor { - template - void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { - // logit(x)' = 1/(x*(1-x)) - dx.device(d) = - (x < static_cast(eps) || x > static_cast(1.0 - eps)) - .select(p.constant(static_cast(0)), - dout * (static_cast(1) / ((static_cast(1) - x) * x))); - } -}; - -template -struct STanhGradFunctor : public BaseActivationFunctor { - float scale_a; - float scale_b; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; - } - - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto a = static_cast(scale_a); - auto b = static_cast(scale_b); - auto temp = (a * x).tanh() * (a * x).tanh(); - dx.device(d) = dout * a * b * (static_cast(1) - temp); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; @@ -1071,68 +1001,6 @@ struct CELUGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct SqrtGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* Out, - const framework::Tensor* ddX, framework::Tensor* ddOut, - framework::Tensor* dOut, const framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); - // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx - // calculate dy first, so ddy can inplace ddx - if (dOut) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); - dout.device(*d) = dx * ddx * static_cast(-1) / out; - } - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); - ddout.device(*d) = ddx * static_cast(0.5) / out; - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct RsqrtGradGradFunctor : public BaseActivationFunctor { - template - void operator()(const Device& dev, const framework::Tensor* Out, - const framework::Tensor* ddX, framework::Tensor* ddOut, - framework::Tensor* dOut, const framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); - - // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx - if (dOut) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); - dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; - } - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); - ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - template struct SquareGradGradFunctor : public BaseActivationFunctor { template @@ -1718,24 +1586,7 @@ class PowGradKernel template class LogitGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* dout = - context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - auto eps = context.Attr("eps"); - dx->mutable_data(dout->place()); - - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_dout = framework::EigenVector::Flatten(*dout); - auto eigen_dx = framework::EigenVector::Flatten(*dx); - auto& place = - *context.template device_context().eigen_device(); - auto eigen_p = framework::EigenVector::Flatten(*x); - - LogitGradFunctor functor; - functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps); - } + void Compute(const framework::ExecutionContext& context) const override {} }; template @@ -1776,17 +1627,12 @@ struct LogGradGradFunctor : public BaseActivationFunctor { __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ - __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ __macro(log2, Log2, Log2Functor, Log2GradFunctor); \ __macro(log10, Log10, Log10Functor, Log10GradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ - __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ - __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ HardSigmoidGradFunctor); \ __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ - __macro(mish, Mish, MishFunctor, MishGradFunctor); \ __macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor); diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index e63bf6ec47ef0838d4f42419eafc6af10518335b..b441f63b500b11e043b300711b23311288fce025 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -128,18 +128,6 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaReciprocalGradFunctor : public BaseActivationFunctor { - // dx = -dout * out^2 - __device__ __forceinline__ T operator()(const T dout, const T out) const { - return -dout * out * out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - template struct CudaLogFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -161,46 +149,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaSquareGradFunctor : public BaseActivationFunctor { - T two = static_cast(2.0f); - - // dx = dout * 2 * x - __device__ __forceinline__ T operator()(const T dout, const T x) const { - return dout * two * x; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct CudaSqrtGradFunctor : public BaseActivationFunctor { - T one_half = static_cast(0.5f); - - // dx = dout * 0.5 / out - __device__ __forceinline__ T operator()(const T dout, const T out) const { - return one_half * dout / out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct CudaRsqrtGradFunctor : public BaseActivationFunctor { - T minus_one_half = static_cast(-0.5f); - - // dx = -0.5 * dout * out^3 - __device__ __forceinline__ T operator()(const T dout, const T out) const { - return minus_one_half * dout * out * out * out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - template struct CudaLog1pFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -320,69 +268,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaSTanhGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float scale_a; - float scale_b; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; - } - - // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - MPType a = static_cast(scale_a); - MPType b = static_cast(scale_b); - MPType temp = tanh(a * x); - return static_cast(dout * a * b * (one - temp * temp)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct CudaSoftplusGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float beta; - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}, {"threshold", &threshold}}; - } - - // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - MPType b = static_cast(beta); - MPType t = static_cast(threshold); - MPType x_beta = x * beta; - return x_beta > t ? arg_dout : static_cast(dout / (one + exp(-x_beta))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct CudaSoftsignGradFunctor : public BaseActivationFunctor { - T one = static_cast(1.0f); - - // dx = dout / (1 + abs(x))^2 - __device__ __forceinline__ T operator()(const T dout, const T x) const { - T temp = one + abs(x); - return dout / (temp * temp); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaRelu6Functor : public BaseActivationFunctor { T zero = static_cast(0.0f); @@ -506,34 +391,6 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaMishGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) - // sp = softplus(x) - // Inputs: args[0], the input dout - // args[1], the input x - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); - MPType gsp = - (x > static_cast(threshold)) ? one : one / (one + exp(-x)); - MPType tsp = tanh(sp); - return static_cast(dout * (tsp + x * (one - tsp * tsp) * gsp)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaHardSwishFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); @@ -831,8 +688,6 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== sqrt register ============================= */ -REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor, - CudaSqrtGradFunctor); REGISTER_OP_CUDA_KERNEL( sqrt_grad_grad, @@ -848,8 +703,6 @@ REGISTER_OP_CUDA_KERNEL( /* =========================== rsqrt register ============================= */ -REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor, - CudaRsqrtGradFunctor); REGISTER_OP_CUDA_KERNEL( rsqrt_grad_grad, @@ -862,8 +715,6 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== square register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor, - CudaSquareGradFunctor); REGISTER_OP_CUDA_KERNEL( square_grad_grad, @@ -900,53 +751,12 @@ REGISTER_OP_CUDA_KERNEL( /* ========================== logit register ============================ */ namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - logit_grad, - ops::LogitGradKernel, - ops::LogitGradKernel, - ops::LogitGradKernel); /* ========================================================================== */ /* ========================== exp register ============================ */ -REGISTER_OP_CUDA_KERNEL( - exp, ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationCudaKernel>); -REGISTER_OP_CUDA_KERNEL( - exp_grad, ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ========================== expm1 register ============================ */ - -REGISTER_OP_CUDA_KERNEL( - expm1, ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>); -REGISTER_OP_CUDA_KERNEL( - expm1_grad, ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>, - ops::ActivationGradCudaKernel>); /* ========================================================================== */ /* ========================== Log register ==================================*/ @@ -969,15 +779,10 @@ REGISTER_OP_CUDA_KERNEL( __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \ __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \ __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ - __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ - CudaReciprocalGradFunctor); \ __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \ __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \ __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ - __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \ - __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \ - __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); \ __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \ __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, \ CudaTanhShrinkGradFunctor); \ @@ -986,7 +791,6 @@ REGISTER_OP_CUDA_KERNEL( __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \ CudaHardSigmoidGradFunctor); \ __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \ - __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \ __macro(hard_swish, HardSwish, CudaHardSwishFunctor, \ CudaHardSwishGradFunctor); FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index d3d6989696b18a67e7b974bec70a80006254d51b..16e73f77bc112f2d25cbf9326e911a5cd3eb13b9 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -103,8 +103,14 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, AcoshGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, SquareGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, SoftsignGradFunctor); + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, Expm1GradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, ReciprocalGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, RsqrtGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor); @@ -122,11 +128,25 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, HardShrinkGradFunctor, threshold); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, + MishGradFunctor, + threshold); + DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, BReluGradFunctor, t_min, t_max); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, + STanhGradFunctor, + scale_a, + scale_b); + +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, + SoftplusGradFunctor, + beta, + threshold); + template void EluGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -190,6 +210,13 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad, ReluDoubleGradKernel) @@ -223,3 +250,14 @@ PD_REGISTER_KERNEL(expm1_grad, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL( + logit_grad, CPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {} +PD_REGISTER_KERNEL(square_grad, + CPU, + ALL_LAYOUT, + phi::SquareGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index eb98ee2ea2637aeaaf5da587fad4101eeb1d2f6e..591793e6e49007d99616e1c4a7ad4bec47d76ef6 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -115,6 +115,22 @@ struct ReciprocalFunctor : public BaseActivationFunctor { } }; +template +struct ReciprocalGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(-1) * out * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { @@ -157,9 +173,9 @@ struct LogitFunctor { } }; -// // mish(x) = x * tanh(softplus(x)) -// // softplus(x) = x, if x > threshold -// // = ln(1 + exp(x)), otherwise +// mish(x) = x * tanh(softplus(x)) +// softplus(x) = x, if x > threshold +// = ln(1 + exp(x)), otherwise template struct MishFunctor : public BaseActivationFunctor { @@ -176,6 +192,33 @@ struct MishFunctor : public BaseActivationFunctor { } }; +// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) +// sp = softplus(x) + +template +struct MishGradFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto sp = (x > static_cast(threshold)) + .select(x, (static_cast(1) + x.exp()).log()); + auto gsp = static_cast(1) - (-sp).exp(); + auto tsp = sp.tanh(); + dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct STanhFunctor : public BaseActivationFunctor { float scale_a; @@ -191,6 +234,29 @@ struct STanhFunctor : public BaseActivationFunctor { } }; +template +struct STanhGradFunctor : public BaseActivationFunctor { + float scale_a; + float scale_b; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto a = static_cast(scale_a); + auto b = static_cast(scale_b); + auto temp = (a * x).tanh() * (a * x).tanh(); + dx.device(d) = dout * a * b * (static_cast(1) - temp); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } @@ -227,6 +293,20 @@ struct SquareFunctor : public BaseActivationFunctor { } }; +template +struct SquareGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(2) * x; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { @@ -236,6 +316,22 @@ struct SqrtFunctor : public BaseActivationFunctor { } }; +template +struct SqrtGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = static_cast(0.5) * dout / out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { @@ -245,29 +341,28 @@ struct RsqrtFunctor : public BaseActivationFunctor { } }; +template +struct RsqrtGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = static_cast(-0.5) * dout * out * out * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // // For numerical stability, using the following formula instead of // softplus(x) = // // log(1 + exp(x)) // // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= // threshold(beta = // // 1, threshold = 20 by default), otherwise x -// template -// struct SoftplusFunctor : public BaseActivationFunctor { -// float beta; -// float threshold; -// typename BaseActivationFunctor::AttrPair GetAttrs() { -// return {{"beta", &beta}, {"threshold", &threshold}}; -// } - -// template -// void operator()(Device d, X x, Out out) { -// auto x_beta = static_cast(beta) * x; -// out.device(d) = (x_beta > static_cast(threshold)) -// .select(x, -// (static_cast(1) + x_beta.exp()).log() / -// static_cast(beta)); -// } -// }; template struct SoftplusFunctor : public BaseActivationFunctor { @@ -288,6 +383,33 @@ struct SoftplusFunctor : public BaseActivationFunctor { } }; +// For numerical stability, using the following formula instead of +// d(softplus(x))/dx = 1 / (1 + exp(-x)) +// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta +// = 1, threshold = 20 by default), otherwise x + +template +struct SoftplusGradFunctor : public BaseActivationFunctor { + float beta; + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto x_beta = static_cast(beta) * x; + dx.device(d) = + (x_beta > static_cast(threshold)) + .select(dout, dout / (static_cast(1) + (-x_beta).exp())); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { @@ -479,6 +601,18 @@ struct AtanGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct LogitGradFunctor { + template + void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { + // logit(x)' = 1/(x*(1-x)) + dx.device(d) = + (x < static_cast(eps) || x > static_cast(1.0 - eps)) + .select(p.constant(static_cast(0)), + dout * (static_cast(1) / ((static_cast(1) - x) * x))); + } +}; + template struct Acosh { HOSTDEVICE T operator()(const T& val) const { return acosh(val); } @@ -868,6 +1002,37 @@ struct SoftsignFunctor : public BaseActivationFunctor { } }; +// template +// struct SoftsignGradFunctor : public BaseActivationFunctor { +// template +// void operator()(Device d, X x, Out out, dOut dout, dX dx) { +// dx.device(d) = +// dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); +// } + +// static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; +// } +// }; + +// d(softsign(x))/dx = 1 / (1 + |x|)^2 +// Taken from https://en.wikipedia.org/wiki/Activation_function + +template +struct SoftsignGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = + dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; @@ -1270,6 +1435,18 @@ struct CudaSquareFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T x) const { return x * x; } }; +template +struct CudaSquareGradFunctor : public BaseActivationFunctor { + T two = static_cast(2.0f); + + // dx = dout * 2 * x + __device__ __forceinline__ T operator()(const T dout, const T x) const { + return dout * two * x; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaExpGradFunctor : public BaseActivationFunctor { // dx = dout * out @@ -1290,6 +1467,18 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T x) const { return one / x; } }; +template +struct CudaReciprocalGradFunctor : public BaseActivationFunctor { + // dx = -dout * out^2 + __device__ __forceinline__ T operator()(const T dout, const T out) const { + return -dout * out * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaExpm1Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1334,6 +1523,19 @@ struct CudaSoftsignFunctor : public BaseActivationFunctor { } }; +template +struct CudaSoftsignGradFunctor : public BaseActivationFunctor { + T one = static_cast(1.0f); + + // dx = dout / (1 + abs(x))^2 + __device__ __forceinline__ T operator()(const T dout, const T x) const { + T temp = one + abs(x); + return dout / (temp * temp); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSinGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1564,6 +1766,31 @@ struct CudaSTanhFunctor : public BaseActivationFunctor { } }; +template +struct CudaSTanhGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float scale_a; + float scale_b; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType a = static_cast(scale_a); + MPType b = static_cast(scale_b); + MPType temp = tanh(a * x); + return static_cast(dout * a * b * (one - temp * temp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSoftplusFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1585,6 +1812,31 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor { } }; +template +struct CudaSoftplusGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * beta; + return x_beta > t ? arg_dout : static_cast(dout / (one + exp(-x_beta))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1611,6 +1863,20 @@ struct CudaSqrtFunctor : public BaseActivationFunctor { } }; +template +struct CudaSqrtGradFunctor : public BaseActivationFunctor { + T one_half = static_cast(0.5f); + + // dx = dout * 0.5 / out + __device__ __forceinline__ T operator()(const T dout, const T out) const { + return one_half * dout / out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaRsqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1622,6 +1888,20 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor { } }; +template +struct CudaRsqrtGradFunctor : public BaseActivationFunctor { + T minus_one_half = static_cast(-0.5f); + + // dx = -0.5 * dout * out^3 + __device__ __forceinline__ T operator()(const T dout, const T out) const { + return minus_one_half * dout * out * out * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaAtanFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -1710,6 +1990,34 @@ struct CudaMishFunctor : public BaseActivationFunctor { } }; +template +struct CudaMishGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) + // sp = softplus(x) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); + MPType gsp = + (x > static_cast(threshold)) ? one : one / (one + exp(-x)); + MPType tsp = tanh(sp); + return static_cast(dout * (tsp + x * (one - tsp * tsp) * gsp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaBReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 00c08ff497dbd8ff630f4e5b093291b91023cb2d..3c4d5420032c1c18ce6f83396b6ec90ad189e1d7 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -157,8 +157,14 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, CudaAcoshGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, CudaSoftsignGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, CudaSquareGradFunctor); + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, CudaExpGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, CudaExpm1GradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, CudaReciprocalGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, CudaSqrtGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, CudaRsqrtGradFunctor); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, CudaLeakyReluGradFunctor, @@ -173,11 +179,25 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, CudaHardShrinkGradFunctor, threshold); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, + CudaMishGradFunctor, + threshold); + DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, CudaBReluGradFunctor, t_min, t_max); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, + CudaSTanhGradFunctor, + scale_a, + scale_b); + +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, + CudaSoftplusGradFunctor, + beta, + threshold); + template void EluGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -266,6 +286,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) PD_REGISTER_KERNEL(exp_grad, GPU, @@ -290,3 +315,14 @@ PD_REGISTER_KERNEL(expm1_grad, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL( + logit_grad, GPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {} +PD_REGISTER_KERNEL(square_grad, + GPU, + ALL_LAYOUT, + phi::SquareGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index a95f49c0e7cfd32802f1d1899a1fe1590fdf6a87..5692d73d8b70688b330d57fa4ec6249f4139415d 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -222,4 +222,22 @@ void EluDoubleGradKernel(const Context& dev_ctx, functor(dev_ctx, &x, &ddx, ddout, &dout, dx); } +template +void LogitGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + float eps, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + + auto eigen_x = EigenVector::Flatten(x); + auto eigen_dout = EigenVector::Flatten(out_grad); + auto eigen_dx = EigenVector::Flatten(*x_grad); + auto& place = *dev_ctx.eigen_device(); + auto eigen_p = EigenVector::Flatten(x); + + funcs::LogitGradFunctor functor; + functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps); +} + } // namespace phi