/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif namespace paddle { namespace operators { template class ActivationKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto& X = detail::Ref(context.Input("X"), "Cannot get input tensor X, variable name = %s", context.op().Input("X")); auto& Out = detail::Ref(context.Output("Out"), "Cannot get output tensor Out, variable name = %s", context.op().Output("Out")); Out.mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(X); auto out = framework::EigenVector::Flatten(Out); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } functor(*place, x, out); } }; template class ActivationGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten(*dOut); auto x = framework::EigenVector::Flatten(*X); auto out = framework::EigenVector::Flatten(*Out); auto dx = framework::EigenVector::Flatten(*dX); auto* place = context.template device_context().eigen_device(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } functor(*place, x, out, dout, dx); } }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template struct SigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out * (static_cast(1) - out); } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: // out = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) // = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) // // Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; // Originally: f' = exp(-x) / (1 + exp(-x)) // For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } }; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } }; // relu(x) = max(x, 0) template struct ReluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x > static_cast(0)).template cast(); } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) - out * out); } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x.tanh() * x.tanh()); } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct HardShrinkFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); out.device(d) = x * (temp1 + temp2); } }; template struct HardShrinkGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; template struct SoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { const Out out_conj = Eigen::numext::conj(out); dx.device(d) = static_cast(0.5) * dout / out_conj; } }; // ceil(x) = ceiling(x) template struct CeilFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.ceil(); } }; template struct ZeroGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) / x; } }; // floor(x) = flooring(x) template struct FloorFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.floor(); } }; template struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } }; // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Cosine()); } }; // sine'(x) = cos(x) template struct SinGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } }; // sine(x) = sin(x) template struct SinFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.unaryExpr(Sine()); } }; // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.round(); } }; // abs(x) = |x| template struct AbsFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.abs(); } }; template struct AbsGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.sign(); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(-1) * out * out; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / x); } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(2) * x; } }; template struct BReluFunctor : public BaseActivationFunctor { float t_min; float t_max; // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; template struct BReluGradFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } }; // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; template struct Relu6GradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((x > static_cast(0)) * (x < static_cast(threshold))) .template cast(); } }; // softplus(x) = log(1 + exp(x)) // When x is a very large positive number, exp(x) may explode to inf, // Using trick below for numerical stability // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) template struct SoftplusFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); } }; // d(softplus(x))/dx = exp(x) / (1 + exp(x)) // For numerical stability: // d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + // exp(x - max(x, 0))) template struct SoftplusGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) dx.device(d) = dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) { out.device(d) = x / (static_cast(1) + x.abs()); } }; // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) { dx.device(d) = dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } }; template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); out.device(d) = (static_cast(1) + temp.exp()).log(); } }; template struct SoftReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((x > -tmp) * (x < tmp)).template cast().eval(); dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } }; template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(alpha) * x); } }; template struct LeakyReluGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast().eval(); auto temp2 = (x >= static_cast(0)).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } }; template struct ELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)) + (static_cast(alpha) * (x.exp() - static_cast(1))) .cwiseMin(static_cast(0)); } }; template struct ELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x > static_cast(0)).template cast() + dout * (out + static_cast(alpha)) * (x < static_cast(0)).template cast(); } }; // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 template struct PowFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x.pow(static_cast(factor)); } }; template struct PowGradFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor - static_cast(1))); } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dout * a * b * (static_cast(1) - temp); } }; template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); out.device(d) = (x > th).template cast() * x; } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); dx.device(d) = dout * (x > th).template cast(); } }; template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out) const { auto temp = x * static_cast(slope) + static_cast(offset); out.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; template struct HardSigmoidGradFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * ((out > static_cast(0)) * (out < static_cast(1))) .template cast() * static_cast(slope); } }; template struct SwishFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out out) const { out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); } }; template struct SwishGradFunctor : public BaseActivationFunctor { float beta; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); auto temp2 = temp1 * (static_cast(1) - (beta * out)); dx.device(d) = dout * ((beta * out) + temp2); } }; } // namespace operators } // namespace paddle #define FOR_EACH_KERNEL_FUNCTOR(__macro) \ __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ __macro(abs, AbsFunctor, AbsGradFunctor); \ __macro(ceil, CeilFunctor, ZeroGradFunctor); \ __macro(floor, FloorFunctor, ZeroGradFunctor); \ __macro(cos, CosFunctor, CosGradFunctor); \ __macro(sin, SinFunctor, SinGradFunctor); \ __macro(round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log, LogFunctor, LogGradFunctor); \ __macro(square, SquareFunctor, SquareGradFunctor); \ __macro(brelu, BReluFunctor, BReluGradFunctor); \ __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(pow, PowFunctor, PowGradFunctor); \ __macro(stanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6Functor, Relu6GradFunctor); \ __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(elu, ELUFunctor, ELUGradFunctor); \ __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \ __macro(swish, SwishFunctor, SwishGradFunctor); \ __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);