From 6951ef9a55768c8e923623431247825a40bd522a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 12 Dec 2018 15:25:07 +0800 Subject: [PATCH] Fix the gelu backward to avoid nan (#14857) * Fix the gelu backward to avoid nan test=develop * Remove unnecessary calls test=develop --- paddle/fluid/operators/activation_op.h | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 87d549678a0..c7df3ea58a9 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -301,23 +301,22 @@ template struct GeluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - auto temp = - ((x * static_cast(M_SQRT1_2)).erf()).template cast().eval(); + auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } }; template struct GeluGradFunctor : BaseActivationFunctor { - bool Inplace() const { return IsInplace("gelu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp = (static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * - ((-static_cast(0.5) * x.square()).exp())) - .template cast() - .eval(); - dx.device(d) = dout * (out / x + temp); + auto first = static_cast(0.5) * + (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + (-static_cast(0.5) * x.square()).exp(); + dx.device(d) = dout * (first + second); } }; -- GitLab