未验证 提交 6951ef9a 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix the gelu backward to avoid nan (#14857)

* Fix the gelu backward to avoid nan

test=develop

* Remove unnecessary calls

test=develop
上级 322bb8d5
......@@ -301,23 +301,22 @@ template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp =
((x * static_cast<T>(M_SQRT1_2)).erf()).template cast<T>().eval();
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
};
template <typename T>
struct GeluGradFunctor : BaseActivationFunctor<T> {
bool Inplace() const { return IsInplace("gelu"); }
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp = (static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
((-static_cast<T>(0.5) * x.square()).exp()))
.template cast<T>()
.eval();
dx.device(d) = dout * (out / x + temp);
auto first = static_cast<T>(0.5) *
(static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));
auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
(-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册