提交 e515f18d 编写于 作者: Q qijun

add tanh and sqrt activation operators

上级 3110bf9a
...@@ -99,5 +99,36 @@ struct ReluGradFunctor { ...@@ -99,5 +99,36 @@ struct ReluGradFunctor {
} }
}; };
struct TanhFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.tanh();
}
};
template <typename T>
struct TanhGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (T(1) - y * y);
}
};
struct SqrtFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
const T y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册