/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif namespace paddle { namespace operators { enum ActBwdOpFwdDeps { kNoDeps = 0x00, // Do not need any forward input/output kDepX = 0x01, // Only need forward input X kDepOut = 0x02, // Only need forward output Out // Never add kDepXOut, because Out can be always calculated // by forward input X in backward part. // FIXME(zjl): but in MKLDNN abs, X and Out are all needed... // Developers should not rely on this enum value! kDepXOut = 0x03 }; std::unique_ptr> GetInplaceOpSet(); static bool IsInplace(const std::string& op) { static auto InplaceOpSet = GetInplaceOpSet(); bool inplace = InplaceOpSet->count(op); // for op_grad const int kGradSuffixLen = 4; if (op.size() > kGradSuffixLen && op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) { inplace = InplaceOpSet->count(op.substr(0, op.size() - (kGradSuffixLen + 1))); } return inplace; } /* The following operator can be used to process SelectedRows, because the * output of those operator for zero is zero too. */ static std::unordered_set CanBeUsedBySelectedRows = { "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"}; inline void ExtractActivationTensor(const framework::ExecutionContext& context, const framework::Tensor** X, framework::Tensor** Out) { auto x_var = context.InputVar("X"); auto out_var = context.OutputVar("Out"); PADDLE_ENFORCE(x_var != nullptr, "Cannot get input Variable X, variable name = %s", context.op().Input("X")); PADDLE_ENFORCE(out_var != nullptr, "Cannot get output Variable Out, variable name = %s", context.op().Output("Out")); if (CanBeUsedBySelectedRows.count(context.op().Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( out_var); } else { *X = context.Input("X"); *Out = context.Output("Out"); } PADDLE_ENFORCE(*Out != nullptr, "Cannot get output tensor Out, variable name = %s", context.op().Output("Out")); } template inline void ExtractActivationGradTensor( const framework::ExecutionContext& context, const framework::Tensor** X, const framework::Tensor** Out, const framework::Tensor** dOut, framework::Tensor** dX) { auto out_grad_var = context.InputVar(framework::GradVarName("Out")); auto x_grad_var = context.OutputVar(framework::GradVarName("X")); const framework::Variable* out_var = nullptr; if (static_cast(kDepValue) & static_cast(kDepOut)) { out_var = context.InputVar("Out"); PADDLE_ENFORCE(out_var != nullptr, "Cannot get input Variable Out, variable name = %s", context.op().Input("Out")); } PADDLE_ENFORCE(out_grad_var != nullptr, "Cannot get input Variable %s, variable name = %s", framework::GradVarName("Out"), context.op().Input(framework::GradVarName("Out"))); PADDLE_ENFORCE(x_grad_var != nullptr, "Cannot get output Variable %s, variable name = %s", framework::GradVarName("X"), context.op().Output(framework::GradVarName("X"))); if (CanBeUsedBySelectedRows.count(context.op().Type())) { *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( *out_grad_var); *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( x_grad_var); if (out_var) { *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); } else { *Out = *dOut; // fake out } } else { *Out = context.Input("Out"); *dOut = context.Input(framework::GradVarName("Out")); *dX = context.Output(framework::GradVarName("X")); if (out_var) { *Out = &(out_var->Get()); } else { *Out = *dOut; // fake out } } PADDLE_ENFORCE(*dX != nullptr, "Cannot get output tensor %s, variable name = %s", framework::GradVarName("X"), context.op().Output(framework::GradVarName("X"))); if (static_cast(kDepValue) & static_cast(kDepX)) { auto x_var = context.InputVar("X"); PADDLE_ENFORCE(x_var != nullptr, "Cannot get input tensor X, variable name = %s", context.op().Input("X")); if (CanBeUsedBySelectedRows.count(context.op().Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); } else { *X = context.Input("X"); } } else { VLOG(10) << " Inplace activation of Op : " << context.op().Type(); *X = *dX; } } template class ActivationKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* X = nullptr; framework::Tensor* Out = nullptr; ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(detail::Ref(X)); auto out = framework::EigenVector::Flatten(detail::Ref(Out)); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } functor(*place, x, out); } }; template class ActivationGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor *X, *Out, *dOut; framework::Tensor* dX = nullptr; X = Out = dOut = nullptr; ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); auto out = framework::EigenVector::Flatten(detail::Ref(Out)); auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); auto x = framework::EigenVector::Flatten(detail::Ref(X)); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } functor(*place, x, out, dout, dx); } }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient. For example, sigmoid op's gradient didn't involve x, so its output can reuse input memory. But abs op's gradient use x, it can not be inplaced. gradient did use x. */ bool Inplace() const { return false; } }; // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template struct SigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out * (static_cast(1) - out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: // out = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) // = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) // // Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; // Originally: f' = exp(-x) / (1 + exp(-x)) // For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // relu(x) = max(x, 0) template struct ReluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (out > static_cast(0)).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) template struct GeluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { // Because the execute or device context can not be deliver here, it keep the // marco for NVCC. #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) auto x_data = x.data(); auto out_data = out.data(); int n = std::min(x.size(), out.size()); std::memset(out_data, 0, n * sizeof(T)); math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, out_data, 1); math::CBlas::VMERF(n, out_data, out_data, VML_LA); for (int i = 0; i < n; i++) { out_data[i] += static_cast(1); } math::CBlas::VMUL(n, x_data, out_data, out_data); for (int i = 0; i < n; i++) { out_data[i] *= static_cast(0.5); } #else auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); #endif } }; template struct GeluGradFunctor : BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto first = static_cast(0.5) * (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * (-static_cast(0.5) * x.square()).exp(); dx.device(d) = dout * (first + second); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) - out * out); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x.tanh() * x.tanh()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct HardShrinkFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); out.device(d) = x * (temp1 + temp2); } }; template struct HardShrinkGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; template struct SoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0.5) * dout / out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // rsqrt(x) = x^(-1/2) template struct RsqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.rsqrt(); } }; template struct RsqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(-0.5) * dout * out * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // ceil(x) = ceiling(x) template struct CeilFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.ceil(); } }; template struct ZeroGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } }; // floor(x) = flooring(x) template struct FloorFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.floor(); } }; template struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; template <> struct Sine { HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { return platform::float16(sin(static_cast(val))); } }; template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; template <> struct Cosine { HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { return platform::float16(cos(static_cast(val))); } }; // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosine()); } }; // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sine(x) = sin(x) template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sine()); } }; template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } }; template <> struct Acos { HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { return platform::float16(acos(static_cast(val))); } }; // Acos(x) = acos(x) template struct AcosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Acos()); } }; // acos'(x) = -1/sqrt(1-x^2) template struct AcosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Asin { HOSTDEVICE T operator()(const T& val) const { return asin(val); } }; template <> struct Asin { HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { return platform::float16(asin(static_cast(val))); } }; // Asin(x) = asin(x) template struct AsinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Asin()); } }; // asin'(x) = 1/sqrt(1-x^2) template struct AsinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct Atan { HOSTDEVICE T operator()(const T& val) const { return atan(val); } }; template <> struct Atan { HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { return platform::float16(atan(static_cast(val))); } }; // Atan(x) = atan(x) template struct AtanFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Atan()); } }; // atan'(x) = 1 / (1 + x^2) template struct AtanGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.round(); } }; // abs(x) = |x| template struct AbsFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.abs(); } }; template struct AbsGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.sign(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(-1) * out * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / x); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(2) * x; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct BReluFunctor : public BaseActivationFunctor { float t_min; float t_max; // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; template struct BReluGradFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; template struct Relu6GradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(threshold))) .template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // softplus(x) = log(1 + exp(x)) // When x is a very large positive number, exp(x) may explode to inf, // Using trick below for numerical stability // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) template struct SoftplusFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); } }; // d(softplus(x))/dx = exp(x) / (1 + exp(x)) // For numerical stability: // d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + // exp(x - max(x, 0))) template struct SoftplusGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) dx.device(d) = dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { out.device(d) = x / (static_cast(1) + x.abs()); } }; // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) { dx.device(d) = dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); out.device(d) = (static_cast(1) + temp.exp()).log(); } }; template struct SoftReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((out > -tmp) * (out < tmp)).template cast().eval(); dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(alpha) * x); } }; template struct LeakyReluGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast().eval(); auto temp2 = (x >= static_cast(0)).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct ELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)) + (static_cast(alpha) * (x.exp() - static_cast(1))) .cwiseMin(static_cast(0)); } }; template struct ELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x > static_cast(0)).template cast() + dout * static_cast(alpha) * x.exp() * (x < static_cast(0)).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.pow(static_cast(factor)); } }; template struct PowGradFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor) - static_cast(1)); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dout * a * b * (static_cast(1) - temp); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); out.device(d) = (x > th).template cast() * x; } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); dx.device(d) = dout * (x > th).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out) const { auto temp = x * static_cast(slope) + static_cast(offset); out.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; template struct HardSigmoidGradFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(1))) .template cast() * static_cast(slope); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct SwishFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); } }; template struct SwishGradFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); auto out = x * temp1; auto temp2 = temp1 * (static_cast(1) - (static_cast(beta) * out)); dx.device(d) = dout * ((static_cast(beta) * out) + temp2); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; /* * in arguments: x, out, ddx * out arguments: ddout, dout, dx */ template inline void ExtractActivationDoubleGradTensor( const framework::ExecutionContext& ctx, const framework::Tensor** X, const framework::Tensor** Out, const framework::Tensor** ddX, framework::Tensor** dX, framework::Tensor** dOut, framework::Tensor** ddOut) { auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE(ddx_var != nullptr, "Cannot get input Variable Out, variable name = %s", ctx.op().Input("DDX")); if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { *ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var); if (ddo_var) { *ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( ddo_var); } } else { *ddX = ctx.Input("DDX"); if (ddo_var) { *ddOut = ctx.Output("DDOut"); } } PADDLE_ENFORCE(*ddX != nullptr, "Cannot get output tensor DDX, variable name = %s", ctx.op().Output("DDX")); if (static_cast(kDepValue) & static_cast(kDepX)) { auto x_var = ctx.InputVar("X"); PADDLE_ENFORCE(x_var != nullptr, "Cannot get input Variable Out, variable name = %s", ctx.op().Input("X")); auto dx_var = ctx.OutputVar("DX"); if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); if (dx_var) { *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( dx_var); } } else { *X = ctx.Input("X"); if (dx_var) { *dX = ctx.Output("DX"); } } } else { VLOG(10) << "Inplace activation of Op: " << ctx.op().Type(); *X = *ddX; } if (static_cast(kDepValue) & static_cast(kDepOut)) { auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE(out_var != nullptr, "Cannot get input tensor Out, variable name = %s", ctx.op().Input("Out")); auto dout_var = ctx.OutputVar("DOut"); if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); if (dout_var) { *dOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( dout_var); } } else { *Out = ctx.Input("Out"); if (dout_var) { *dOut = ctx.Output("DOut"); } } } else { VLOG(10) << "Inplace activation of Op: " << ctx.op().Type(); *Out = *ddX; } } template class ActivationDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *Out, *ddX; X = Out = ddX = nullptr; framework::Tensor *ddOut, *dOut, *dX; ddOut = dOut = dX = nullptr; ExtractActivationDoubleGradTensor(ctx, &X, &Out, &ddX, &dX, &dOut, &ddOut); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); if (dOut) dOut->mutable_data(ctx.GetPlace()); if (dX) dX->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } functor(place, X, Out, ddX, ddOut, dOut, dX); } }; template struct ReluGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); auto out = framework::EigenVector::Flatten(detail::Ref(Out)); if (ddOut) { auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); } if (dOut) { auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); dout.device(*d) = dout.constant(static_cast(0)); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct LeakyReluGradGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); auto x = framework::EigenVector::Flatten(detail::Ref(X)); if (ddOut) { auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); ddout.device(*d) = ddx * ((x >= static_cast(0)).template cast().eval() + static_cast(alpha) * (x < static_cast(0)).template cast().eval()) .template cast(); } if (dX) { auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); dx.device(*d) = dx.constant(static_cast(0)); } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template struct SqrtGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* Out, const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, const framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); auto out = framework::EigenVector::Flatten(detail::Ref(Out)); if (ddOut) { auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); ddout.device(*d) = ddx * static_cast(0.5) / out; } if (dOut) { auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); dout.device(*d) = dx * ddx * static_cast(-1) / out; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template struct SquareGradGradFunctor : public BaseActivationFunctor { template void operator()(const Device& dev, const framework::Tensor* X, const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); auto x = framework::EigenVector::Flatten(detail::Ref(X)); if (ddOut) { auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); ddout.device(*d) = ddx * static_cast(2) * x; } if (dX) { auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); dx.device(*d) = ddx * static_cast(2) * dout; } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // DOut(dy) as input(not output), tensor extraction is different from // others. Impliment extraction kernel seperately here. inline void ExtractDoubleGradTensorWithInputDOut( const framework::ExecutionContext& ctx, const framework::Tensor** X, const framework::Tensor** ddX, framework::Tensor** dX, const framework::Tensor** dOut, framework::Tensor** ddOut) { // extract ddX(output), ddOut(input) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE(ddx_var != nullptr, "Cannot get input Variable Out, variable name = %s", ctx.op().Input("DDX")); *ddX = ctx.Input("DDX"); if (ddo_var) { *ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE(*ddX != nullptr, "Cannot get output tensor DDX, variable name = %s", ctx.op().Output("DDX")); // extract x(input), dx(output) auto x_var = ctx.InputVar("X"); PADDLE_ENFORCE(x_var != nullptr, "Cannot get input Variable Out, variable name = %s", ctx.op().Input("X")); auto dx_var = ctx.OutputVar("DX"); *X = ctx.Input("X"); if (dx_var) { *dX = ctx.Output("DX"); } // extract dOut(input) auto dout_var = ctx.InputVar("DOut"); if (dout_var) { *dOut = ctx.Input("DOut"); } } template class SquareDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *X, *ddX, *dOut; X = ddX = dOut = nullptr; framework::Tensor *dX, *ddOut; dX = ddOut = nullptr; ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut); if (dX) dX->mutable_data(X->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, X, ddX, ddOut, dOut, dX); } }; template class SqrtDoubleGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor *Out, *dX, *ddX; Out = dX = ddX = nullptr; framework::Tensor *ddOut, *dOut; ddOut = dOut = nullptr; // extract ddx(input), ddout(output) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); PADDLE_ENFORCE(ddx_var != nullptr, "Cannot get input Variable DDX, variable name = %s", ctx.op().Input("DDX")); ddX = ctx.Input("DDX"); if (ddo_var) { ddOut = ctx.Output("DDOut"); } PADDLE_ENFORCE(ddX != nullptr, "Cannot get input Variable DDX, variable name = %s", ctx.op().Input("DDX")); // extract out(input), dout(output) auto out_var = ctx.InputVar("Out"); PADDLE_ENFORCE(out_var != nullptr, "Cannot get input Variable Out, variable name = %s", ctx.op().Input("Out")); auto dout_var = ctx.OutputVar("DOut"); Out = ctx.Input("Out"); if (dout_var) { dOut = ctx.Output("DOut"); } // extract dx(input) auto dx_var = ctx.InputVar("DX"); PADDLE_ENFORCE(dx_var != nullptr, "Cannot get input Variable DX, variable name = %s", ctx.op().Input("DX")); if (dx_var) { dX = ctx.Input("DX"); } if (dOut) dOut->mutable_data(Out->dims(), ctx.GetPlace()); if (ddOut) ddOut->mutable_data(Out->dims(), ctx.GetPlace()); auto& place = ctx.template device_context(); Functor functor; functor(place, Out, ddX, ddOut, dOut, dX); } }; } // namespace operators } // namespace paddle #define FOR_EACH_ACTIVATION_OP(__macro) \ __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, Exp, ExpFunctor, ExpGradFunctor); \ __macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \ __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \ __macro(abs, Abs, AbsFunctor, AbsGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(cos, Cos, CosFunctor, CosGradFunctor); \ __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \ __macro(sin, Sin, SinFunctor, SinGradFunctor); \ __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log, Log, LogFunctor, LogGradFunctor); \ __macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(pow, Pow, PowFunctor, PowGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(elu, ELU, ELUFunctor, ELUGradFunctor); \ __macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \ __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ HardSigmoidGradFunctor); \ __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \ ThresholdedReluGradFunctor);