From dadace3178ab1f038bec7d8fcdfb849e8fc6963f Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 14:02:29 +0800 Subject: [PATCH] add more activation functors --- paddle/operators/activation_op.h | 62 +++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 4421c1095..9bf340f2e 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -55,6 +55,8 @@ class ActivationGradKernel : public framework::OpKernel { } }; +// sigmoid = 1 / (1 + exp(-x) +template struct SigmoidFunctor { template void operator()(Device d, X x, Y y) { @@ -69,6 +71,7 @@ struct SigmoidGradFunctor { } }; +// exp(x) = e^x struct ExpFunctor { template void operator()(Device d, X x, Y y) { @@ -79,10 +82,11 @@ struct ExpFunctor { struct ExpGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = y; + dx.device(d) = dy * y; } }; +// relu(x) = max(x, 0) template struct ReluFunctor { template @@ -99,6 +103,7 @@ struct ReluGradFunctor { } }; +// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) struct TanhFunctor { template void operator()(Device d, X x, Y y) { @@ -114,6 +119,7 @@ struct TanhGradFunctor { } }; +// sqrt(x) = x^(1/2) struct SqrtFunctor { template void operator()(Device d, X x, Y y) { @@ -130,5 +136,59 @@ struct SqrtGradFunctor { } }; +// abs(x) = |x| +struct AbsFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.abs(); + } +}; + +// reciprocal(x) = 1 / x +template +struct ReciprocalFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = 1. / x; + } +}; + +struct ReciprocalGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (-1.0) * y * y; + } +}; + +// log(x) = natural logarithm of x +struct LogFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.log(); + } +}; + +struct LogGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (1. / x); + } +}; + +// square(x) = x^2 +struct SquareFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.square(); + } +} + +struct SquareGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * 2 * x; + } +}; + } // namespace operators } // namespace paddle -- GitLab