提交 b6260f38 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Optimize the kernel implementation of layernorm with openmp (#20895)

上级 8c4573a3
......@@ -26,16 +26,19 @@ namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right) {
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel
{
#endif
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
......@@ -47,6 +50,9 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
#ifdef PADDLE_WITH_MKLML
#pragma omp for
#endif
for (int i = 0; i < height; ++i) {
offset = i * right;
......@@ -100,15 +106,15 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
tmp = _mm256_div_ps(
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
tmp = _mm256_div_ps(
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
......@@ -145,12 +151,16 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias +
j - offset)));
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp,
_mm256_loadu_ps((const float*)bias + j - offset)));
}
}
}
#ifdef PADDLE_WITH_MKLML
}
#endif
}
bool LayerNormKernel::CanBeUsed(const int& d) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册