提交 e0be63bf 编写于 作者: F fengjiayi

change activations

上级 874cac0c
......@@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......@@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
};
......@@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator");
AddOutput("Out", "Output of Sigmoid operator");
AddComment(R"DOC(
Sigmoid Activation Operator
$$y = \frac{1}{1 + e^{-x}}$$
$$out = \frac{1}{1 + e^{-x}}$$
)DOC");
}
......@@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LogSigmoid operator");
AddOutput("Y", "Output of LogSigmoid operator");
AddOutput("Out", "Output of LogSigmoid operator");
AddComment(R"DOC(
Logsigmoid Activation Operator
$$y = \log \frac{1}{1 + e^{-x}}$$
$$out = \log \frac{1}{1 + e^{-x}}$$
)DOC");
}
......@@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator");
AddOutput("Out", "Output of Exp operator");
AddComment(R"DOC(
Exp Activation Operator.
$y = e^x$
$out = e^x$
)DOC");
}
......@@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator");
AddOutput("Out", "Output of Relu operator");
AddComment(R"DOC(
Relu Activation Operator.
$y = \max(x, 0)$
$out = \max(x, 0)$
)DOC");
}
......@@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
AddComment(R"DOC(
LeakyRelu Activation Operator.
$y = \max(x, \alpha * x)$
$out = \max(x, \alpha * x)$
)DOC");
}
......@@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC(
Softshrink Activation Operator.
$$
y = \begin{cases}
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
......@@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator");
AddOutput("Out", "Output of Tanh operator");
AddComment(R"DOC(
Tanh Activation Operator.
$$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
$$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC");
}
......@@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of TanhShrink operator");
AddOutput("Y", "Output of TanhShrink operator");
AddOutput("Out", "Output of TanhShrink operator");
AddComment(R"DOC(
TanhShrink Activation Operator.
$$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC");
}
......@@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
.SetDefault(0.5f);
AddComment(R"DOC(
HardShrink Activation Operator.
$$
y = \begin{cases}
out = \begin{cases}
x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\
0, \text{otherwise}
......@@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator");
AddOutput("Out", "Output of Sqrt operator");
AddComment(R"DOC(
Sqrt Activation Operator.
$y = \sqrt{x}$
$out = \sqrt{x}$
)DOC");
}
......@@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator");
AddOutput("Out", "Output of Abs operator");
AddComment(R"DOC(
Abs Activation Operator.
$y = |x|$
$out = |x|$
)DOC");
}
......@@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Ceil operator");
AddOutput("Y", "Output of Ceil operator");
AddOutput("Out", "Output of Ceil operator");
AddComment(R"DOC(
Ceil Activation Operator.
$y = ceil(x)$
$out = ceil(x)$
)DOC");
}
......@@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Floor operator");
AddOutput("Y", "Output of Floor operator");
AddOutput("Out", "Output of Floor operator");
AddComment(R"DOC(
Floor Activation Operator.
$y = floor(x)$
$out = floor(x)$
)DOC");
}
......@@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Round operator");
AddOutput("Y", "Output of Round operator");
AddOutput("Out", "Output of Round operator");
AddComment(R"DOC(
Round Activation Operator.
$y = [x]$
$out = [x]$
)DOC");
}
......@@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator");
AddOutput("Out", "Output of Reciprocal operator");
AddComment(R"DOC(
Reciprocal Activation Operator.
$$y = \frac{1}{x}$$
$$out = \frac{1}{x}$$
)DOC");
}
......@@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker {
LogOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator");
AddOutput("Out", "Output of Log operator");
AddComment(R"DOC(
Log Activation Operator.
$y = \ln(x)$
$out = \ln(x)$
Natural logarithm of x.
......@@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator");
AddOutput("Out", "Output of Square operator");
AddComment(R"DOC(
Square Activation Operator.
$y = x^2$
$out = x^2$
)DOC");
}
......@@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softplus operator");
AddOutput("Y", "Output of Softplus operator");
AddOutput("Out", "Output of Softplus operator");
AddComment(R"DOC(
Softplus Activation Operator.
$y = \ln(1 + e^{x})$
$out = \ln(1 + e^{x})$
)DOC");
}
......@@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator");
AddOutput("Out", "Output of Softsign operator");
AddComment(R"DOC(
Softsign Activation Operator.
$$y = \frac{x}{1 + |x|}$$
$$out = \frac{x}{1 + |x|}$$
)DOC");
}
......@@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator");
AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<float>(0));
AddAttr<float>("t_max", "The max marginal value of BRelu")
......@@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
BRelu Activation Operator.
$y = \max(\min(x, t_{min}), t_{max})$
$out = \max(\min(x, t_{min}), t_{max})$
)DOC");
}
......@@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
.SetDefault(40.0f);
AddComment(R"DOC(
SoftRelu Activation Operator.
$y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
)DOC");
}
......@@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator");
AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddComment(R"DOC(
ELU Activation Operator.
......@@ -388,7 +388,7 @@ ELU Activation Operator.
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.
$y = \max(0, x) + \min(0, \alpha * (e^x - 1))$
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
)DOC");
}
......@@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
.SetDefault(6.0f);
AddComment(R"DOC(
Relu6 Activation Operator.
$y = \min(\max(0, x), 6)$
$out = \min(\max(0, x), 6)$
)DOC");
}
......@@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker {
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator");
AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
AddComment(R"DOC(
Pow Activation Operator.
$y = x^{factor}$
$out = x^{factor}$
)DOC");
}
......@@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator");
AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
.SetDefault(2.0f / 3.0f);
AddAttr<float>("scale_b", "The scale parameter of b for the input")
......@@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
STanh Activation Operator.
$$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
)DOC");
}
......@@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
.SetDefault(1.0f);
AddComment(R"DOC(
ThresholdedRelu Activation Operator.
$$
y = \begin{cases}
out = \begin{cases}
x, \text{if } x > threshold \\
0, \text{otherwise}
\end{cases}
......@@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(0.2f);
AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
......@@ -484,7 +484,7 @@ HardSigmoid Activation Operator.
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
$y = \max(0, \min(1, slope * x + shift))$
$out = \max(0, \min(1, slope * x + shift))$
The slope should be positive. The offset can be either positive or negative.
The default slope and shift are set according to the above reference.
......@@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator");
AddOutput("Y", "Output of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC(
Swish Activation Operator.
$$y = \frac{x}{1 + e^{- \beta x}}$$
$$out = \frac{x}{1 + e^{- \beta x}}$$
)DOC");
}
......
......@@ -27,11 +27,11 @@ class ActivationKernel
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());
auto* Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
......@@ -40,7 +40,7 @@ class ActivationKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(*place, x, y);
functor(*place, x, out);
}
};
......@@ -51,14 +51,15 @@ class ActivationGradKernel
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
......@@ -67,7 +68,7 @@ class ActivationGradKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(*place, x, y, dy, dx);
functor(*place, x, out, dout, dx);
}
};
......@@ -83,17 +84,18 @@ struct BaseActivationFunctor {
// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
}
};
template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y * (static_cast<T>(1) - y);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out * (static_cast<T>(1) - out);
}
};
......@@ -101,7 +103,7 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
// y = -log( exp(0) + exp(-x)) [since exp(0) = 1]
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
// max(-x, 0)))
......@@ -112,10 +114,10 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
}
};
......@@ -124,62 +126,66 @@ struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
dx.device(d) =
dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
}
};
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.exp();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.exp();
}
};
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out;
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0));
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T>
struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
}
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.tanh();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.tanh();
}
};
template <typename T>
struct TanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) - y * y);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) - out * out);
}
};
......@@ -187,17 +193,18 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x - x.tanh();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x - x.tanh();
}
};
template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x.tanh() * x.tanh());
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh());
}
};
......@@ -210,11 +217,11 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2);
out.device(d) = x * (temp1 + temp2);
}
};
......@@ -226,11 +233,12 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
};
......@@ -243,12 +251,12 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
......@@ -258,46 +266,49 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.sqrt();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
const Out out_conj = Eigen::numext::conj(out);
dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
}
};
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.ceil();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) / x;
}
};
......@@ -305,86 +316,90 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.ceil();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.round();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.round();
}
};
// abs(x) = |x|
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.abs();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.abs();
}
};
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * x.sign();
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.sign();
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / x;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = static_cast<T>(1) / x;
}
};
template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(-1) * y * y;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out;
}
};
// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.log();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log();
}
};
template <typename T>
struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) / x);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / x);
}
};
// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.square();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.square();
}
};
template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(2) * x;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x;
}
};
......@@ -399,9 +414,9 @@ struct BReluFunctor : public BaseActivationFunctor<T> {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
}
};
......@@ -413,9 +428,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>();
}
......@@ -430,9 +446,9 @@ struct Relu6Functor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
}
};
......@@ -443,9 +459,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
.template cast<T>();
}
......@@ -458,10 +475,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
}
};
......@@ -471,19 +488,21 @@ struct SoftplusFunctor : public BaseActivationFunctor<T> {
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
dx.device(d) =
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
}
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x / (static_cast<T>(1) + x.abs());
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
......@@ -491,10 +510,11 @@ struct SoftsignFunctor : public BaseActivationFunctor<T> {
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
dx.device(d) =
dy * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}
};
......@@ -505,11 +525,11 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto tmp = static_cast<T>(threshold);
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast<T>(1) + temp.exp()).log();
out.device(d) = (static_cast<T>(1) + temp.exp()).log();
}
};
......@@ -519,11 +539,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = static_cast<T>(threshold);
auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
}
};
......@@ -534,9 +555,9 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
}
};
......@@ -546,12 +567,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
};
......@@ -562,9 +584,9 @@ struct ELUFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)) +
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0)) +
(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
.cwiseMin(static_cast<T>(0));
}
......@@ -576,10 +598,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + static_cast<T>(alpha)) *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * (out + static_cast<T>(alpha)) *
(x < static_cast<T>(0)).template cast<T>();
}
};
......@@ -591,9 +614,9 @@ struct PowFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(static_cast<T>(factor));
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor));
}
};
......@@ -603,9 +626,10 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(factor) *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1)));
}
};
......@@ -618,9 +642,9 @@ struct STanhFunctor : public BaseActivationFunctor<T> {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
}
};
......@@ -633,12 +657,13 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dy * a * b * (static_cast<T>(1) - temp);
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
}
};
......@@ -649,10 +674,10 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto th = static_cast<T>(threshold);
y.device(d) = (x > th).template cast<T>() * x;
out.device(d) = (x > th).template cast<T>() * x;
}
};
......@@ -663,10 +688,11 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto th = static_cast<T>(threshold);
dx.device(d) = dy * (x > th).template cast<T>();
dx.device(d) = dout * (x > th).template cast<T>();
}
};
......@@ -678,10 +704,11 @@ struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
y.device(d) = temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
out.device(d) =
temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
}
};
......@@ -693,11 +720,12 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy *
((y > static_cast<T>(0)) * (y < static_cast<T>(1))).template cast<T>() *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
.template cast<T>() *
static_cast<T>(slope);
}
};
......@@ -709,9 +737,9 @@ struct SwishFunctor : public BaseActivationFunctor<T> {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
}
};
......@@ -722,12 +750,13 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto temp2 = temp1 * (static_cast<T>(1) - (beta * y));
dx.device(d) = dy * ((beta * y) + temp2);
auto temp2 = temp1 * (static_cast<T>(1) - (beta * out));
dx.device(d) = dout * ((beta * out) + temp2);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册