提交 acc6ae49 编写于 作者: Y Yihua Xu 提交者: tensor-tang

Fix the issue to run on AVX2 and AVX512F machines (#14851)

test=develop
上级 45fb357b
......@@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
}
};
#define INTRIAVX_FLOAT(isa, block) \
#define INTRIAVX_FLOAT(isa, jit_block) \
template <> \
LayerNormKernelImpl<float, isa, block>::LayerNormKernelImpl(int right) \
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
: LayerNormKernel<float>() { \
this->num_ = right; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
this->end_ = this->num_ - this->rest_; \
} \
template <> \
void LayerNormKernelImpl<float, platform::avx, block>::Compute( \
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \
__m256 sum; \
......@@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
__m256 tmp; \
size_t offset; \
size_t j; \
size_t block = YMM_FLOAT_BLOCK; \
__m256 reverse_num_vec = \
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
......@@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8);
INTRIAVX_FLOAT(platform::avx, kGT8LT16);
INTRIAVX_FLOAT(platform::avx, kEQ16);
INTRIAVX_FLOAT(platform::avx, kGT16);
#endif
#ifdef __AVX2__
INTRIAVX_FLOAT(platform::avx2, kEQ8);
INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX_FLOAT(platform::avx2, kEQ16);
INTRIAVX_FLOAT(platform::avx2, kGT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ8);
INTRIAVX_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ16);
INTRIAVX_FLOAT(platform::avx512f, kGT16);
#endif
#undef INTRIAVX_FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册