未验证 提交 bcafe317 编写于 作者: F Feiyu Chan 提交者: GitHub

add MKL computation back to gelu's non-approximate part (#23420)

上级 dbfbd7ea
...@@ -41,9 +41,28 @@ struct GeluFunctor { ...@@ -41,9 +41,28 @@ struct GeluFunctor {
.tanh(); .tanh();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp); out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
} else { } else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data,
1);
math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf(); auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp); out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
} }
} }
}; };
...@@ -61,6 +80,41 @@ struct GeluGradFunctor { ...@@ -61,6 +80,41 @@ struct GeluGradFunctor {
(static_cast<T>(1) + y + (static_cast<T>(1) + y +
(x - x * y.square()) * (kAlpha + kBeta * x.square())); (x - x * y.square()) * (kAlpha + kBeta * x.square()));
} else { } else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto dx_data = dx.data();
auto dout_data = dout.data();
int n = std::min(x.size(), dx.size());
auto first = static_cast<T*>(std::malloc(n * sizeof(T)));
std::memset(first, 0, n * sizeof(T));
auto second = static_cast<T*>(std::malloc(n * sizeof(T)));
std::memset(second, 0, n * sizeof(T));
// first = (0.5 * (1 + erf(x / sqrt(2))))
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, first, 1);
math::CBlas<T>::VMERF(n, first, first, VML_LA);
for (int i = 0; i < n; i++) {
first[i] += static_cast<T>(1);
}
math::CBlas<T>::SCAL(n, static_cast<T>(0.5), first, 1);
// second = (0.5 * 2/sqrt(pi) * 1/sqrt(2) * x * exp(-0.5 * x^2))
math::CBlas<T>::VSQUARE(n, x_data, second);
math::CBlas<T>::SCAL(n, -static_cast<T>(0.5), second, 1);
math::CBlas<T>::VEXP(n, second, second);
math::CBlas<T>::VMUL(n, x_data, second, second);
math::CBlas<T>::SCAL(n, static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2),
second, 1);
// dx = dout * (first + second);
math::CBlas<T>::VADD(n, first, second, first);
math::CBlas<T>::VMUL(n, dout_data, first, dx_data);
std::free(first);
std::free(second);
#else
// gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) * // gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) *
// exp(- x^2 / 2) // exp(- x^2 / 2)
auto first = auto first =
...@@ -70,6 +124,7 @@ struct GeluGradFunctor { ...@@ -70,6 +124,7 @@ struct GeluGradFunctor {
auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
(-static_cast<T>(0.5) * x.square()).exp(); (-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second); dx.device(d) = dout * (first + second);
#endif
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册