未验证 提交 eca8dcc7 编写于 作者: Z Zhang Zheng 提交者: GitHub

Unify the implementation of activation operation (#32348)

上级 6f6e159a
......@@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 + temp2).template cast<T>();
out.device(d) = x * (temp1 || temp2).template cast<T>();
}
};
......@@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 || temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册