From e0be63bf0912bf84be2f92679d15239d31b5108d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 26 Dec 2017 12:41:25 +0800 Subject: [PATCH] change activations --- paddle/operators/activation_op.cc | 118 ++++----- paddle/operators/activation_op.h | 387 ++++++++++++++++-------------- 2 files changed, 267 insertions(+), 238 deletions(-) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 2b4c7e5f0de..ef35779889f 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputDim("Y", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Y"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); } }; @@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sigmoid operator"); - AddOutput("Y", "Output of Sigmoid operator"); + AddOutput("Out", "Output of Sigmoid operator"); AddComment(R"DOC( Sigmoid Activation Operator -$$y = \frac{1}{1 + e^{-x}}$$ +$$out = \frac{1}{1 + e^{-x}}$$ )DOC"); } @@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of LogSigmoid operator"); - AddOutput("Y", "Output of LogSigmoid operator"); + AddOutput("Out", "Output of LogSigmoid operator"); AddComment(R"DOC( Logsigmoid Activation Operator -$$y = \log \frac{1}{1 + e^{-x}}$$ +$$out = \log \frac{1}{1 + e^{-x}}$$ )DOC"); } @@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Exp operator"); - AddOutput("Y", "Output of Exp operator"); + AddOutput("Out", "Output of Exp operator"); AddComment(R"DOC( Exp Activation Operator. -$y = e^x$ +$out = e^x$ )DOC"); } @@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu operator"); - AddOutput("Y", "Output of Relu operator"); + AddOutput("Out", "Output of Relu operator"); AddComment(R"DOC( Relu Activation Operator. -$y = \max(x, 0)$ +$out = \max(x, 0)$ )DOC"); } @@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of LeakyRelu operator"); - AddOutput("Y", "Output of LeakyRelu operator"); + AddOutput("Out", "Output of LeakyRelu operator"); AddAttr("alpha", "The small negative slope").SetDefault(0.02f); AddComment(R"DOC( LeakyRelu Activation Operator. -$y = \max(x, \alpha * x)$ +$out = \max(x, \alpha * x)$ )DOC"); } @@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softshrink operator"); - AddOutput("Y", "Output of Softshrink operator"); + AddOutput("Out", "Output of Softshrink operator"); AddAttr("lambda", "non-negative offset").SetDefault(0.5f); AddComment(R"DOC( Softshrink Activation Operator. $$ -y = \begin{cases} +out = \begin{cases} x - \lambda, \text{if } x > \lambda \\ x + \lambda, \text{if } x < -\lambda \\ 0, \text{otherwise} @@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Tanh operator"); - AddOutput("Y", "Output of Tanh operator"); + AddOutput("Out", "Output of Tanh operator"); AddComment(R"DOC( Tanh Activation Operator. -$$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ +$$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ )DOC"); } @@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of TanhShrink operator"); - AddOutput("Y", "Output of TanhShrink operator"); + AddOutput("Out", "Output of TanhShrink operator"); AddComment(R"DOC( TanhShrink Activation Operator. -$$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ +$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ )DOC"); } @@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of HardShrink operator"); - AddOutput("Y", "Output of HardShrink operator"); + AddOutput("Out", "Output of HardShrink operator"); AddAttr("threshold", "The value of threshold for HardShrink") .SetDefault(0.5f); AddComment(R"DOC( HardShrink Activation Operator. $$ -y = \begin{cases} +out = \begin{cases} x, \text{if } x > \lambda \\ x, \text{if } x < -\lambda \\ 0, \text{otherwise} @@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sqrt operator"); - AddOutput("Y", "Output of Sqrt operator"); + AddOutput("Out", "Output of Sqrt operator"); AddComment(R"DOC( Sqrt Activation Operator. -$y = \sqrt{x}$ +$out = \sqrt{x}$ )DOC"); } @@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Abs operator"); - AddOutput("Y", "Output of Abs operator"); + AddOutput("Out", "Output of Abs operator"); AddComment(R"DOC( Abs Activation Operator. -$y = |x|$ +$out = |x|$ )DOC"); } @@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker { CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Ceil operator"); - AddOutput("Y", "Output of Ceil operator"); + AddOutput("Out", "Output of Ceil operator"); AddComment(R"DOC( Ceil Activation Operator. -$y = ceil(x)$ +$out = ceil(x)$ )DOC"); } @@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker { FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Floor operator"); - AddOutput("Y", "Output of Floor operator"); + AddOutput("Out", "Output of Floor operator"); AddComment(R"DOC( Floor Activation Operator. -$y = floor(x)$ +$out = floor(x)$ )DOC"); } @@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker { RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Round operator"); - AddOutput("Y", "Output of Round operator"); + AddOutput("Out", "Output of Round operator"); AddComment(R"DOC( Round Activation Operator. -$y = [x]$ +$out = [x]$ )DOC"); } @@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Reciprocal operator"); - AddOutput("Y", "Output of Reciprocal operator"); + AddOutput("Out", "Output of Reciprocal operator"); AddComment(R"DOC( Reciprocal Activation Operator. -$$y = \frac{1}{x}$$ +$$out = \frac{1}{x}$$ )DOC"); } @@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker { LogOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Log operator"); - AddOutput("Y", "Output of Log operator"); + AddOutput("Out", "Output of Log operator"); AddComment(R"DOC( Log Activation Operator. -$y = \ln(x)$ +$out = \ln(x)$ Natural logarithm of x. @@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Square operator"); - AddOutput("Y", "Output of Square operator"); + AddOutput("Out", "Output of Square operator"); AddComment(R"DOC( Square Activation Operator. -$y = x^2$ +$out = x^2$ )DOC"); } @@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softplus operator"); - AddOutput("Y", "Output of Softplus operator"); + AddOutput("Out", "Output of Softplus operator"); AddComment(R"DOC( Softplus Activation Operator. -$y = \ln(1 + e^{x})$ +$out = \ln(1 + e^{x})$ )DOC"); } @@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softsign operator"); - AddOutput("Y", "Output of Softsign operator"); + AddOutput("Out", "Output of Softsign operator"); AddComment(R"DOC( Softsign Activation Operator. -$$y = \frac{x}{1 + |x|}$$ +$$out = \frac{x}{1 + |x|}$$ )DOC"); } @@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of BRelu operator"); - AddOutput("Y", "Output of BRelu operator"); + AddOutput("Out", "Output of BRelu operator"); AddAttr("t_min", "The min marginal value of BRelu") .SetDefault(static_cast(0)); AddAttr("t_max", "The max marginal value of BRelu") @@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( BRelu Activation Operator. -$y = \max(\min(x, t_{min}), t_{max})$ +$out = \max(\min(x, t_{min}), t_{max})$ )DOC"); } @@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of SoftRelu operator"); - AddOutput("Y", "Output of SoftRelu operator"); + AddOutput("Out", "Output of SoftRelu operator"); AddAttr("threshold", "The threshold value of SoftRelu") .SetDefault(40.0f); AddComment(R"DOC( SoftRelu Activation Operator. -$y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ +$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ )DOC"); } @@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of ELU operator"); - AddOutput("Y", "Output of ELU operator"); + AddOutput("Out", "Output of ELU operator"); AddAttr("alpha", "The alpha value of ELU").SetDefault(1.0f); AddComment(R"DOC( ELU Activation Operator. @@ -388,7 +388,7 @@ ELU Activation Operator. Applies the following element-wise computation on the input according to https://arxiv.org/abs/1511.07289. -$y = \max(0, x) + \min(0, \alpha * (e^x - 1))$ +$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$ )DOC"); } @@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu6 operator"); - AddOutput("Y", "Output of Relu6 operator"); + AddOutput("Out", "Output of Relu6 operator"); AddAttr("threshold", "The threshold value of Relu6") .SetDefault(6.0f); AddComment(R"DOC( Relu6 Activation Operator. -$y = \min(\max(0, x), 6)$ +$out = \min(\max(0, x), 6)$ )DOC"); } @@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker { PowOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Pow operator"); - AddOutput("Y", "Output of Pow operator"); + AddOutput("Out", "Output of Pow operator"); AddAttr("factor", "The exponential factor of Pow").SetDefault(1.0f); AddComment(R"DOC( Pow Activation Operator. -$y = x^{factor}$ +$out = x^{factor}$ )DOC"); } @@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of STanh operator"); - AddOutput("Y", "Output of STanh operator"); + AddOutput("Out", "Output of STanh operator"); AddAttr("scale_a", "The scale parameter of a for the input") .SetDefault(2.0f / 3.0f); AddAttr("scale_b", "The scale parameter of b for the input") @@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( STanh Activation Operator. -$$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ +$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ )DOC"); } @@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of ThresholdedRelu operator"); - AddOutput("Y", "Output of ThresholdedRelu operator"); + AddOutput("Out", "Output of ThresholdedRelu operator"); AddAttr("threshold", "The threshold location of activation") .SetDefault(1.0f); AddComment(R"DOC( ThresholdedRelu Activation Operator. $$ -y = \begin{cases} +out = \begin{cases} x, \text{if } x > threshold \\ 0, \text{otherwise} \end{cases} @@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of HardSigmoid operator"); - AddOutput("Y", "Output of HardSigmoid operator"); + AddOutput("Out", "Output of HardSigmoid operator"); AddAttr("slope", "Slope for linear approximation of sigmoid") .SetDefault(0.2f); AddAttr("offset", "Offset for linear approximation of sigmoid") @@ -484,7 +484,7 @@ HardSigmoid Activation Operator. Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), which is much faster than sigmoid. -$y = \max(0, \min(1, slope * x + shift))$ +$out = \max(0, \min(1, slope * x + shift))$ The slope should be positive. The offset can be either positive or negative. The default slope and shift are set according to the above reference. @@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Swish operator"); - AddOutput("Y", "Output of Swish operator"); + AddOutput("Out", "Output of Swish operator"); AddAttr("beta", "Constant beta of swish operator").SetDefault(1.0f); AddComment(R"DOC( Swish Activation Operator. -$$y = \frac{x}{1 + e^{- \beta x}}$$ +$$out = \frac{x}{1 + e^{- \beta x}}$$ )DOC"); } diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 75eefca8b8c..980bf93fcf7 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -27,11 +27,11 @@ class ActivationKernel void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); - auto* Y = context.Output("Y"); - Y->mutable_data(context.GetPlace()); + auto* Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); + auto out = framework::EigenVector::Flatten(*Out); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -40,7 +40,7 @@ class ActivationKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, y); + functor(*place, x, out); } }; @@ -51,14 +51,15 @@ class ActivationGradKernel using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* dY = context.Input(framework::GradVarName("Y")); + auto* Out = context.Input("Out"); + auto* dOut = + context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); - auto dy = framework::EigenVector::Flatten(*dY); + auto dout = framework::EigenVector::Flatten(*dOut); auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); + auto out = framework::EigenVector::Flatten(*Out); auto dx = framework::EigenVector::Flatten(*dX); auto* place = context.template device_context().eigen_device(); @@ -67,7 +68,7 @@ class ActivationGradKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, y, dy, dx); + functor(*place, x, out, dout, dx); } }; @@ -83,17 +84,18 @@ struct BaseActivationFunctor { // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template struct SigmoidGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * y * (static_cast(1) - y); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out * (static_cast(1) - out); } }; @@ -101,7 +103,7 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: -// y = -log( exp(0) + exp(-x)) [since exp(0) = 1] +// out = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) @@ -112,10 +114,10 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) - y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); + out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; @@ -124,62 +126,66 @@ struct LogSigmoidFunctor : public BaseActivationFunctor { // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = - dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); + dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } }; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.exp(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * y; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out; } }; // relu(x) = max(x, 0) template struct ReluFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.cwiseMax(static_cast(0)); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (x > static_cast(0)).template cast(); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x > static_cast(0)).template cast(); } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.tanh(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (static_cast(1) - y * y); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) - out * out); } }; @@ -187,17 +193,18 @@ struct TanhGradFunctor : public BaseActivationFunctor { // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x - x.tanh(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (x.tanh() * x.tanh()); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x.tanh() * x.tanh()); } }; @@ -210,11 +217,11 @@ struct HardShrinkFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); - y.device(d) = x * (temp1 + temp2); + out.device(d) = x * (temp1 + temp2); } }; @@ -226,11 +233,12 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); - dx.device(d) = dy * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); } }; @@ -243,12 +251,12 @@ struct SoftShrinkFunctor : public BaseActivationFunctor { return {{"lambda", &lambda}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); - y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); + out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; @@ -258,46 +266,49 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); - dx.device(d) = dy * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.sqrt(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - const Y y_conj = Eigen::numext::conj(y); - dx.device(d) = static_cast(0.5) * dy / y_conj; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + const Out out_conj = Eigen::numext::conj(out); + dx.device(d) = static_cast(0.5) * dout / out_conj; } }; // ceil(x) = ceiling(x) template struct CeilFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.ceil(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.ceil(); } }; template struct ZeroGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) / x; } }; @@ -305,86 +316,90 @@ struct ZeroGradFunctor : public BaseActivationFunctor { // floor(x) = flooring(x) template struct FloorFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.ceil(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.ceil(); } }; // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.round(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.round(); } }; // abs(x) = |x| template struct AbsFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.abs(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.abs(); } }; template struct AbsGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * x.sign(); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.sign(); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = static_cast(1) / x; + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * static_cast(-1) * y * y; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(-1) * out * out; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.log(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (static_cast(1) / x); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) / x); } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.square(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * static_cast(2) * x; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(2) * x; } }; @@ -399,9 +414,9 @@ struct BReluFunctor : public BaseActivationFunctor { return {{"t_min", &t_min}, {"t_max", &t_max}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; @@ -413,9 +428,10 @@ struct BReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } @@ -430,9 +446,9 @@ struct Relu6Functor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; @@ -443,9 +459,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * ((x > static_cast(0)) * (x < static_cast(threshold))) .template cast(); } @@ -458,10 +475,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor { // Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) template struct SoftplusFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) { + template + void operator()(Device d, X x, Out out) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) - y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); + out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); } }; @@ -471,19 +488,21 @@ struct SoftplusFunctor : public BaseActivationFunctor { // exp(x - max(x, 0))) template struct SoftplusGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) - dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); + dx.device(d) = + dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) { - y.device(d) = x / (static_cast(1) + x.abs()); + template + void operator()(Device d, X x, Out out) { + out.device(d) = x / (static_cast(1) + x.abs()); } }; @@ -491,10 +510,11 @@ struct SoftsignFunctor : public BaseActivationFunctor { // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) { dx.device(d) = - dy * (static_cast(1) / (static_cast(1) + x.abs()).square()); + dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } }; @@ -505,11 +525,11 @@ struct SoftReluFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); - y.device(d) = (static_cast(1) + temp.exp()).log(); + out.device(d) = (static_cast(1) + temp.exp()).log(); } }; @@ -519,11 +539,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((x > -tmp) * (x < tmp)).template cast().eval(); - dx.device(d) = dy * (static_cast(1) - (-y).exp()) * temp; + dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } }; @@ -534,9 +555,9 @@ struct LeakyReluFunctor : public BaseActivationFunctor { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.cwiseMax(static_cast(alpha) * x); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(alpha) * x); } }; @@ -546,12 +567,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast().eval(); auto temp2 = (x >= static_cast(0)).template cast().eval(); - dx.device(d) = dy * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); } }; @@ -562,11 +584,11 @@ struct ELUFunctor : public BaseActivationFunctor { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.cwiseMax(static_cast(0)) + - (static_cast(alpha) * (x.exp() - static_cast(1))) - .cwiseMin(static_cast(0)); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)) + + (static_cast(alpha) * (x.exp() - static_cast(1))) + .cwiseMin(static_cast(0)); } }; @@ -576,10 +598,11 @@ struct ELUGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (x > static_cast(0)).template cast() + - dy * (y + static_cast(alpha)) * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x > static_cast(0)).template cast() + + dout * (out + static_cast(alpha)) * (x < static_cast(0)).template cast(); } }; @@ -591,9 +614,9 @@ struct PowFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.pow(static_cast(factor)); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.pow(static_cast(factor)); } }; @@ -603,9 +626,10 @@ struct PowGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * static_cast(factor) * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor - static_cast(1))); } }; @@ -618,9 +642,9 @@ struct STanhFunctor : public BaseActivationFunctor { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; @@ -633,12 +657,13 @@ struct STanhGradFunctor : public BaseActivationFunctor { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); - dx.device(d) = dy * a * b * (static_cast(1) - temp); + dx.device(d) = dout * a * b * (static_cast(1) - temp); } }; @@ -649,10 +674,10 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); - y.device(d) = (x > th).template cast() * x; + out.device(d) = (x > th).template cast() * x; } }; @@ -663,10 +688,11 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); - dx.device(d) = dy * (x > th).template cast(); + dx.device(d) = dout * (x > th).template cast(); } }; @@ -678,10 +704,11 @@ struct HardSigmoidFunctor : public BaseActivationFunctor { return {{"slope", &slope}, {"offset", &offset}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto temp = x * static_cast(slope) + static_cast(offset); - y.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); + out.device(d) = + temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; @@ -693,12 +720,13 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { return {{"slope", &slope}, {"offset", &offset}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = - dy * - ((y > static_cast(0)) * (y < static_cast(1))).template cast() * - static_cast(slope); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * + ((out > static_cast(0)) * (out < static_cast(1))) + .template cast() * + static_cast(slope); } }; @@ -709,9 +737,9 @@ struct SwishFunctor : public BaseActivationFunctor { return {{"beta", &beta}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); } }; @@ -722,12 +750,13 @@ struct SwishGradFunctor : public BaseActivationFunctor { return {{"beta", &beta}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (beta * y)); - dx.device(d) = dy * ((beta * y) + temp2); + auto temp2 = temp1 * (static_cast(1) - (beta * out)); + dx.device(d) = dout * ((beta * out) + temp2); } }; -- GitLab