提交 b6260f38 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Optimize the kernel implementation of layernorm with openmp (#20895)

上级 8c4573a3
...@@ -26,16 +26,19 @@ namespace intrinsic { ...@@ -26,16 +26,19 @@ namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var, void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height, const float* scale, const float* bias, int height,
const float epsilon, int right) { const float epsilon, int right) {
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel
{
#endif
__m256 sum; __m256 sum;
__m256 mean_vec, var_vec; __m256 mean_vec, var_vec;
__m128 hi, lo; __m128 hi, lo;
__m256 tmp; __m256 tmp;
size_t offset; size_t offset;
size_t j; size_t j;
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
__m256 reverse_num_vec = __m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon); __m256 epsilon_vec = _mm256_set1_ps(epsilon);
...@@ -47,6 +50,9 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -47,6 +50,9 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
#ifdef PADDLE_WITH_MKLML
#pragma omp for
#endif
for (int i = 0; i < height; ++i) { for (int i = 0; i < height; ++i) {
offset = i * right; offset = i * right;
...@@ -100,15 +106,15 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -100,15 +106,15 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
/* get x_norm and calculate output*/ /* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) { for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp, tmp = _mm256_div_ps(
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); _mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
} }
if (rest != 0) { if (rest != 0) {
j = offset + right - block; j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp, tmp = _mm256_div_ps(
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); _mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
} }
...@@ -145,12 +151,16 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -145,12 +151,16 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
} }
if (rest != 0) { if (rest != 0) {
j = offset + right - block; j = offset + right - block;
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, _mm256_storeu_ps(
_mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias + reinterpret_cast<float*>(out) + j,
j - offset))); _mm256_add_ps(tmp,
_mm256_loadu_ps((const float*)bias + j - offset)));
}
} }
} }
#ifdef PADDLE_WITH_MKLML
} }
#endif
} }
bool LayerNormKernel::CanBeUsed(const int& d) const { bool LayerNormKernel::CanBeUsed(const int& d) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册