提交 95113cb2 编写于 作者: P phlrain

fix error; test=develop

上级 c72cf5fa
...@@ -296,8 +296,6 @@ USE_PHI_FUNCTOR(Softsign) ...@@ -296,8 +296,6 @@ USE_PHI_FUNCTOR(Softsign)
template <typename T> template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>; using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
template <typename T> template <typename T>
using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>; using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>;
template <typename T> template <typename T>
...@@ -717,106 +715,6 @@ struct PowGradFunctor : public BaseActivationFunctor<T> { ...@@ -717,106 +715,6 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
<<<<<<< HEAD
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
=======
struct LogitFunctor {
template <typename Device, typename X, typename Out, typename P>
void operator()(Device d, X x, Out out, P p, float eps) const {
// logit(x) = ln(x/(1-x))
auto tmp_x =
(x.cwiseMin(static_cast<T>(1.0 - eps))).cwiseMax(static_cast<T>(eps));
if (!eps) {
out.device(d) = (x < static_cast<T>(0.0) || x > static_cast<T>(1.0))
.select(p.constant(static_cast<T>(NAN)),
(tmp_x / (static_cast<T>(1) - tmp_x)).log());
} else {
out.device(d) = (tmp_x / (static_cast<T>(1) - tmp_x)).log();
}
}
};
template <typename T>
struct LogitGradFunctor {
template <typename Device, typename X, typename dOut, typename dX, typename P>
void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
// logit(x)' = 1/(x*(1-x))
dx.device(d) =
(x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(0)),
dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
}
};
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
>>>>>>> 1904572ac8edb57dfb528e711588758002a168dd
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
<<<<<<< HEAD
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
out.device(d) =
temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
=======
out.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
>>>>>>> 1904572ac8edb57dfb528e711588758002a168dd
}
};
template <typename T>
<<<<<<< HEAD
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
.template cast<T>() *
static_cast<T>(slope);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
=======
struct STanhGradFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
>>>>>>> 1904572ac8edb57dfb528e711588758002a168dd
};
template <typename T> template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> { struct SwishFunctor : public BaseActivationFunctor<T> {
float beta; float beta;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册