diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 0b7e171e722e62d987675033b7c48f762048a61d..4421c109574f0a5f5ada6c4959330c70e247588c 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -99,5 +99,36 @@ struct ReluGradFunctor { } }; +struct TanhFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.tanh(); + } +}; + +template +struct TanhGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (T(1) - y * y); + } +}; + +struct SqrtFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.sqrt(); + } +}; + +template +struct SqrtGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + const T y_conj = Eigen::numext::conj(y); + dx.device(d) = static_cast(0.5) * dy / y_conj; + } +}; + } // namespace operators } // namespace paddle