diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b516fc8a418599d429e47748f53e8a6ed1f65624..1739aa2924d2e7fd97d07a2a39ba8323002f41c3 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -363,17 +363,64 @@ struct GeluFunctor : public BaseActivationFunctor { } }; +// gelu_grad(x) = dout * (0.5 * (1 + erf(x / sqrt(2))) + 0.5 * 2 / sqrt(pie) / +// sqrt(2) * x * exp (-0.5 * sqrt(x))) +// gelu_grad(x) = dout * (0.5 + 0.5 * erf(x * M_SQRT1_2) + (0.5 * M_2_SQRTPI * +// M_SQRT1_2) * x * exp (-0.5 * sqrt(x))) template struct GeluGradFunctor : BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { +#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + auto x_data = x.data(); + auto dx_data = dx.data(); + int n = std::min(x.size(), dx.size()); + + std::memset(dx_data, 0, n * sizeof(T)); + + // First(dx_data) = erf(x * M_SQRT1_2) + math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, dx_data, 1); + math::CBlas::VMERF(n, dx_data, dx_data, VML_LA); + + // Second = 0.5 * M_2_SQRTPI * M_SQRT1_2 * x * exp (-0.5 * sqrt(x)) + auto second = static_cast(std::malloc(n * sizeof(T))); + std::memset(second, 0, n * sizeof(T)); + + math::CBlas::VSQUARE(n, x_data, second); + for (int i = 0; i < n; i++) { + second[i] *= static_cast(-0.5); + } + math::CBlas::VEXP(n, second, second); + math::CBlas::VMUL(n, x_data, second, second); + T tmp = static_cast(0.5) * static_cast(M_SQRT1_2) * + static_cast(M_2_SQRTPI); + for (int i = 0; i < n; i++) { + second[i] *= tmp; + } + + // Sum = 0.5 * First + Second + math::CBlas::AXPY(n, static_cast(0.5), dx_data, 1, second, 1); + + // 0.5 + Sum + for (int i = 0; i < n; i++) { + second[i] += static_cast(0.5); + } + + // * dout + auto dout_data = dout.data(); + math::CBlas::VMUL(n, dout_data, second, dx_data); + + std::free(second); +#else auto first = static_cast(0.5) * (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * (-static_cast(0.5) * x.square()).exp(); dx.device(d) = dout * (first + second); +#endif } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }