提交 36588b33 编写于 作者: T tensor-tang

fix illegal instruction of rnn1 and text

上级 6447155d
...@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat) ...@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas activation_functions) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -69,37 +69,225 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16); ...@@ -69,37 +69,225 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif #endif
#define INTRI8_FLOAT(isa) \ namespace detail {
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST(0x7f, 0x7f);
_PS256_CONST(one, 1.f);
_PS256_CONST(0p5, 0.5f);
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
typedef union imm_xmm_union {
__m256i imm;
__m128i xmm[2];
} imm_xmm_union;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */ \
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */ \
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32);
__m256 ExpAVX(__m256 x) {
__m256 tmp = _mm256_setzero_ps(), fx;
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
__m256i imm0;
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */
fx = _mm256_mul_ps(x, *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));
tmp = _mm256_floor_ps(fx);
/* if greater, substract 1 */
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
mask = _mm256_and_ps(mask, one);
fx = _mm256_sub_ps(tmp, mask);
tmp =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1));
__m256 z =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));
x = _mm256_sub_ps(x, tmp);
x = _mm256_sub_ps(x, z);
z = _mm256_mul_ps(x, x);
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));
y = _mm256_mul_ps(y, z);
y = _mm256_add_ps(y, x);
y = _mm256_add_ps(y, one);
/* build 2^n */
imm0 = _mm256_cvttps_epi32(fx);
// two AVX2 instructions using SSE2
imm0 = avx2_mm256_add_epi32(imm0,
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = avx2_mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
__m256 tmp = _mm256_setzero_ps(), fx;
__m256 one = *reinterpret_cast<const __m256*> _ps256_one;
__m256i imm0;
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */
fx = _mm256_mul_ps(x, *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));
tmp = _mm256_floor_ps(fx);
/* if greater, substract 1 */
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
mask = _mm256_and_ps(mask, one);
fx = _mm256_sub_ps(tmp, mask);
tmp =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1));
__m256 z =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));
x = _mm256_sub_ps(x, tmp);
x = _mm256_sub_ps(x, z);
z = _mm256_mul_ps(x, x);
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));
y = _mm256_mul_ps(y, z);
y = _mm256_add_ps(y, x);
y = _mm256_add_ps(y, one);
/* build 2^n */
imm0 = _mm256_cvttps_epi32(fx);
// two AVX2 instructions
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = _mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
} // namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \ template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \ void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \ const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \ _mm256_storeu_ps(y, expisa(tmp)); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa, expisa) \
template <> \ template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \ void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \ const { \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \ tmp0 = expisa(tmp0); \
tmp1 = detail::Exp(tmp1); \ tmp1 = expisa(tmp1); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx); INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx); INTRI16_FLOAT(jit::avx, detail::ExpAVX);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRI8_FLOAT(jit::avx2); INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2); INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f); INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f); INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#endif #endif
// TODO(TJ): eq16 test and complete avx512 // TODO(TJ): eq16 test and complete avx512
...@@ -135,26 +323,26 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -135,26 +323,26 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
std::shared_ptr<const VExpKernel<T>> vexp_; std::shared_ptr<const VExpKernel<T>> vexp_;
}; };
#define INTRI_SIGMOID(tmp, min, max) \ #define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \ tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \ tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \ tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = detail::Exp(tmp); \ tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa, expisa) \
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \ void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \ const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ /*use static const??*/ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max); \ INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa, expisa) \
template <> \ template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \ void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
float* y) const { \ float* y) const { \
...@@ -162,13 +350,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -162,13 +350,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max); \ INTRI_SIGMOID(tmp0, min, max, expisa); \
INTRI_SIGMOID(tmp1, min, max); \ INTRI_SIGMOID(tmp1, min, max, expisa); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#define INTRI_GT8LT16_FLOAT(isa) \ #define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \ template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \ VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \ : VSigmoidKernel<float>() { \
...@@ -184,7 +372,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -184,7 +372,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max); \ INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \ const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \ const float max_ = SIGMOID_THRESHOLD_MAX; \
...@@ -198,7 +386,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -198,7 +386,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
} \ } \
} }
#define INTRI_GT16_FLOAT(isa) \ #define INTRI_GT16_FLOAT(isa, expisa) \
template <> \ template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \ VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \ : VSigmoidKernel<float>() { \
...@@ -215,7 +403,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -215,7 +403,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \ __m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max); \ INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y + i, tmp); \ _mm256_storeu_ps(y + i, tmp); \
} \ } \
const float min_ = SIGMOID_THRESHOLD_MIN; \ const float min_ = SIGMOID_THRESHOLD_MIN; \
...@@ -231,22 +419,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> { ...@@ -231,22 +419,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
} }
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx); INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx); INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx); INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx); INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRI8_FLOAT(jit::avx2); INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2); INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// INTRI_GT8LT16_FLOAT(jit::avx2); // maybe use avx at gt8lt16 and gt16
// INTRI_GT16_FLOAT(jit::avx2);
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f); INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f); INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// INTRI_GT8LT16_FLOAT(jit::avx512f); // maybe use avx2 at gt8lt16 and gt16
// INTRI_GT16_FLOAT(jit::avx512f);
#endif #endif
#undef INTRI8_FLOAT #undef INTRI8_FLOAT
...@@ -280,36 +466,36 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -280,36 +466,36 @@ class VTanhKernelImpl : public VTanhKernel<T> {
std::shared_ptr<const VAddBiasKernel<T>> vaddbias_; std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
}; };
#define INTRI_VTANH(tmp) \ #define INTRI_VTANH(tmp, expisa) \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = detail::Exp(tmp); \ tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa, expisa) \
template <> \ template <> \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \ void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \ const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp); \ INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
} }
#define INTRI16_FLOAT(isa) \ #define INTRI16_FLOAT(isa, expisa) \
template <> \ template <> \
void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \ void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \ const { \
__m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0); \ INTRI_VTANH(tmp0, expisa); \
INTRI_VTANH(tmp1); \ INTRI_VTANH(tmp1, expisa); \
_mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \ _mm256_storeu_ps(y + 8, tmp1); \
} }
#define INTRI_GT8LT16_FLOAT(isa) \ #define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \ template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \ VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \ : VTanhKernel<float>() { \
...@@ -327,7 +513,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -327,7 +513,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \ void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \ float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \ __m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp); \ INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \ _mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \ x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \
...@@ -337,7 +523,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -337,7 +523,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_->Compute(-1.f, y, y); \ vaddbias_->Compute(-1.f, y, y); \
} }
#define INTRI_GT16_FLOAT(isa) \ #define INTRI_GT16_FLOAT(isa, expisa) \
template <> \ template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \ VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \ : VTanhKernel<float>() { \
...@@ -356,7 +542,7 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -356,7 +542,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
const { \ const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \ __m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp); \ INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y + i, tmp); \ _mm256_storeu_ps(y + i, tmp); \
} \ } \
x += this->end_; \ x += this->end_; \
...@@ -368,19 +554,19 @@ class VTanhKernelImpl : public VTanhKernel<T> { ...@@ -368,19 +554,19 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} }
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx); INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx); INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx); INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx); INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRI8_FLOAT(jit::avx2); INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2); INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16 // maybe use avx at gt8lt16 and gt16
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f); INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f); INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16 // maybe use avx at gt8lt16 and gt16
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册