diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 7bf79b08956885259e5ac3801274a1a675e6d975..78d20ddf5fd63b81fd5e7fba656d825897a67a11 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -17,6 +17,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) +#include "paddle/fluid/operators/math/jit_kernel.h" +#endif #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -191,6 +195,8 @@ class LayerNormKernel : public framework::OpKernel { out.ShareDataWith(*y); out.Resize(matrix_shape); +#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \ + defined(__OSX__) auto& dev_ctx = ctx.template device_context(); RowwiseMean2D row_mean(left, right, ctx.device_context()); @@ -217,6 +223,19 @@ class LayerNormKernel : public framework::OpKernel { ElementwiseComputeEx, DeviceContext, T>( ctx, &out, bias, /*axis*/ 1, AddFunctor(), &out); } +#else + PADDLE_ENFORCE_EQ(mean->numel(), left); + PADDLE_ENFORCE_EQ(var->numel(), left); + PADDLE_ENFORCE_EQ(scale->numel(), right); + PADDLE_ENFORCE_EQ(bias->numel(), right); + + const auto& ker = math::jitkernel::KernelPool::Instance() + .template Get>( + static_cast(right)); + ker->Compute(x.data(), out.data(), mean->data(), var->data(), + scale->data(), bias->data(), static_cast(left), + static_cast(epsilon)); +#endif } }; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 8c5516b2329b5312b44519c2e97fc83eaacb6546..83ee9f6c51c64c6b000b20d73d41036b8590da5c 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -77,7 +77,7 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) if (NOT WIN32) - set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc) + set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) if(WITH_XBYAK) list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 4d8d3cd79a16a3ea61c4f63da3493e105847d30b..665ba24872a09897c4c1cb9bb5fc163b0c564dda 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -145,6 +145,14 @@ class CRFDecodeKernel : public Kernel { int *track) const = 0; }; +template +class LayerNormKernel : public Kernel { + public: + virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale, + const T *bias, int height, + const float epsilon) const = 0; +}; + } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..49904e6e8c7cd346bcbfb67c3a7574118b36e058 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc @@ -0,0 +1,241 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include +#include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +/* Layer Norm JitKernel */ +template +class LayerNormKernelImpl : public LayerNormKernel { + public: + explicit LayerNormKernelImpl(int right) : LayerNormKernel() { + this->num_ = right; + } + + void Compute(T* x, T* out, T* mean, T* var, const T* scale, const T* bias, + int height, const float epsilon) const override { + // get mean + for (int i = 0; i < height; i++) { + T sum = 0.0; + int offset = i * this->num_; + for (int j = 0; j < this->num_; j++) { + sum += x[offset + j]; + } + mean[i] = sum / this->num_; + } + + // get variance + for (int i = 0; i < height; i++) { + T sum = 0.0; + int offset = i * this->num_; + for (int j = 0; j < this->num_; j++) { + sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]); + } + var[i] = sum / this->num_; + } + + for (int i = 0; i < height; i++) { + int offset = i * this->num_; + T sqrt_var = sqrt(var[i] + (T)epsilon); + for (int j = 0; j < this->num_; j++) { + out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var; + } + } + if (scale) { + for (int i = 0; i < height; i++) { + int offset = i * this->num_; + for (int j = 0; j < this->num_; j++) { + out[offset + j] *= scale[j]; + } + } + } + + if (bias) { + for (int i = 0; i < height; i++) { + int offset = i * this->num_; + for (int j = 0; j < this->num_; j++) { + out[offset + j] += bias[j]; + } + } + } + } +}; + +#define INTRIAVX_FLOAT(isa, block) \ + template <> \ + LayerNormKernelImpl::LayerNormKernelImpl(int right) \ + : LayerNormKernel() { \ + this->num_ = right; \ + this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ + this->end_ = this->num_ - this->rest_; \ + } \ + template <> \ + void LayerNormKernelImpl::Compute( \ + float* x, float* out, float* mean, float* var, const float* scale, \ + const float* bias, int height, const float epsilon) const { \ + __m256 sum; \ + __m256 mean_vec, var_vec; \ + __m128 hi, lo; \ + __m256 tmp; \ + size_t offset; \ + size_t j; \ + __m256 reverse_num_vec = \ + _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ + __m256 epsilon_vec = _mm256_set1_ps(epsilon); \ + int rest_mask = \ + ((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \ + 0x0ff; \ + __m256i mask_vec = _mm256_set_epi32( \ + rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \ + rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \ + rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \ + rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \ + \ + for (int i = 0; i < height; ++i) { \ + offset = i * this->num_; \ + \ + /* get mean */ \ + sum = _mm256_setzero_ps(); \ + for (j = offset; j < end_ + offset; j += block) { \ + sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \ + } \ + if (rest_ != 0) { \ + j = offset + this->num_ - block; \ + tmp = _mm256_loadu_ps((const float*)x + j); \ + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, (__m256)mask_vec); \ + sum = _mm256_add_ps(sum, tmp); \ + } \ + hi = _mm256_extractf128_ps(sum, 1); \ + lo = _mm256_extractf128_ps(sum, 0); \ + sum = _mm256_add_ps( \ + sum, _mm256_insertf128_ps( \ + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \ + sum = _mm256_hadd_ps(sum, sum); \ + sum = _mm256_hadd_ps(sum, sum); \ + mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \ + mean[i] = *reinterpret_cast(&mean_vec); \ + \ + /* get variance */ \ + sum = _mm256_setzero_ps(); \ + for (j = offset; j < end_ + offset; j += block) { \ + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ + tmp = _mm256_mul_ps(tmp, tmp); \ + sum = _mm256_add_ps(sum, tmp); \ + } \ + if (rest_ != 0) { \ + j = offset + this->num_ - block; \ + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ + tmp = _mm256_mul_ps(tmp, tmp); \ + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, (__m256)mask_vec); \ + sum = _mm256_add_ps(sum, tmp); \ + } \ + hi = _mm256_extractf128_ps(sum, 1); \ + lo = _mm256_extractf128_ps(sum, 0); \ + sum = _mm256_add_ps( \ + sum, _mm256_insertf128_ps( \ + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \ + sum = _mm256_hadd_ps(sum, sum); \ + sum = _mm256_hadd_ps(sum, sum); \ + var_vec = _mm256_mul_ps(sum, reverse_num_vec); \ + var[i] = *reinterpret_cast(&var_vec); \ + \ + /* get x_norm and calculate output*/ \ + for (j = offset; j < end_ + offset; j += block) { \ + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ + tmp = _mm256_div_ps( \ + tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \ + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); \ + } \ + if (rest_ != 0) { \ + j = offset + num_ - block; \ + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ + tmp = _mm256_div_ps( \ + tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \ + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); \ + } \ + \ + if (scale) { \ + if (rest_ != 0) { \ + j = offset + this->num_ - block; \ + tmp = _mm256_loadu_ps((const float*)out + j); \ + } \ + for (j = offset; j < end_ + offset; j += block) { \ + _mm256_storeu_ps( \ + reinterpret_cast(out) + j, \ + _mm256_mul_ps( \ + _mm256_loadu_ps((const float*)out + j), \ + _mm256_loadu_ps((const float*)scale + j - offset))); \ + } \ + if (rest_ != 0) { \ + j = offset + this->num_ - block; \ + _mm256_storeu_ps( \ + reinterpret_cast(out) + j, \ + _mm256_mul_ps( \ + tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \ + } \ + } \ + \ + if (bias) { \ + if (rest_ != 0) { \ + j = offset + this->num_ - block; \ + tmp = _mm256_loadu_ps((const float*)out + j); \ + } \ + for (j = offset; j < end_ + offset; j += block) { \ + _mm256_storeu_ps( \ + reinterpret_cast(out) + j, \ + _mm256_add_ps( \ + _mm256_loadu_ps((const float*)out + j), \ + _mm256_loadu_ps((const float*)bias + j - offset))); \ + } \ + if (rest_ != 0) { \ + j = offset + this->num_ - block; \ + _mm256_storeu_ps( \ + reinterpret_cast(out) + j, \ + _mm256_add_ps( \ + tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \ + } \ + } \ + } \ + } + +#ifdef __AVX__ +INTRIAVX_FLOAT(jit::avx, kEQ8); +INTRIAVX_FLOAT(jit::avx, kGT8LT16); +INTRIAVX_FLOAT(jit::avx, kEQ16); +INTRIAVX_FLOAT(jit::avx, kGT16); +#endif +#ifdef __AVX2__ +INTRIAVX_FLOAT(jit::avx2, kEQ8); +INTRIAVX_FLOAT(jit::avx2, kGT8LT16); +INTRIAVX_FLOAT(jit::avx2, kEQ16); +INTRIAVX_FLOAT(jit::avx2, kGT16); +#endif + +#undef INTRIAVX_FLOAT + +REGISTER_JITKERNEL_DEPRECATED(layer_norm, LayerNormKernel); + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle