diff --git a/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt b/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt index de83d80e7757ad161c810bf17f456d143f3fe597..c4a50138636a377d1fbbe14bfa6fd915717b4223 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt @@ -6,3 +6,4 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE) # use mkl kernels by name and type USE_JITKERNEL_MORE(crfdecoding, intrinsic) +USE_JITKERNEL_MORE(layernorm, intrinsic) diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..fafc12914e3f04542b015edc62285f404d163c3e --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/more/intrinsic/layer_norm.h" +#include +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace intrinsic { + +void LayerNorm(float* x, float* out, float* mean, float* var, + const float* scale, const float* bias, int height, + const float epsilon, int right) { + __m256 sum; + __m256 mean_vec, var_vec; + __m128 hi, lo; + __m256 tmp; + size_t offset; + size_t j; + int block = YMM_FLOAT_BLOCK; + const int rest = right % block; + const int end = right - rest; + + __m256 reverse_num_vec = + _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); + __m256 epsilon_vec = _mm256_set1_ps(epsilon); + int rest_mask = + ((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff; + __m256i mask_vec = _mm256_set_epi32( + rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, + rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, + rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, + rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); + + for (int i = 0; i < height; ++i) { + offset = i * right; + + /* get mean */ + sum = _mm256_setzero_ps(); + for (j = offset; j < end + offset; j += block) { + sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); + } + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)x + j); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); + } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + mean_vec = _mm256_mul_ps(sum, reverse_num_vec); + mean[i] = *reinterpret_cast(&mean_vec); + + /* get variance */ + sum = _mm256_setzero_ps(); + for (j = offset; j < end + offset; j += block) { + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + sum = _mm256_add_ps(sum, tmp); + } + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); + } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + var_vec = _mm256_mul_ps(sum, reverse_num_vec); + var[i] = *reinterpret_cast(&var_vec); + + /* get x_norm and calculate output*/ + for (j = offset; j < end + offset; j += block) { + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps(tmp, + _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); + } + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps(tmp, + _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); + } + + if (scale) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)scale + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(tmp, + _mm256_loadu_ps((const float*)scale + j - offset))); + } + } + + if (bias) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_add_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)bias + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps(reinterpret_cast(out) + j, + _mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias + + j - offset))); + } + } + } +} + +bool LayerNormKernel::UseMe(int d) const { + return platform::MayIUse(platform::avx); +} + +} // namespace intrinsic +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace intrinsic = paddle::operators::jit::more::intrinsic; + +REGISTER_JITKERNEL_MORE(layernorm, intrinsic, intrinsic::LayerNormKernel); diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..b802f56f57fd89f0d6f4a255be295f4015a3da41 --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jit/kernel_base.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace intrinsic { + +void LayerNorm(float* x, float* out, float* mean, float* var, + const float* scale, const float* bias, int height, + const float epsilon, int right); + +class LayerNormKernel : public KernelImpl> { + public: + LayerNormKernel() { this->func = LayerNorm; } + bool UseMe(typename LayerNormTuples::attr_type) const override; +}; + +} // namespace intrinsic +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle