From acc6ae49b18cb55db4dd84cd09069ebe01a1b54a Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Wed, 12 Dec 2018 00:31:59 +0800 Subject: [PATCH] Fix the issue to run on AVX2 and AVX512F machines (#14851) test=develop --- .../fluid/operators/math/jit_kernel_layer_norm.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc index fead13eba..cb49e6648 100644 --- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc +++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc @@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel { } }; -#define INTRIAVX_FLOAT(isa, block) \ +#define INTRIAVX_FLOAT(isa, jit_block) \ template <> \ - LayerNormKernelImpl::LayerNormKernelImpl(int right) \ + LayerNormKernelImpl::LayerNormKernelImpl(int right) \ : LayerNormKernel() { \ this->num_ = right; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ this->end_ = this->num_ - this->rest_; \ } \ template <> \ - void LayerNormKernelImpl::Compute( \ + void LayerNormKernelImpl::Compute( \ float* x, float* out, float* mean, float* var, const float* scale, \ const float* bias, int height, const float epsilon) const { \ __m256 sum; \ @@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel { __m256 tmp; \ size_t offset; \ size_t j; \ + size_t block = YMM_FLOAT_BLOCK; \ __m256 reverse_num_vec = \ _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ __m256 epsilon_vec = _mm256_set1_ps(epsilon); \ @@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8); INTRIAVX_FLOAT(platform::avx, kGT8LT16); INTRIAVX_FLOAT(platform::avx, kEQ16); INTRIAVX_FLOAT(platform::avx, kGT16); -#endif -#ifdef __AVX2__ INTRIAVX_FLOAT(platform::avx2, kEQ8); INTRIAVX_FLOAT(platform::avx2, kGT8LT16); INTRIAVX_FLOAT(platform::avx2, kEQ16); INTRIAVX_FLOAT(platform::avx2, kGT16); +INTRIAVX_FLOAT(platform::avx512f, kEQ8); +INTRIAVX_FLOAT(platform::avx512f, kGT8LT16); +INTRIAVX_FLOAT(platform::avx512f, kEQ16); +INTRIAVX_FLOAT(platform::avx512f, kGT16); #endif #undef INTRIAVX_FLOAT -- GitLab