提交 acc6ae49 编写于 作者: Y Yihua Xu 提交者: tensor-tang

Fix the issue to run on AVX2 and AVX512F machines (#14851)

test=develop
上级 45fb357b
...@@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> { ...@@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
} }
}; };
#define INTRIAVX_FLOAT(isa, block) \ #define INTRIAVX_FLOAT(isa, jit_block) \
template <> \ template <> \
LayerNormKernelImpl<float, isa, block>::LayerNormKernelImpl(int right) \ LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
: LayerNormKernel<float>() { \ : LayerNormKernel<float>() { \
this->num_ = right; \ this->num_ = right; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
this->end_ = this->num_ - this->rest_; \ this->end_ = this->num_ - this->rest_; \
} \ } \
template <> \ template <> \
void LayerNormKernelImpl<float, platform::avx, block>::Compute( \ void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \ float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \ const float* bias, int height, const float epsilon) const { \
__m256 sum; \ __m256 sum; \
...@@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> { ...@@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
__m256 tmp; \ __m256 tmp; \
size_t offset; \ size_t offset; \
size_t j; \ size_t j; \
size_t block = YMM_FLOAT_BLOCK; \
__m256 reverse_num_vec = \ __m256 reverse_num_vec = \
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \ __m256 epsilon_vec = _mm256_set1_ps(epsilon); \
...@@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8); ...@@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8);
INTRIAVX_FLOAT(platform::avx, kGT8LT16); INTRIAVX_FLOAT(platform::avx, kGT8LT16);
INTRIAVX_FLOAT(platform::avx, kEQ16); INTRIAVX_FLOAT(platform::avx, kEQ16);
INTRIAVX_FLOAT(platform::avx, kGT16); INTRIAVX_FLOAT(platform::avx, kGT16);
#endif
#ifdef __AVX2__
INTRIAVX_FLOAT(platform::avx2, kEQ8); INTRIAVX_FLOAT(platform::avx2, kEQ8);
INTRIAVX_FLOAT(platform::avx2, kGT8LT16); INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX_FLOAT(platform::avx2, kEQ16); INTRIAVX_FLOAT(platform::avx2, kEQ16);
INTRIAVX_FLOAT(platform::avx2, kGT16); INTRIAVX_FLOAT(platform::avx2, kGT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ8);
INTRIAVX_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ16);
INTRIAVX_FLOAT(platform::avx512f, kGT16);
#endif #endif
#undef INTRIAVX_FLOAT #undef INTRIAVX_FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册