diff --git a/paddle/fluid/operators/gelu_op.h b/paddle/fluid/operators/gelu_op.h index ad38ec1cc5a17e28e4a37ee405a9123708717906..329b8583192a41c6c088cdbbb3ee7bd68c77f373 100644 --- a/paddle/fluid/operators/gelu_op.h +++ b/paddle/fluid/operators/gelu_op.h @@ -41,9 +41,28 @@ struct GeluFunctor { .tanh(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } else { +#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + auto x_data = x.data(); + auto out_data = out.data(); + int n = std::min(x.size(), out.size()); + + std::memset(out_data, 0, n * sizeof(T)); + math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, out_data, + 1); + math::CBlas::VMERF(n, out_data, out_data, VML_LA); + for (int i = 0; i < n; i++) { + out_data[i] += static_cast(1); + } + math::CBlas::VMUL(n, x_data, out_data, out_data); + for (int i = 0; i < n; i++) { + out_data[i] *= static_cast(0.5); + } +#else // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); +#endif } } }; @@ -61,6 +80,41 @@ struct GeluGradFunctor { (static_cast(1) + y + (x - x * y.square()) * (kAlpha + kBeta * x.square())); } else { +#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + auto x_data = x.data(); + auto dx_data = dx.data(); + auto dout_data = dout.data(); + int n = std::min(x.size(), dx.size()); + + auto first = static_cast(std::malloc(n * sizeof(T))); + std::memset(first, 0, n * sizeof(T)); + auto second = static_cast(std::malloc(n * sizeof(T))); + std::memset(second, 0, n * sizeof(T)); + + // first = (0.5 * (1 + erf(x / sqrt(2)))) + math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, first, 1); + math::CBlas::VMERF(n, first, first, VML_LA); + for (int i = 0; i < n; i++) { + first[i] += static_cast(1); + } + math::CBlas::SCAL(n, static_cast(0.5), first, 1); + + // second = (0.5 * 2/sqrt(pi) * 1/sqrt(2) * x * exp(-0.5 * x^2)) + math::CBlas::VSQUARE(n, x_data, second); + math::CBlas::SCAL(n, -static_cast(0.5), second, 1); + math::CBlas::VEXP(n, second, second); + math::CBlas::VMUL(n, x_data, second, second); + math::CBlas::SCAL(n, static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2), + second, 1); + + // dx = dout * (first + second); + math::CBlas::VADD(n, first, second, first); + math::CBlas::VMUL(n, dout_data, first, dx_data); + + std::free(first); + std::free(second); +#else // gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) * // exp(- x^2 / 2) auto first = @@ -70,6 +124,7 @@ struct GeluGradFunctor { auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * (-static_cast(0.5) * x.square()).exp(); dx.device(d) = dout * (first + second); +#endif } } };