提交 dadace31 编写于 作者: Q qijun

add more activation functors

上级 e515f18d
...@@ -55,6 +55,8 @@ class ActivationGradKernel : public framework::OpKernel { ...@@ -55,6 +55,8 @@ class ActivationGradKernel : public framework::OpKernel {
} }
}; };
// sigmoid = 1 / (1 + exp(-x)
template <typename T>
struct SigmoidFunctor { struct SigmoidFunctor {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) {
...@@ -69,6 +71,7 @@ struct SigmoidGradFunctor { ...@@ -69,6 +71,7 @@ struct SigmoidGradFunctor {
} }
}; };
// exp(x) = e^x
struct ExpFunctor { struct ExpFunctor {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) {
...@@ -79,10 +82,11 @@ struct ExpFunctor { ...@@ -79,10 +82,11 @@ struct ExpFunctor {
struct ExpGradFunctor { struct ExpGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = y; dx.device(d) = dy * y;
} }
}; };
// relu(x) = max(x, 0)
template <typename T> template <typename T>
struct ReluFunctor { struct ReluFunctor {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
...@@ -99,6 +103,7 @@ struct ReluGradFunctor { ...@@ -99,6 +103,7 @@ struct ReluGradFunctor {
} }
}; };
// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
struct TanhFunctor { struct TanhFunctor {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) {
...@@ -114,6 +119,7 @@ struct TanhGradFunctor { ...@@ -114,6 +119,7 @@ struct TanhGradFunctor {
} }
}; };
// sqrt(x) = x^(1/2)
struct SqrtFunctor { struct SqrtFunctor {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) {
...@@ -130,5 +136,59 @@ struct SqrtGradFunctor { ...@@ -130,5 +136,59 @@ struct SqrtGradFunctor {
} }
}; };
// abs(x) = |x|
struct AbsFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.abs();
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = 1. / x;
}
};
struct ReciprocalGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (-1.0) * y * y;
}
};
// log(x) = natural logarithm of x
struct LogFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.log();
}
};
struct LogGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (1. / x);
}
};
// square(x) = x^2
struct SquareFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.square();
}
}
struct SquareGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * 2 * x;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册