未验证 提交 eca8dcc7 编写于 作者: Z Zhang Zheng 提交者: GitHub

Unify the implementation of activation operation (#32348)

上级 6f6e159a
...@@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
auto temp1 = x < static_cast<T>(threshold * -1.f); auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold); auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 + temp2).template cast<T>(); out.device(d) = x * (temp1 || temp2).template cast<T>();
} }
}; };
...@@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = x < static_cast<T>(threshold * -1.f); auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold); auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 || temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册