From 370224821cb8c36316248e36e8c78bb38e434010 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Mon, 27 Dec 2021 10:14:33 +0800 Subject: [PATCH] gelu using normcdf for cudnn (#38450) --- paddle/fluid/operators/gelu_op.cu | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index d533d79a036..8151d21fa67 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -41,11 +41,7 @@ struct GeluWithoutApproximateFunctor { inline HOSTDEVICE T operator()(T arg_x) { // actual gelu with approximation = false MPType x = static_cast(arg_x); - MPType one = static_cast(1); - MPType half = static_cast(0.5); - MPType erf_out = erf(x * static_cast(M_SQRT1_2)); - MPType out = x * half * (one + erf_out); - return static_cast(out); + return static_cast(x * normcdf(x)); } }; @@ -100,12 +96,10 @@ struct GeluWithoutApproximateGradFunctor { inline HOSTDEVICE T operator()(T arg_x, T arg_dout) { MPType x = static_cast(arg_x); MPType dout = static_cast(arg_dout); - MPType one = static_cast(1); - MPType half = static_cast(0.5); - MPType kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - auto ans = half * (one + erf(x * static_cast(M_SQRT1_2))) + - half * kAlpha * x * exp(-half * x * x); - return static_cast(ans * dout); + constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast(0.5); + const MPType cdf = normcdf(x); + const MPType pdf = exp(static_cast(-0.5) * x * x) * kBeta; + return static_cast(dout * (cdf + x * pdf)); } }; -- GitLab