From b6260f3866a1de60c7089a331a5b9f60939b5963 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Thu, 31 Oct 2019 13:47:25 +0800 Subject: [PATCH] Optimize the kernel implementation of layernorm with openmp (#20895) --- .../jit/more/intrinsic/layer_norm.cc | 218 +++++++++--------- 1 file changed, 114 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc index a4e3246f10..61d8c50c56 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -26,131 +26,141 @@ namespace intrinsic { void LayerNorm(float* x, float* out, float* mean, float* var, const float* scale, const float* bias, int height, const float epsilon, int right) { - __m256 sum; - __m256 mean_vec, var_vec; - __m128 hi, lo; - __m256 tmp; - size_t offset; - size_t j; int block = YMM_FLOAT_BLOCK; const int rest = right % block; const int end = right - rest; +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel + { +#endif + __m256 sum; + __m256 mean_vec, var_vec; + __m128 hi, lo; + __m256 tmp; + size_t offset; + size_t j; + __m256 reverse_num_vec = + _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); + __m256 epsilon_vec = _mm256_set1_ps(epsilon); + int rest_mask = + ((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff; + __m256i mask_vec = _mm256_set_epi32( + rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, + rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, + rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, + rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); - __m256 reverse_num_vec = - _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); - __m256 epsilon_vec = _mm256_set1_ps(epsilon); - int rest_mask = - ((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff; - __m256i mask_vec = _mm256_set_epi32( - rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, - rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, - rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, - rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); +#ifdef PADDLE_WITH_MKLML +#pragma omp for +#endif + for (int i = 0; i < height; ++i) { + offset = i * right; - for (int i = 0; i < height; ++i) { - offset = i * right; - - /* get mean */ - sum = _mm256_setzero_ps(); - for (j = offset; j < end + offset; j += block) { - sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); - } - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_loadu_ps((const float*)x + j); - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, - *(__m256*)&mask_vec); // NOLINT - sum = _mm256_add_ps(sum, tmp); - } - hi = _mm256_extractf128_ps(sum, 1); - lo = _mm256_extractf128_ps(sum, 0); - sum = _mm256_add_ps( - sum, _mm256_insertf128_ps( - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); - sum = _mm256_hadd_ps(sum, sum); - sum = _mm256_hadd_ps(sum, sum); - mean_vec = _mm256_mul_ps(sum, reverse_num_vec); - mean[i] = *reinterpret_cast(&mean_vec); - - /* get variance */ - sum = _mm256_setzero_ps(); - for (j = offset; j < end + offset; j += block) { - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_mul_ps(tmp, tmp); - sum = _mm256_add_ps(sum, tmp); - } - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_mul_ps(tmp, tmp); - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, - *(__m256*)&mask_vec); // NOLINT - sum = _mm256_add_ps(sum, tmp); - } - hi = _mm256_extractf128_ps(sum, 1); - lo = _mm256_extractf128_ps(sum, 0); - sum = _mm256_add_ps( - sum, _mm256_insertf128_ps( - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); - sum = _mm256_hadd_ps(sum, sum); - sum = _mm256_hadd_ps(sum, sum); - var_vec = _mm256_mul_ps(sum, reverse_num_vec); - var[i] = *reinterpret_cast(&var_vec); - - /* get x_norm and calculate output*/ - for (j = offset; j < end + offset; j += block) { - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_div_ps(tmp, - _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); - } - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_div_ps(tmp, - _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); - } - - if (scale) { - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_loadu_ps((const float*)out + j); - } + /* get mean */ + sum = _mm256_setzero_ps(); for (j = offset; j < end + offset; j += block) { - _mm256_storeu_ps( - reinterpret_cast(out) + j, - _mm256_mul_ps(_mm256_loadu_ps((const float*)out + j), - _mm256_loadu_ps((const float*)scale + j - offset))); + sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); } if (rest != 0) { j = offset + right - block; - _mm256_storeu_ps( - reinterpret_cast(out) + j, - _mm256_mul_ps(tmp, - _mm256_loadu_ps((const float*)scale + j - offset))); + tmp = _mm256_loadu_ps((const float*)x + j); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); } - } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + mean_vec = _mm256_mul_ps(sum, reverse_num_vec); + mean[i] = *reinterpret_cast(&mean_vec); - if (bias) { + /* get variance */ + sum = _mm256_setzero_ps(); + for (j = offset; j < end + offset; j += block) { + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + sum = _mm256_add_ps(sum, tmp); + } if (rest != 0) { j = offset + right - block; - tmp = _mm256_loadu_ps((const float*)out + j); + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + var_vec = _mm256_mul_ps(sum, reverse_num_vec); + var[i] = *reinterpret_cast(&var_vec); + + /* get x_norm and calculate output*/ for (j = offset; j < end + offset; j += block) { - _mm256_storeu_ps( - reinterpret_cast(out) + j, - _mm256_add_ps(_mm256_loadu_ps((const float*)out + j), - _mm256_loadu_ps((const float*)bias + j - offset))); + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps( + tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); } if (rest != 0) { j = offset + right - block; - _mm256_storeu_ps(reinterpret_cast(out) + j, - _mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias + - j - offset))); + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps( + tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); + } + + if (scale) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)scale + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(tmp, + _mm256_loadu_ps((const float*)scale + j - offset))); + } + } + + if (bias) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_add_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)bias + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_add_ps(tmp, + _mm256_loadu_ps((const float*)bias + j - offset))); + } } } +#ifdef PADDLE_WITH_MKLML } +#endif } bool LayerNormKernel::CanBeUsed(const int& d) const { -- GitLab