From d623e863c9348c19d9438f2f0b14f3b011877eda Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Tue, 19 Nov 2019 04:41:57 +0100 Subject: [PATCH] Fix GELU grad error (#21204) test=develop --- paddle/fluid/operators/activation_op.h | 49 +++++++++++--------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 8c70c9abccd..9f0203b2149 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -340,10 +340,8 @@ struct GeluFunctor : public BaseActivationFunctor { } }; -// gelu_grad(x) = dout * (0.5 * (1 + erf(x / sqrt(2))) + 0.5 * 2 / sqrt(pie) / -// sqrt(2) * x * exp (-0.5 * sqrt(x))) -// gelu_grad(x) = dout * (0.5 + 0.5 * erf(x * M_SQRT1_2) + (0.5 * M_2_SQRTPI * -// M_SQRT1_2) * x * exp (-0.5 * sqrt(x))) +// gelu_grad(x) = dout * (0.5 * (1 + erf(x / sqrt(2))) + 0.5 * 2 / sqrt(pi) / +// sqrt(2) * x * exp (-0.5 * x^2)) template struct GeluGradFunctor : BaseActivationFunctor { template { !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) auto x_data = x.data(); auto dx_data = dx.data(); + auto dout_data = dout.data(); int n = std::min(x.size(), dx.size()); - std::memset(dx_data, 0, n * sizeof(T)); - - // First(dx_data) = erf(x * M_SQRT1_2) - math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, dx_data, 1); - math::CBlas::VMERF(n, dx_data, dx_data, VML_LA); - - // Second = 0.5 * M_2_SQRTPI * M_SQRT1_2 * x * exp (-0.5 * sqrt(x)) + auto first = static_cast(std::malloc(n * sizeof(T))); + std::memset(first, 0, n * sizeof(T)); auto second = static_cast(std::malloc(n * sizeof(T))); std::memset(second, 0, n * sizeof(T)); - math::CBlas::VSQUARE(n, x_data, second); + // first = (0.5 * (1 + erf(x / sqrt(2)))) + math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, first, 1); + math::CBlas::VMERF(n, first, first, VML_LA); for (int i = 0; i < n; i++) { - second[i] *= static_cast(-0.5); + first[i] += static_cast(1); } + math::CBlas::SCAL(n, static_cast(0.5), first, 1); + + // second = (0.5 * 2/sqrt(pi) * 1/sqrt(2) * x * exp(-0.5 * x^2)) + math::CBlas::VSQUARE(n, x_data, second); + math::CBlas::SCAL(n, -static_cast(0.5), second, 1); math::CBlas::VEXP(n, second, second); math::CBlas::VMUL(n, x_data, second, second); - T tmp = static_cast(0.5) * static_cast(M_SQRT1_2) * - static_cast(M_2_SQRTPI); - for (int i = 0; i < n; i++) { - second[i] *= tmp; - } - - // Sum = 0.5 * First + Second - math::CBlas::AXPY(n, static_cast(0.5), dx_data, 1, second, 1); - - // 0.5 + Sum - for (int i = 0; i < n; i++) { - second[i] += static_cast(0.5); - } + math::CBlas::SCAL(n, static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2), + second, 1); - // * dout - auto dout_data = dout.data(); - math::CBlas::VMUL(n, dout_data, second, dx_data); + // dx = dout * (first + second); + math::CBlas::VADD(n, first, second, first); + math::CBlas::VMUL(n, dout_data, first, dx_data); + std::free(first); std::free(second); #else auto first = static_cast(0.5) * -- GitLab