/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { template class ActivationKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Y = context.Output("Y"); Y->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(*X); auto y = framework::EigenVector::Flatten(*Y); auto place = context.GetEigenDevice(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } functor(place, x, y); } }; template class ActivationGradKernel : public framework::OpKernel { public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Y = context.Input("Y"); auto* dY = context.Input(framework::GradVarName("Y")); auto* dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); auto dy = framework::EigenVector::Flatten(*dY); auto x = framework::EigenVector::Flatten(*X); auto y = framework::EigenVector::Flatten(*Y); auto dx = framework::EigenVector::Flatten(*dX); auto place = context.GetEigenDevice(); Functor functor; auto attrs = functor.GetAttrs(); for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } functor(place, x, y, dy, dx); } }; template struct BaseActivationFunctor { using ELEMENT_TYPE = T; using AttrPair = std::vector>; AttrPair GetAttrs() { return AttrPair(); } }; // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template struct SigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * y * (static_cast(1) - y); } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: // y = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) // = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) // // Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; // Originally: f' = exp(-x) / (1 + exp(-x)) // For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } }; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * y; } }; // relu(x) = max(x, 0) template struct ReluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (x > static_cast(0)).template cast(); } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (static_cast(1) - y * y); } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (x.tanh() * x.tanh()); } }; // tanhshrink(x) = x - tanh(x) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct HardShrinkFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); y.device(d) = x * (temp1 + temp2); } }; template struct HardShrinkGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); dx.device(d) = dy * (temp1 + temp2).template cast(); } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Y y) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; template struct SoftShrinkGradFunctor : public BaseActivationFunctor { float lambda; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); dx.device(d) = dy * (temp1 + temp2).template cast(); } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { const Y y_conj = Eigen::numext::conj(y); dx.device(d) = static_cast(0.5) * dy / y_conj; } }; // abs(x) = |x| template struct AbsFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.abs(); } }; template struct AbsGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * x.sign(); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * static_cast(-1) * y * y; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (static_cast(1) / x); } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) const { y.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * static_cast(2) * x; } }; template struct BReluFunctor : public BaseActivationFunctor { float t_min; float t_max; // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Y y) const { y.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; template struct BReluGradFunctor : public BaseActivationFunctor { float t_min; float t_max; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } }; // relu6(x) = min(max(0, x), 6) template struct Relu6Functor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y) const { y.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; template struct Relu6GradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * ((x > static_cast(0)) * (x < static_cast(threshold))) .template cast(); } }; // softplus(x) = log(1 + exp(x)) // When x is a very large positive number, exp(x) may explode to inf, // Using trick below for numerical stability // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) template struct SoftplusFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); } }; // d(softplus(x))/dx = exp(x) / (1 + exp(x)) // For numerical stability: // d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + // exp(x - max(x, 0))) template struct SoftplusGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y) { y.device(d) = x / (static_cast(1) + x.abs()); } }; // d(softsign(x))/dx = 1 / (1 + |x|)^2 // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { dx.device(d) = dy * (static_cast(1) / (static_cast(1) + x.abs()).square()); } }; template struct SoftReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); y.device(d) = (static_cast(1) + temp.exp()).log(); } }; template struct SoftReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((x > -tmp) * (x < tmp)).template cast().eval(); dx.device(d) = dy * (static_cast(1) - (-y).exp()) * temp; } }; template struct LeakyReluFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Y y) const { y.device(d) = x.cwiseMax(static_cast(alpha) * x); } }; template struct LeakyReluGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast().eval(); auto temp2 = (x >= static_cast(0)).template cast().eval(); dx.device(d) = dy * (temp1 + temp2).template cast(); } }; template struct ELUFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Y y) const { y.device(d) = x.cwiseMax(static_cast(0)) + (static_cast(alpha) * (x.exp() - static_cast(1))) .cwiseMin(static_cast(0)); } }; template struct ELUGradFunctor : public BaseActivationFunctor { float alpha; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * (x > static_cast(0)).template cast() + dy * (y + static_cast(alpha)) * (x < static_cast(0)).template cast(); } }; template struct PowFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Y y) const { y.device(d) = x.pow(static_cast(factor)); } }; template struct PowGradFunctor : public BaseActivationFunctor { float factor; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * static_cast(factor) * x.pow(static_cast(factor - static_cast(1))); } }; template struct STanhFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Y y) const { y.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; template struct STanhGradFunctor : public BaseActivationFunctor { float scale_a; float scale_b; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dy * a * b * (static_cast(1) - temp); } }; template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y) const { auto th = static_cast(threshold); y.device(d) = (x > th).template cast() * x; } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { auto th = static_cast(threshold); dx.device(d) = dy * (x > th).template cast(); } }; template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Y y) const { auto temp = x * static_cast(slope) + static_cast(offset); y.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; template struct HardSigmoidGradFunctor : public BaseActivationFunctor { float slope; float offset; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } template void operator()(Device d, X x, Y y, dY dy, dX dx) const { dx.device(d) = dy * ((y > static_cast(0)) * (y < static_cast(1))).template cast() * static_cast(slope); } }; } // namespace operators } // namespace paddle #define FOR_EACH_KERNEL_FUNCTOR(__macro) \ __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \ __macro(relu, ReluFunctor, ReluGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ __macro(abs, AbsFunctor, AbsGradFunctor); \ __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(log, LogFunctor, LogGradFunctor); \ __macro(square, SquareFunctor, SquareGradFunctor); \ __macro(brelu, BReluFunctor, BReluGradFunctor); \ __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(pow, PowFunctor, PowGradFunctor); \ __macro(stanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6Functor, Relu6GradFunctor); \ __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(elu, ELUFunctor, ELUGradFunctor); \ __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \ __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);