diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 4421c109574f0a5f5ada6c4959330c70e247588c..9bf340f2ed46a8e7b64a15ea23a26feb68460c63 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -55,6 +55,8 @@ class ActivationGradKernel : public framework::OpKernel { } }; +// sigmoid = 1 / (1 + exp(-x) +template struct SigmoidFunctor { template void operator()(Device d, X x, Y y) { @@ -69,6 +71,7 @@ struct SigmoidGradFunctor { } }; +// exp(x) = e^x struct ExpFunctor { template void operator()(Device d, X x, Y y) { @@ -79,10 +82,11 @@ struct ExpFunctor { struct ExpGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = y; + dx.device(d) = dy * y; } }; +// relu(x) = max(x, 0) template struct ReluFunctor { template @@ -99,6 +103,7 @@ struct ReluGradFunctor { } }; +// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) struct TanhFunctor { template void operator()(Device d, X x, Y y) { @@ -114,6 +119,7 @@ struct TanhGradFunctor { } }; +// sqrt(x) = x^(1/2) struct SqrtFunctor { template void operator()(Device d, X x, Y y) { @@ -130,5 +136,59 @@ struct SqrtGradFunctor { } }; +// abs(x) = |x| +struct AbsFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.abs(); + } +}; + +// reciprocal(x) = 1 / x +template +struct ReciprocalFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = 1. / x; + } +}; + +struct ReciprocalGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (-1.0) * y * y; + } +}; + +// log(x) = natural logarithm of x +struct LogFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.log(); + } +}; + +struct LogGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (1. / x); + } +}; + +// square(x) = x^2 +struct SquareFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.square(); + } +} + +struct SquareGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * 2 * x; + } +}; + } // namespace operators } // namespace paddle