未验证 提交 37022482 编写于 作者: G Guoxia Wang 提交者: GitHub

gelu using normcdf for cudnn (#38450)

上级 5d902954
......@@ -41,11 +41,7 @@ struct GeluWithoutApproximateFunctor {
inline HOSTDEVICE T operator()(T arg_x) {
// actual gelu with approximation = false
MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType erf_out = erf(x * static_cast<MPType>(M_SQRT1_2));
MPType out = x * half * (one + erf_out);
return static_cast<T>(out);
return static_cast<T>(x * normcdf(x));
}
};
......@@ -100,12 +96,10 @@ struct GeluWithoutApproximateGradFunctor {
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto ans = half * (one + erf(x * static_cast<MPType>(M_SQRT1_2))) +
half * kAlpha * x * exp(-half * x * x);
return static_cast<T>(ans * dout);
constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast<MPType>(0.5);
const MPType cdf = normcdf(x);
const MPType pdf = exp(static_cast<MPType>(-0.5) * x * x) * kBeta;
return static_cast<T>(dout * (cdf + x * pdf));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册