/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif #include "paddle/phi/kernels/funcs/activation_functor.h" namespace paddle { namespace operators { using framework::To32BitIndex; using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps; /* The following operator can be used to process SelectedRows, because the * output of those operator for zero is zero too. */ static std::unordered_set CanBeUsedBySelectedRows = { "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"}; inline void ExtractActivationTensor(const framework::ExecutionContext& context, const framework::Tensor** X, framework::Tensor** Out) { auto x_var = context.InputVar("X"); auto out_var = context.OutputVar("Out"); PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( "Cannot get input Variable X, variable name = %s", context.InputName("X"))); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get output Variable Out, variable name = %s", context.OutputName("Out"))); if (CanBeUsedBySelectedRows.count(context.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( out_var); } else { *X = context.Input("X"); *Out = context.Output("Out"); } PADDLE_ENFORCE_NOT_NULL(*Out, platform::errors::NotFound( "Cannot get the tensor from the Variable " "Output(Out), variable name = %s", context.OutputName("Out"))); } template inline void ExtractActivationGradTensor( const framework::ExecutionContext& context, const framework::Tensor** X, const framework::Tensor** Out, const framework::Tensor** dOut, framework::Tensor** dX) { auto out_grad_var = context.InputVar(framework::GradVarName("Out")); auto x_grad_var = context.OutputVar(framework::GradVarName("X")); const framework::Variable* out_var = nullptr; if (static_cast(kDepValue) & static_cast(ActBwdOpFwdDeps::kDepOut)) { out_var = context.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", context.InputName("Out"))); } PADDLE_ENFORCE_NOT_NULL( out_grad_var, platform::errors::NotFound( "Cannot get input Variable %s, variable name = %s", framework::GradVarName("Out"), context.InputName(framework::GradVarName("Out")))); PADDLE_ENFORCE_NOT_NULL( x_grad_var, platform::errors::NotFound( "Cannot get output Variable %s, variable name = %s", framework::GradVarName("X"), context.OutputName(framework::GradVarName("X")))); if (CanBeUsedBySelectedRows.count(context.Type())) { *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( *out_grad_var); *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( x_grad_var); if (out_var) { *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); } else { *Out = *dOut; // fake out } } else { *Out = context.Input("Out"); *dOut = context.Input(framework::GradVarName("Out")); *dX = context.Output(framework::GradVarName("X")); if (out_var) { *Out = &(out_var->Get()); } else { *Out = *dOut; // fake out } } PADDLE_ENFORCE_NOT_NULL(*dX, platform::errors::NotFound( "Cannot get the tensor from the Variable " "Output(Out), variable name = %s", context.OutputName(framework::GradVarName("X")))); if (static_cast(kDepValue) & static_cast(ActBwdOpFwdDeps::kDepX)) { auto x_var = context.InputVar("X"); PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( "Cannot get the tensor from the " "Variable Input(X), variable name = %s", context.InputName("X"))); if (CanBeUsedBySelectedRows.count(context.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); } else { *X = context.Input("X"); } } else { VLOG(10) << " Inplace activation of Op : " << context.Type(); *X = *dX; } } template class ActivationKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* X = nullptr; framework::Tensor* Out = nullptr; ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "Activation")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "Activation")); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out)); } else { functor(*place, x, out); } } }; template class ActivationGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor *X, *Out, *dOut; framework::Tensor* dX = nullptr; X = Out = dOut = nullptr; ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } // use 32bit index to speed up computation bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); if (use_32bit_index && is_gpu_place) { functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout), To32BitIndex(dx)); } else { functor(*place, x, out, dout, dx); } } }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; #define USE_PHI_FUNCTOR(name) \ template \ using name##Functor = phi::funcs::name##Functor; \ template \ using name##GradFunctor = phi::funcs::name##GradFunctor; #define USE_PHI_DOUBLE_GRAD_FUNCTOR(name) \ template \ using name##GradGradFunctor = phi::funcs::name##GradGradFunctor; #define USE_PHI_TRIPLE_GRAD_FUNCTOR(name) \ template \ using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor; USE_PHI_FUNCTOR(Cos) USE_PHI_FUNCTOR(Tan) USE_PHI_FUNCTOR(Acos) USE_PHI_FUNCTOR(Sin) USE_PHI_FUNCTOR(Asin) USE_PHI_FUNCTOR(Atan) USE_PHI_FUNCTOR(Sinh) USE_PHI_FUNCTOR(Cosh) USE_PHI_FUNCTOR(Asinh) USE_PHI_FUNCTOR(Acosh) USE_PHI_FUNCTOR(Atanh) USE_PHI_FUNCTOR(Tanh) USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh) USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh) USE_PHI_FUNCTOR(BRelu) USE_PHI_FUNCTOR(ThresholdedRelu) USE_PHI_FUNCTOR(LeakyRelu) USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) USE_PHI_FUNCTOR(HardShrink) USE_PHI_FUNCTOR(SoftShrink) USE_PHI_FUNCTOR(TanhShrink) USE_PHI_FUNCTOR(Silu) USE_PHI_FUNCTOR(ELU) USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU) USE_PHI_FUNCTOR(Sigmoid) USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid) USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid) USE_PHI_FUNCTOR(LogSigmoid) USE_PHI_FUNCTOR(HardSigmoid) USE_PHI_FUNCTOR(Log) USE_PHI_DOUBLE_GRAD_FUNCTOR(Log) USE_PHI_FUNCTOR(Log2) USE_PHI_FUNCTOR(Log10) USE_PHI_FUNCTOR(Log1p) USE_PHI_FUNCTOR(Swish) USE_PHI_FUNCTOR(HardSwish) USE_PHI_FUNCTOR(Pow) template using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; template using RoundFunctor = phi::funcs::RoundFunctor; template using FloorFunctor = phi::funcs::FloorFunctor; template using CeilFunctor = phi::funcs::CeilFunctor; template using ZeroGradFunctor = phi::funcs::ZeroGradFunctor; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // expm1(x) = e^x - 1 template struct Expm1Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.expm1(); } }; template struct Expm1GradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out + dout; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // relu(x) = max(x, 0) template using ReluCPUFunctor = phi::funcs::ReluCPUFunctor; template using ReluGradFunctor = phi::funcs::ReluGradFunctor; template using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; template using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0.5) * dout / out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.rsqrt(); } }; template struct RsqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(-0.5) * dout * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(-1) * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(2) * x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; template struct Relu6GradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(threshold))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; // For numerical stability, using the following formula instead of softplus(x) = // log(1 + exp(x)) // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = // 1, threshold = 20 by default), otherwise x template struct SoftplusFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) { auto x_beta = static_cast(beta) * x; out.device(d) = (x_beta > static_cast(threshold)) .select(x, (static_cast(1) + x_beta.exp()).log() / static_cast(beta)); } }; // For numerical stability, using the following formula instead of // d(softplus(x))/dx = 1 / (1 + exp(-x)) // d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta // = 1, threshold = 20 by default), otherwise x template struct SoftplusGradFunctor : public BaseActivationFunctor { float beta; float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}, {"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto x_beta = static_cast(beta) * x; dx.device(d) = (x_beta > static_cast(threshold)) .select(dout, dout / (static_cast(1) + (-x_beta).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // mish(x) = x * tanh(softplus(x)) // softplus(x) = x, if x > threshold // = ln(1 + exp(x)), otherwise template struct MishFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); out.device(d) = x * sp.tanh(); } }; // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) // sp = softplus(x) template struct MishGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto sp = (x > static_cast(threshold)) .select(x, (static_cast(1) + x.exp()).log()); auto gsp = static_cast(1) - (-sp).exp(); auto tsp = sp.tanh(); dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { out.device(d) = x / (static_cast(1) + x.abs()); } }; // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) { dx.device(d) = dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); out.device(d) = (static_cast(1) + temp.exp()).log(); } }; template struct SoftReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((out > -tmp) * (out < tmp)).template cast(); dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template class ELUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); const float alpha = context.Attr("alpha"); dX->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "elu_grad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad")); auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad")); auto* place = context.template device_context().eigen_device(); if (alpha > 0) { ELUGradFunctor functor; functor.alpha = alpha; functor(*place, x, out, dout, dx); } else { ELUGradNegativeAlphaFunctor functor; functor.alpha = alpha; functor(*place, x, out, dout, dx); } } }; template struct CELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = (x < static_cast(0)) .select(static_cast(alpha) * ((x / static_cast(alpha)).exp() - static_cast(1)), x); } }; template struct CELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp_a_pos = static_cast(alpha > 0); auto temp_a_neg = static_cast(alpha <= 0); auto temp_x_pos = (x > static_cast(0)).template cast(); auto temp_x_neg = (x <= static_cast(0)).template cast(); // dx = dout, if alpha > 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0 // dx = dout , if alpha < 0 and x > 0 // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0 dx.device(d) = dout * temp_a_pos * temp_x_pos + dout * (x / static_cast(alpha)).exp() * temp_a_pos * temp_x_neg + dout * temp_a_neg * temp_x_pos + dout * (x / static_cast(alpha)).exp() * temp_a_neg * temp_x_neg; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct LogitFunctor { template void operator()(Device d, X x, Out out, P p, float eps) const { // logit(x) = ln(x/(1-x)) auto tmp_x = (x.cwiseMin(static_cast(1.0 - eps))).cwiseMax(static_cast(eps)); if (!eps) { out.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) .select(p.constant(static_cast(NAN)), (tmp_x / (static_cast(1) - tmp_x)).log()); } else { out.device(d) = (tmp_x / (static_cast(1) - tmp_x)).log(); } } }; template struct LogitGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { // logit(x)' = 1/(x*(1-x)) dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) .select(p.constant(static_cast(0)), dout * (static_cast(1) / ((static_cast(1) - x) * x))); } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dout * a * b * (static_cast(1) - temp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct AbsGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "AbsGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "AbsGradGrad")); if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "AbsGradGrad")); ddout.device(*d) = ddx * x.sign(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct CELUGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad")); if (dX) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad")); dx.device(*d) = ddx * dout / static_cast(alpha) * (x / static_cast(alpha)).exp() * (x <= static_cast(0)).template cast(); } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + (x / static_cast(alpha)).exp() * (x <= static_cast(0)).template cast()) .template cast(); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; template struct SqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, const framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // calculate dy first, so ddy can inplace ddx if (dOut) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); dout.device(*d) = dx * ddx * static_cast(-1) / out; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); ddout.device(*d) = ddx * static_cast(0.5) / out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct RsqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, const framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad")); auto out = framework::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad")); // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx if (dOut) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad")); dout.device(*d) = (static_cast(3.0) / out) * dx * ddx; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad")); ddout.device(*d) = ddx * static_cast(-0.5) * out * out * out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; } }; template struct SquareGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad")); auto x = framework::EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad")); // square GradGrad: ddy=2x*ddx, dx=2dy*ddx // calculate dx first, so ddy can inplace ddx if (dX) { auto dx = framework::EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad")); auto dout = framework::EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad")); dx.device(*d) = ddx * static_cast(2) * dout; } if (ddOut) { auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad")); ddout.device(*d) = ddx * static_cast(2) * x; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // DOut(dy) as input(not output), tensor extraction is different from // others. Impliment extraction kernel seperately here. inline void ExtractDoubleGradTensorWithInputDOut( const framework::ExecutionContext& ctx, const framework::Tensor** X, const framework::Tensor** ddX, framework::Tensor** dX, const framework::Tensor** dOut, framework::Tensor** ddOut) { // extract ddX(output), ddOut(input) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("DDX"))); *ddX = ctx.Input("DDX"); if (ddo_var) { *ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get the tensor from the Variable DDX, variable name = %s", ctx.OutputName("DDX"))); // extract x(input), dx(output) auto x_var = ctx.InputVar("X"); PADDLE_ENFORCE_NOT_NULL( x_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("X"))); auto dx_var = ctx.OutputVar("DX"); *X = ctx.Input("X"); if (dx_var) { *dX = ctx.Output("DX"); } // extract dOut(input) auto dout_var = ctx.InputVar("DOut"); if (dout_var) { *dOut = ctx.Input("DOut"); } } template class SquareDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *ddX, *dOut; X = ddX = dOut = nullptr; framework::Tensor *dX, *ddOut; dX = ddOut = nullptr; ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, X, ddX, ddOut, dOut, dX); } }; template class CELUDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *ddX, *dOut; X = ddX = dOut = nullptr; framework::Tensor *dX, *ddOut; dX = ddOut = nullptr; ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } functor(place, X, ddX, ddOut, dOut, dX); } }; template class SqrtDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *dX, *ddX; Out = dX = ddX = nullptr; framework::Tensor *ddOut, *dOut; ddOut = dOut = nullptr; // extract ddx(input), ddout(output) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); ddX = ctx.Input("DDX"); if (ddo_var) { ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); // extract out(input), dout(output) auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); auto dout_var = ctx.OutputVar("DOut"); Out = ctx.Input("Out"); if (dout_var) { dOut = ctx.Output("DOut"); } // extract dx(input) auto dx_var = ctx.InputVar("DX"); PADDLE_ENFORCE_NOT_NULL( dx_var, platform::errors::NotFound( "Cannot get input Variable DX, variable name = %s", ctx.InputName("DX"))); if (dx_var) { dX = ctx.Input("DX"); } if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, ddOut, dOut, dX); } }; // rsqrt Grad: dx = -0.5 * dy * y * y * y // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx template class RsqrtDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *dX, *ddX; Out = dX = ddX = nullptr; framework::Tensor *ddOut, *dOut; ddOut = dOut = nullptr; // extract ddx(input), ddout(output) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE_NOT_NULL( ddx_var, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); ddX = ctx.Input("DDX"); if (ddo_var) { ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE_NOT_NULL( ddX, platform::errors::NotFound( "Cannot get input Variable DDX, variable name = %s", ctx.InputName("DDX"))); // extract out(input), dout(output) auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE_NOT_NULL( out_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", ctx.InputName("Out"))); auto dout_var = ctx.OutputVar("DOut"); Out = ctx.Input("Out"); if (dout_var) { dOut = ctx.Output("DOut"); } // extract dx(input) auto dx_var = ctx.InputVar("DX"); PADDLE_ENFORCE_NOT_NULL( dx_var, platform::errors::NotFound( "Cannot get input Variable DX, variable name = %s", ctx.InputName("DX"))); if (dx_var) { dX = ctx.Input("DX"); } if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, ddOut, dOut, dX); } }; template class LogitKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* out = context.Output("Out"); auto* in = context.Input("X"); auto eps = context.Attr("eps"); out->mutable_data(in->place()); auto eigen_out = framework::EigenVector::Flatten(*out); auto eigen_in = framework::EigenVector::Flatten(*in); auto& place = *context.template device_context().eigen_device(); auto eigen_p = framework::EigenVector::Flatten(*out); LogitFunctor functor; functor(place, eigen_in, eigen_out, eigen_p, eps); } }; template class LogitGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto eps = context.Attr("eps"); dx->mutable_data(dout->place()); auto eigen_x = framework::EigenVector::Flatten(*x); auto eigen_dout = framework::EigenVector::Flatten(*dout); auto eigen_dx = framework::EigenVector::Flatten(*dx); auto& place = *context.template device_context().eigen_device(); auto eigen_p = framework::EigenVector::Flatten(*x); LogitGradFunctor functor; functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps); } }; } // namespace operators } // namespace paddle #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(mish, Mish, MishFunctor, MishGradFunctor);