/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif #include "paddle/phi/kernels/funcs/activation_functor.h" namespace paddle { namespace operators { using framework::To32BitIndex; using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps; /* The following operator can be used to process SelectedRows, because the * output of those operator for zero is zero too. */ static std::unordered_set CanBeUsedBySelectedRows = { "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"}; inline void ExtractActivationTensor(const framework::ExecutionContext& context, const framework::Tensor** X, framework::Tensor** Out) { auto x_var = context.InputVar("X"); auto out_var = context.OutputVar("Out"); PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( "Cannot get input Variable X, variable name = %s", context.InputName("X"))); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get output Variable Out, variable name = %s", context.OutputName("Out"))); if (CanBeUsedBySelectedRows.count(context.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( out_var); } else { *X = context.Input("X"); *Out = context.Output("Out"); } PADDLE_ENFORCE_NOT_NULL(*Out, platform::errors::NotFound( "Cannot get the tensor from the Variable " "Output(Out), variable name = %s", context.OutputName("Out"))); } template inline void ExtractActivationGradTensor( const framework::ExecutionContext& context, const framework::Tensor** X, const framework::Tensor** Out, const framework::Tensor** dOut, framework::Tensor** dX) { auto out_grad_var = context.InputVar(framework::GradVarName("Out")); auto x_grad_var = context.OutputVar(framework::GradVarName("X")); const framework::Variable* out_var = nullptr; if (static_cast(kDepValue) & static_cast(ActBwdOpFwdDeps::kDepOut)) { out_var = context.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", context.InputName("Out"))); } PADDLE_ENFORCE_NOT_NULL( out_grad_var, platform::errors::NotFound( "Cannot get input Variable %s, variable name = %s", framework::GradVarName("Out"), context.InputName(framework::GradVarName("Out")))); PADDLE_ENFORCE_NOT_NULL( x_grad_var, platform::errors::NotFound( "Cannot get output Variable %s, variable name = %s", framework::GradVarName("X"), context.OutputName(framework::GradVarName("X")))); if (CanBeUsedBySelectedRows.count(context.Type())) { *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( *out_grad_var); *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( x_grad_var); if (out_var) { *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); } else { *Out = *dOut; // fake out } } else { *Out = context.Input("Out"); *dOut = context.Input(framework::GradVarName("Out")); *dX = context.Output(framework::GradVarName("X")); if (out_var) { *Out = &(out_var->Get()); } else { *Out = *dOut; // fake out } } PADDLE_ENFORCE_NOT_NULL(*dX, platform::errors::NotFound( "Cannot get the tensor from the Variable " "Output(Out), variable name = %s", context.OutputName(framework::GradVarName("X")))); if (static_cast(kDepValue) & static_cast(ActBwdOpFwdDeps::kDepX)) { auto x_var = context.InputVar("X"); PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( "Cannot get the tensor from the " "Variable Input(X), variable name = %s", context.InputName("X"))); if (CanBeUsedBySelectedRows.count(context.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); } else { *X = context.Input("X"); } } else { VLOG(10) << " Inplace activation of Op : " << context.Type(); *X = *dX; } } template class ActivationKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* X = nullptr; framework::Tensor* Out = nullptr; ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "Activation")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "Activation")); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out)); } else { functor(*place, x, out); } } }; template class ActivationGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor *X, *Out, *dOut; framework::Tensor* dX = nullptr; X = Out = dOut = nullptr; ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout), To32BitIndex(dx)); } else { functor(*place, x, out, dout, dx); } } }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; #define USE_PHI_FUNCTOR(name) \ template \ using name##Functor = phi::funcs::name##Functor; \ template \ using name##GradFunctor = phi::funcs::name##GradFunctor; #define USE_PHI_DOUBLE_GRAD_FUNCTOR(name) \ template \ using name##GradGradFunctor = phi::funcs::name##GradGradFunctor; #define USE_PHI_TRIPLE_GRAD_FUNCTOR(name) \ template \ using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor; USE_PHI_FUNCTOR(Cos) USE_PHI_FUNCTOR(Tan) USE_PHI_FUNCTOR(Acos) USE_PHI_FUNCTOR(Sin) USE_PHI_FUNCTOR(Asin) USE_PHI_FUNCTOR(Atan) USE_PHI_FUNCTOR(Sinh) USE_PHI_FUNCTOR(Cosh) USE_PHI_FUNCTOR(Asinh) USE_PHI_FUNCTOR(Acosh) USE_PHI_FUNCTOR(Atanh) USE_PHI_FUNCTOR(Tanh) USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh) USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh) USE_PHI_FUNCTOR(BRelu) USE_PHI_FUNCTOR(ThresholdedRelu) USE_PHI_FUNCTOR(LeakyRelu) USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) template struct SigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out * (static_cast(1) - out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut -> SigmoidGradGrad -> DOutNew DDX DDOut DDOut = (1-Out)*Out*DDX DOutNew = (1-2*Out)*DOut*DDX */ template struct SigmoidGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, const framework::Tensor* dOut, framework::Tensor* dOutNew, framework::Tensor* ddOut) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidGradGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidGradGrad")); if (dOutNew) { auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidGradGrad")); auto dout_new = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SigmoidGradGrad")); dout_new.device(*d) = (static_cast(1) - static_cast(2) * out) * dout * ddx; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SigmoidGradGrad")); ddout.device(*d) = (static_cast(1) - out) * out * ddx; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; /* Out DOut D_Dout DDx -> SigmoidTripleGrad -> D_DDx D_DDout d_OutNew D_Dout_new D_Dout = (1-2*Out)*DDx*D_Dout_new D_DDx = (1-Out)*Out*D_DDout + (1-2*Out)*DOut*D_Dout_new D_OutNew = (DDx-2*Out*DDx)*D_DDout - 2*DOut*DDx*D_Dout_new Out, DDX, DOut, D_DDOut, D_DOut_New // input D_OutNew, D_DOut, D_DDx // output */ template struct SigmoidTripleGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, const framework::Tensor* dOut, const framework::Tensor* d_DDOut, const framework::Tensor* d_dOut_New, framework::Tensor* d_d_Out, framework::Tensor* d_Out_New, framework::Tensor* d_DDx) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidTripleGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidTripleGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidTripleGrad")); auto d_ddOut = framework::EigenVector::Flatten( GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad")); auto d_dOutNew = framework::EigenVector::Flatten(GET_DATA_SAFELY( d_dOut_New, "Input", "D_DOut_New", "SigmoidTripleGrad")); if (d_Out_New) { auto d_OutNew = framework::EigenVector::Flatten(GET_DATA_SAFELY( d_Out_New, "Output", "D_OutNew", "SigmoidTripleGrad")); d_OutNew.device(*d) = (ddx - static_cast(2) * out * ddx) * d_ddOut - static_cast(2) * dout * ddx * d_dOutNew; } if (d_d_Out) { auto d_dOut = framework::EigenVector::Flatten( GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "SigmoidTripleGrad")); d_dOut.device(*d) = (static_cast(1) - static_cast(2) * out) * ddx * d_dOutNew; } if (d_DDx) { auto d_ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "SigmoidTripleGrad")); d_ddx.device(*d) = (static_cast(1) - out) * out * d_ddOut + (static_cast(1) - static_cast(2) * out) * dout * d_dOutNew; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // silu(x) = x / (1 + exp(-x)) template struct SiluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = static_cast(1) / (static_cast(1) + (-x).exp()); out.device(d) = x * temp; } }; // silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x})) template struct SiluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(1) + (-x).exp(); // 1+e^(-x) auto temp2 = x * (-x).exp(); // x*e^(-x) dx.device(d) = dout * ((static_cast(1) / temp1) * (static_cast(1) + (temp2 / temp1))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: // out = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) // = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) // // Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; // Originally: f' = exp(-x) / (1 + exp(-x)) // For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // expm1(x) = e^x - 1 template struct Expm1Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.expm1(); } }; template struct Expm1GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out + dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // relu(x) = max(x, 0) template using ReluCPUFunctor = phi::funcs::ReluCPUFunctor; template using ReluGradFunctor = phi::funcs::ReluGradFunctor; template using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; template using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x.tanh() * x.tanh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct HardShrinkFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); out.device(d) = x * (temp1 || temp2).template cast(); } }; template struct HardShrinkGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = x < static_cast(threshold * -1.f); auto temp2 = x > static_cast(threshold); dx.device(d) = dout * (temp1 || temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast(); auto temp2 = (x < -lambdaT).template cast(); out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; template struct SoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast(); auto temp2 = (x < -lambdaT).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0.5) * dout / out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.rsqrt(); } }; template struct RsqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(-0.5) * dout * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // ceil(x) = ceiling(x) template struct CeilFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.ceil(); } }; template struct ZeroGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kNoDeps; } }; // floor(x) = flooring(x) template struct FloorFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.floor(); } }; // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.round(); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(-1) * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log2(x) = logarithm to the base 2 of the elements of x template struct Log2Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log() / static_cast(log(2)); } }; // the gradient of log2(x) is 1/(x*ln(2)) template struct Log2GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * static_cast(log(2))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log10(x) = logarithm to the base 10 of the elements of x template struct Log10Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log() / static_cast(log(10)); } }; // the gradient of log10(x) is 1/(x*ln(10)) template struct Log10GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (x * static_cast(log(10))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // log1p(x) = natural logarithm of x+1 template struct Log1pFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = (static_cast(1) + x).log(); } }; template struct Log1pGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / (x + static_cast(1))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(2) * x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; template struct Relu6GradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(threshold))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // HardSwish = min(max(0, x+3), 6) * x / 6 template struct HardSwishFunctor : public BaseActivationFunctor { float threshold; float scale; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x + static_cast(offset)) .cwiseMax(static_cast(0)) .cwiseMin(static_cast(threshold)) * x / static_cast(scale); } }; template struct HardSwishGradFunctor : public BaseActivationFunctor { float threshold; float scale; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = ((x + static_cast(offset)) < static_cast(threshold)) .template cast(); dx.device(d) = dout * (((x + static_cast(offset)) > static_cast(0)).template cast() * (static_cast(2) * x + static_cast(offset)) / static_cast(scale) * tmp + static_cast(1) * (static_cast(1) - tmp)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // For numerical stability, using the following formula instead of softplus(x) = // log(1 + exp(x)) // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = // 1, threshold = 20 by default), otherwise x template struct SoftplusFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) { auto x_beta = static_cast(beta) * x; out.device(d) = (x_beta > static_cast(threshold)) .select(x, (static_cast(1) + x_beta.exp()).log() / static_cast(beta)); } }; // For numerical stability, using the following formula instead of // d(softplus(x))/dx = 1 / (1 + exp(-x)) // d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta // = 1, threshold = 20 by default), otherwise x template struct SoftplusGradFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto x_beta = static_cast(beta) * x; dx.device(d) = (x_beta > static_cast(threshold)) .select(dout, dout / (static_cast(1) + (-x_beta).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise template struct MishFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); out.device(d) = x * sp.tanh(); } }; // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) // sp = softplus(x) template struct MishGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); auto gsp = static_cast(1) - (-sp).exp(); auto tsp = sp.tanh(); dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { out.device(d) = x / (static_cast(1) + x.abs()); } }; // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) { dx.device(d) = dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); out.device(d) = (static_cast(1) + temp.exp()).log(); } }; template struct SoftReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((out > -tmp) * (out < tmp)).template cast(); dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct ELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x < static_cast(0)) .select(static_cast(alpha) * (x.exp() - static_cast(1)), x); } }; template struct ELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { // case 1: alpha >= 0 // dx = dout, if out > 0 // dx = dout * (out + alpha), if out <= 0 dx.device(d) = (out > static_cast(0)) .select(dout, dout * (out + static_cast(alpha))); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { // case 2: alpha < 0 // dx = dout, if x > 0 // dx = dout * (out + alpha), if x <=0 dx.device(d) = (x > static_cast(0)) .select(dout, dout * static_cast(alpha) * x.exp()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template class ELUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); const float alpha = context.Attr("alpha"); dX->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "elu_grad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad")); auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad")); auto* place = context.template device_context().eigen_device(); if (alpha > 0) { ELUGradFunctor functor; functor.alpha = alpha; functor(*place, x, out, dout, dx); } else { ELUGradNegativeAlphaFunctor functor; functor.alpha = alpha; functor(*place, x, out, dout, dx); } } }; template struct CELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x < static_cast(0)) .select(static_cast(alpha) * ((x / static_cast(alpha)).exp() - static_cast(1)), x); } }; template struct CELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp_a_pos = static_cast(alpha > 0); auto temp_a_neg = static_cast(alpha <= 0); auto temp_x_pos = (x > static_cast(0)).template cast(); auto temp_x_neg = (x <= static_cast(0)).template cast(); // dx = dout, if alpha > 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 // dx = dout , if alpha < 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 dx.device(d) = dout * temp_a_pos * temp_x_pos + dout * (x / static_cast(alpha)).exp() * temp_a_pos * temp_x_neg + dout * temp_a_neg * temp_x_pos + dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.pow(static_cast(factor)); } }; template struct PowGradFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor) - static_cast(1)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LogitFunctor { template void operator()(Device d, X x, Out out, P p, float eps) const { // logit(x) = ln(x/(1-x)) auto tmp_x = (x.cwiseMin(static_cast(1.0 - eps))).cwiseMax(static_cast(eps)); if (!eps) { out.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) .select(p.constant(static_cast(NAN)), (tmp_x / (static_cast(1) - tmp_x)).log()); } else { out.device(d) = (tmp_x / (static_cast(1) - tmp_x)).log(); } } }; template struct LogitGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { // logit(x)' = 1/(x*(1-x)) dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) .select(p.constant(static_cast(0)), dout * (static_cast(1) / ((static_cast(1) - x) * x))); } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dout * a * b * (static_cast(1) - temp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out) const { auto temp = x * static_cast(slope) + static_cast(offset); out.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; template struct HardSigmoidGradFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(1))) .template cast() * static_cast(slope); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct SwishFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); } }; template struct SwishGradFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); auto out = x * temp1; auto temp2 = temp1 * (static_cast(1) - (static_cast(beta) * out)); dx.device(d) = dout * ((static_cast(beta) * out) + temp2); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct AbsGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "AbsGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "AbsGradGrad")); if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "AbsGradGrad")); ddout.device(*d) = ddx * x.sign(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct ELUGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad")); if (dX) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); dx.device(*d) = ddx * dout * static_cast(alpha) * x.exp() * (x <= static_cast(0)).template cast(); } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * x.exp() * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CELUGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad")); if (dX) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad")); dx.device(*d) = ddx * dout / static_cast(alpha) * (x / static_cast(alpha)).exp() * (x <= static_cast(0)).template cast(); } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + (x / static_cast(alpha)).exp() * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, const framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // calculate dy first, so ddy can inplace ddx if (dOut) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); dout.device(*d) = dx * ddx * static_cast(-1) / out; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); ddout.device(*d) = ddx * static_cast(0.5) / out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct RsqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, const framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx if (dOut) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct SquareGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad")); // square GradGrad: ddy=2x*ddx, dx=2dy*ddx // calculate dx first, so ddy can inplace ddx if (dX) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad")); dx.device(*d) = ddx * static_cast(2) * dout; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad")); ddout.device(*d) = ddx * static_cast(2) * x; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // DOut(dy) as input(not output), tensor extraction is different from // others. Impliment extraction kernel seperately here. inline void ExtractDoubleGradTensorWithInputDOut( const framework::ExecutionContext& ctx, const framework::Tensor** X, const framework::Tensor** ddX, framework::Tensor** dX, const framework::Tensor** dOut, framework::Tensor** ddOut) { // extract ddX(output), ddOut(input) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("DDX"))); *ddX = ctx.Input("DDX"); if (ddo_var) { *ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get the tensor from the Variable DDX, variable name = %s", ctx.OutputName("DDX"))); // extract x(input), dx(output) auto x_var = ctx.InputVar("X"); PADDLE_ENFORCE_NOT_NULL( x_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("X"))); auto dx_var = ctx.OutputVar("DX"); *X = ctx.Input("X"); if (dx_var) { *dX = ctx.Output("DX"); } // extract dOut(input) auto dout_var = ctx.InputVar("DOut"); if (dout_var) { *dOut = ctx.Input("DOut"); } } template class SigmoidDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *ddX, *dOut; framework::Tensor *dOutNew, *ddOut; Out = ddX = dOut = nullptr; dOutNew = ddOut = nullptr; // extract ddx(input) and out(input) ddX = ctx.Input("DDX"); Out = ctx.Input("Out"); PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable ddX, variable name = %s", ctx.InputName("DDX"))); PADDLE_ENFORCE_NOT_NULL( Out, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); // set output ddout ddOut = ctx.Output("DDOut"); // extract dOut(intput) dOut = ctx.Input("DOut"); PADDLE_ENFORCE_NOT_NULL( dOut, platform::errors::NotFound( "Cannot get input Variable dOut, variable name = %s", ctx.InputName("DOut"))); dOutNew = ctx.Output("DOutNew"); if (dOutNew) dOutNew->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, dOut, dOutNew, ddOut); } }; // Out, DDX, DOut, D_DDOut, D_DOut_New // input // D_OutNew, D_DOut, D_DDx // output template class SigmoidTripleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *ddX, *dOut, *d_ddOut, *d_dOutNew; framework::Tensor *d_OutNew, *d_dOut, *d_ddx; Out = ddX = dOut = d_ddOut = d_dOutNew = nullptr; d_OutNew = d_dOut = d_ddx = nullptr; // extract ddx(input), out(input), dOut(input), d_ddOut(input), // d_dOutNew(input) ddX = ctx.Input("DDX"); Out = ctx.Input("Out"); dOut = ctx.Input("DOut"); d_ddOut = ctx.Input("D_DDOut"); d_dOutNew = ctx.Input("D_DOut_New"); PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable ddX, variable name = %s", ctx.InputName("DDX"))); PADDLE_ENFORCE_NOT_NULL( Out, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); PADDLE_ENFORCE_NOT_NULL( dOut, platform::errors::NotFound( "Cannot get input Variable dOut, variable name = %s", ctx.InputName("DOut"))); PADDLE_ENFORCE_NOT_NULL( d_ddOut, platform::errors::NotFound( "Cannot get input Variable d_ddOut, variable name = %s", ctx.InputName("D_DDOut"))); PADDLE_ENFORCE_NOT_NULL( d_dOutNew, platform::errors::NotFound( "Cannot get input Variable d_dOutNew, variable name = %s", ctx.InputName("D_DOutNew"))); // set output d_OutNew、d_dOut、d_ddx d_dOut = ctx.Output("D_DOut"); d_OutNew = ctx.Output("D_OutNew"); d_ddx = ctx.Output("D_DDx"); if (d_dOut) d_dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (d_OutNew) d_OutNew->mutable_data(Out->dims(), ctx.GetPlace()); if (d_ddx) d_ddx->mutable_data(ddX->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, dOut, d_ddOut, d_dOutNew, // input d_dOut, d_OutNew, d_ddx); // output } }; template class TanhDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *ddX, *dOut; framework::Tensor *dOutNew, *ddOut; Out = ddX = dOut = nullptr; dOutNew = ddOut = nullptr; // extract ddx(input) and out(input) auto ddx_var = ctx.InputVar("DDX"); auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable ddx, variable name = %s", ctx.InputName("DDX"))); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable out, variable name = %s", ctx.InputName("Out"))); ddX = ctx.Input("DDX"); Out = ctx.Input("Out"); // set output ddout auto ddout_var = ctx.OutputVar("DDOut"); if (ddout_var) { ddOut = ctx.Output("DDOut"); } // extract dOut(intput) auto dout_var = ctx.InputVar("DOut"); PADDLE_ENFORCE_NOT_NULL( dout_var, platform::errors::NotFound( "Cannot get input Variable dout_var, variable name = %s", ctx.InputName("DOut"))); dOut = ctx.Input("DOut"); // set output dout_new auto dout_new_var = ctx.OutputVar("DOutNew"); if (dout_new_var) { dOutNew = ctx.Output("DOutNew"); } if (dOutNew) dOutNew->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, dOut, dOutNew, ddOut); } }; template class TanhTripeGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *ddX, *dOut, *d_ddOut, *d_dOutNew; framework::Tensor *d_OutNew, *d_dOut, *d_ddx; Out = ddX = dOut = d_ddOut = d_dOutNew = nullptr; d_OutNew = d_dOut = d_ddx = nullptr; // extract ddx(input), out(input), dOut(input), d_ddOut(input), // d_dOutNew(input) ddX = ctx.Input("DDX"); Out = ctx.Input("Out"); dOut = ctx.Input("DOut"); d_ddOut = ctx.Input("D_DDOut"); d_dOutNew = ctx.Input("D_DOut_New"); PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable ddX, variable name = %s", ctx.InputName("DDX"))); PADDLE_ENFORCE_NOT_NULL( Out, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); PADDLE_ENFORCE_NOT_NULL( dOut, platform::errors::NotFound( "Cannot get input Variable dOut, variable name = %s", ctx.InputName("DOut"))); PADDLE_ENFORCE_NOT_NULL( d_ddOut, platform::errors::NotFound( "Cannot get input Variable d_ddOut, variable name = %s", ctx.InputName("D_DDOut"))); PADDLE_ENFORCE_NOT_NULL( d_dOutNew, platform::errors::NotFound( "Cannot get input Variable d_dOutNew, variable name = %s", ctx.InputName("D_DOutNew"))); // set output d_OutNew、d_dOut、d_ddx d_dOut = ctx.Output("D_DOut"); d_OutNew = ctx.Output("D_OutNew"); d_ddx = ctx.Output("D_DDx"); if (d_dOut) d_dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (d_OutNew) d_OutNew->mutable_data(Out->dims(), ctx.GetPlace()); if (d_ddx) d_ddx->mutable_data(ddX->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, dOut, d_ddOut, d_dOutNew, // input d_dOut, d_OutNew, d_ddx); // output } }; template class SquareDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *ddX, *dOut; X = ddX = dOut = nullptr; framework::Tensor *dX, *ddOut; dX = ddOut = nullptr; ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, X, ddX, ddOut, dOut, dX); } }; template class LogDoubleGradKernel : public SquareDoubleGradKernel {}; template class ELUDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *ddX, *dOut; X = ddX = dOut = nullptr; framework::Tensor *dX, *ddOut; dX = ddOut = nullptr; ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } functor(place, X, ddX, ddOut, dOut, dX); } }; template class CELUDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *ddX, *dOut; X = ddX = dOut = nullptr; framework::Tensor *dX, *ddOut; dX = ddOut = nullptr; ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } functor(place, X, ddX, ddOut, dOut, dX); } }; template class SqrtDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *dX, *ddX; Out = dX = ddX = nullptr; framework::Tensor *ddOut, *dOut; ddOut = dOut = nullptr; // extract ddx(input), ddout(output) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); ddX = ctx.Input("DDX"); if (ddo_var) { ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); // extract out(input), dout(output) auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); auto dout_var = ctx.OutputVar("DOut"); Out = ctx.Input("Out"); if (dout_var) { dOut = ctx.Output("DOut"); } // extract dx(input) auto dx_var = ctx.InputVar("DX"); PADDLE_ENFORCE_NOT_NULL( dx_var, platform::errors::NotFound( "Cannot get input Variable DX, variable name = %s", ctx.InputName("DX"))); if (dx_var) { dX = ctx.Input("DX"); } if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, ddOut, dOut, dX); } }; // rsqrt Grad: dx = -0.5 * dy * y * y * y // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx template class RsqrtDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *dX, *ddX; Out = dX = ddX = nullptr; framework::Tensor *ddOut, *dOut; ddOut = dOut = nullptr; // extract ddx(input), ddout(output) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); ddX = ctx.Input("DDX"); if (ddo_var) { ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); // extract out(input), dout(output) auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); auto dout_var = ctx.OutputVar("DOut"); Out = ctx.Input("Out"); if (dout_var) { dOut = ctx.Output("DOut"); } // extract dx(input) auto dx_var = ctx.InputVar("DX"); PADDLE_ENFORCE_NOT_NULL( dx_var, platform::errors::NotFound( "Cannot get input Variable DX, variable name = %s", ctx.InputName("DX"))); if (dx_var) { dX = ctx.Input("DX"); } if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, ddOut, dOut, dX); } }; template class PowKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* X = nullptr; framework::Tensor* Out = nullptr; ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "Pow")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "Pow")); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } // get FactorTensor auto* factor_tensor = context.HasInput("FactorTensor") ? context.Input("FactorTensor") : nullptr; if (factor_tensor) { auto* factor_data = factor_tensor->data(); framework::Tensor cpu_factor_tensor; if (platform::is_gpu_place(factor_tensor->place())) { framework::TensorCopySync(*factor_tensor, platform::CPUPlace(), &cpu_factor_tensor); factor_data = cpu_factor_tensor.data(); } auto factor = std::vector(factor_data, factor_data + factor_tensor->numel()); PADDLE_ENFORCE_EQ( factor.size(), 1, platform::errors::InvalidArgument( "The shape of factor(tensor) must be [1] rather than %d", factor.size())); for (auto& attr : attrs) { *attr.second = factor[0]; } } functor(*place, x, out); } }; template class PowGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor *X, *Out, *dOut; framework::Tensor* dX = nullptr; X = Out = dOut = nullptr; ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad")); auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "PowGrad")); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } // get FactorTensor auto* factor_tensor = context.HasInput("FactorTensor") ? context.Input("FactorTensor") : nullptr; if (factor_tensor) { auto* factor_data = factor_tensor->data(); framework::Tensor cpu_factor_tensor; if (platform::is_gpu_place(factor_tensor->place())) { framework::TensorCopySync(*factor_tensor, platform::CPUPlace(), &cpu_factor_tensor); factor_data = cpu_factor_tensor.data(); } auto factor = std::vector(factor_data, factor_data + factor_tensor->numel()); PADDLE_ENFORCE_EQ( factor.size(), 1, platform::errors::InvalidArgument( "The shape of factor(tensor) must be [1] rather than %d", factor.size())); for (auto& attr : attrs) { *attr.second = factor[0]; } } functor(*place, x, out, dout, dx); } }; template class LogitKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* out = context.Output("Out"); auto* in = context.Input("X"); auto eps = context.Attr("eps"); out->mutable_data(in->place()); auto eigen_out = framework::EigenVector::Flatten(*out); auto eigen_in = framework::EigenVector::Flatten(*in); auto& place = *context.template device_context().eigen_device(); auto eigen_p = framework::EigenVector::Flatten(*out); LogitFunctor functor; functor(place, eigen_in, eigen_out, eigen_p, eps); } }; template class LogitGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto eps = context.Attr("eps"); dx->mutable_data(dout->place()); auto eigen_x = framework::EigenVector::Flatten(*x); auto eigen_dout = framework::EigenVector::Flatten(*dout); auto eigen_dx = framework::EigenVector::Flatten(*dx); auto& place = *context.template device_context().eigen_device(); auto eigen_p = framework::EigenVector::Flatten(*x); LogitGradFunctor functor; functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps); } }; template struct LogGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad")); // ddout = ddx / x; dx = -(dout / x) * (ddx / x) // calculate dx first, so ddout can inplace ddx if (dX) { auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad")); auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad")); dx.device(*d) = dout * static_cast(-1) * ddx / (x * x); } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad")); ddout.device(*d) = ddx * static_cast(1) / x; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; } // namespace operators } // namespace paddle #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(silu, Silu, SiluFunctor, SiluGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ __macro(log2, Log2, Log2Functor, Log2GradFunctor); \ __macro(log10, Log10, Log10Functor, Log10GradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \ __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ HardSigmoidGradFunctor); \ __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(mish, Mish, MishFunctor, MishGradFunctor); \ __macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);