diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 7365bfeeb8edf09a8ad5e1cb2c61300e86bdf518..c7bdec354735773a15b4c99baf9f7798f2d92564 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc - DEPS cpu_info cblas activation_functions) + DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index b62e130c43743f542e2074868fc01598047d6b19..15efeba41a26ec8e419b21b927ef06170d1a777b 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -69,37 +69,225 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define INTRI8_FLOAT(isa) \ +namespace detail { + +#ifdef __AVX__ + +#define ALIGN32 __attribute__((aligned(32))) + +#define _PS256_CONST(Name, Val) \ + static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ + Val, Val, Val, Val} + +#define _PI256_CONST(Name, Val) \ + static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ + Val, Val, Val, Val} + +_PI256_CONST(0x7f, 0x7f); +_PS256_CONST(one, 1.f); +_PS256_CONST(0p5, 0.5f); +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341); +_PS256_CONST(cephes_exp_C1, 0.693359375); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4); +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1); + +typedef union imm_xmm_union { + __m256i imm; + __m128i xmm[2]; +} imm_xmm_union; + +#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \ + { \ + imm_xmm_union u ALIGN32; \ + u.imm = imm_; \ + xmm0_ = u.xmm[0]; \ + xmm1_ = u.xmm[1]; \ + } + +#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \ + { \ + imm_xmm_union u ALIGN32; \ + u.xmm[0] = xmm0_; \ + u.xmm[1] = xmm1_; \ + imm_ = u.imm; \ + } + +#define AVX2_BITOP_USING_SSE2(fn) \ + static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \ + /* use SSE2 to perform the bitop AVX2 */ \ + __m128i x1, x2; \ + __m256i ret; \ + COPY_IMM_TO_XMM(x, x1, x2); \ + x1 = _mm_##fn(x1, y); \ + x2 = _mm_##fn(x2, y); \ + COPY_XMM_TO_IMM(x1, x2, ret); \ + return ret; \ + } + +#define AVX2_INTOP_USING_SSE2(fn) \ + static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \ + /* use SSE2 to perform the AVX2 integer operation */ \ + __m128i x1, x2; \ + __m128i y1, y2; \ + __m256i ret; \ + COPY_IMM_TO_XMM(x, x1, x2); \ + COPY_IMM_TO_XMM(y, y1, y2); \ + x1 = _mm_##fn(x1, y1); \ + x2 = _mm_##fn(x2, y2); \ + COPY_XMM_TO_IMM(x1, x2, ret); \ + return ret; \ + } + +AVX2_BITOP_USING_SSE2(slli_epi32); +AVX2_INTOP_USING_SSE2(add_epi32); + +__m256 ExpAVX(__m256 x) { + __m256 tmp = _mm256_setzero_ps(), fx; + __m256 one = *reinterpret_cast(_ps256_one); + __m256i imm0; + + x = _mm256_min_ps(x, *reinterpret_cast(_ps256_exp_hi)); + x = _mm256_max_ps(x, *reinterpret_cast(_ps256_exp_lo)); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_mul_ps(x, *reinterpret_cast(_ps256_cephes_LOG2EF)); + fx = _mm256_add_ps(fx, *reinterpret_cast(_ps256_0p5)); + + tmp = _mm256_floor_ps(fx); + + /* if greater, substract 1 */ + __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + tmp = + _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C1)); + __m256 z = + _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C2)); + x = _mm256_sub_ps(x, tmp); + x = _mm256_sub_ps(x, z); + z = _mm256_mul_ps(x, x); + + __m256 y = *reinterpret_cast(_ps256_cephes_exp_p0); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p1)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p2)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p3)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p4)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p5)); + y = _mm256_mul_ps(y, z); + y = _mm256_add_ps(y, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // two AVX2 instructions using SSE2 + imm0 = avx2_mm256_add_epi32(imm0, + *reinterpret_cast(_pi256_0x7f)); + imm0 = avx2_mm256_slli_epi32(imm0, 23); + __m256 pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} +#endif + +#ifdef __AVX2__ +__m256 ExpAVX2(__m256 x) { + __m256 tmp = _mm256_setzero_ps(), fx; + __m256 one = *reinterpret_cast _ps256_one; + __m256i imm0; + + x = _mm256_min_ps(x, *reinterpret_cast(_ps256_exp_hi)); + x = _mm256_max_ps(x, *reinterpret_cast(_ps256_exp_lo)); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_mul_ps(x, *reinterpret_cast(_ps256_cephes_LOG2EF)); + fx = _mm256_add_ps(fx, *reinterpret_cast(_ps256_0p5)); + + tmp = _mm256_floor_ps(fx); + + /* if greater, substract 1 */ + __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + tmp = + _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C1)); + __m256 z = + _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C2)); + x = _mm256_sub_ps(x, tmp); + x = _mm256_sub_ps(x, z); + z = _mm256_mul_ps(x, x); + __m256 y = *reinterpret_cast(_ps256_cephes_exp_p0); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p1)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p2)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p3)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p4)); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p5)); + y = _mm256_mul_ps(y, z); + y = _mm256_add_ps(y, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // two AVX2 instructions + imm0 = _mm256_add_epi32(imm0, *reinterpret_cast(_pi256_0x7f)); + imm0 = _mm256_slli_epi32(imm0, 23); + __m256 pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} +#endif + +} // namespace detail + +#define INTRI8_FLOAT(isa, expisa) \ template <> \ void VExpKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp = _mm256_loadu_ps(x); \ - _mm256_storeu_ps(y, detail::Exp(tmp)); \ + _mm256_storeu_ps(y, expisa(tmp)); \ } -#define INTRI16_FLOAT(isa) \ +#define INTRI16_FLOAT(isa, expisa) \ template <> \ void VExpKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = detail::Exp(tmp0); \ - tmp1 = detail::Exp(tmp1); \ + tmp0 = expisa(tmp0); \ + tmp1 = expisa(tmp1); \ _mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y + 8, tmp1); \ } #ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx, detail::ExpAVX); +INTRI16_FLOAT(jit::avx, detail::ExpAVX); #endif #ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); #endif #ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); #endif // TODO(TJ): eq16 test and complete avx512 @@ -135,26 +323,26 @@ class VSigmoidKernelImpl : public VSigmoidKernel { std::shared_ptr> vexp_; }; -#define INTRI_SIGMOID(tmp, min, max) \ +#define INTRI_SIGMOID(tmp, min, max, expisa) \ tmp = _mm256_max_ps(tmp, min); \ tmp = _mm256_min_ps(tmp, max); \ tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \ - tmp = detail::Exp(tmp); \ + tmp = expisa(tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) -#define INTRI8_FLOAT(isa) \ +#define INTRI8_FLOAT(isa, expisa) \ template <> \ void VSigmoidKernelImpl::Compute(const float* x, float* y) \ const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + /*use static const??*/ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa) \ +#define INTRI16_FLOAT(isa, expisa) \ template <> \ void VSigmoidKernelImpl::Compute(const float* x, \ float* y) const { \ @@ -162,13 +350,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel { __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max); \ - INTRI_SIGMOID(tmp1, min, max); \ + INTRI_SIGMOID(tmp0, min, max, expisa); \ + INTRI_SIGMOID(tmp1, min, max, expisa); \ _mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y + 8, tmp1); \ } -#define INTRI_GT8LT16_FLOAT(isa) \ +#define INTRI_GT8LT16_FLOAT(isa, expisa) \ template <> \ VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ : VSigmoidKernel() { \ @@ -184,7 +372,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ _mm256_storeu_ps(y, tmp); \ const float min_ = SIGMOID_THRESHOLD_MIN; \ const float max_ = SIGMOID_THRESHOLD_MAX; \ @@ -198,7 +386,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { } \ } -#define INTRI_GT16_FLOAT(isa) \ +#define INTRI_GT16_FLOAT(isa, expisa) \ template <> \ VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ : VSigmoidKernel() { \ @@ -215,7 +403,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_SIGMOID(tmp, min, max); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ _mm256_storeu_ps(y + i, tmp); \ } \ const float min_ = SIGMOID_THRESHOLD_MIN; \ @@ -231,22 +419,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel { } #ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_GT8LT16_FLOAT(jit::avx); -INTRI_GT16_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx, detail::ExpAVX); +INTRI16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); #endif #ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -// INTRI_GT8LT16_FLOAT(jit::avx2); -// INTRI_GT16_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); +// maybe use avx at gt8lt16 and gt16 #endif #ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -// INTRI_GT8LT16_FLOAT(jit::avx512f); -// INTRI_GT16_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); +// maybe use avx2 at gt8lt16 and gt16 #endif #undef INTRI8_FLOAT @@ -280,36 +466,36 @@ class VTanhKernelImpl : public VTanhKernel { std::shared_ptr> vaddbias_; }; -#define INTRI_VTANH(tmp) \ +#define INTRI_VTANH(tmp, expisa) \ tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ - tmp = detail::Exp(tmp); \ + tmp = expisa(tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) -#define INTRI8_FLOAT(isa) \ +#define INTRI8_FLOAT(isa, expisa) \ template <> \ void VTanhKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp); \ + INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa) \ +#define INTRI16_FLOAT(isa, expisa) \ template <> \ void VTanhKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0); \ - INTRI_VTANH(tmp1); \ + INTRI_VTANH(tmp0, expisa); \ + INTRI_VTANH(tmp1, expisa); \ _mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y + 8, tmp1); \ } -#define INTRI_GT8LT16_FLOAT(isa) \ +#define INTRI_GT8LT16_FLOAT(isa, expisa) \ template <> \ VTanhKernelImpl::VTanhKernelImpl(int d) \ : VTanhKernel() { \ @@ -327,7 +513,7 @@ class VTanhKernelImpl : public VTanhKernel { void VTanhKernelImpl::Compute(const float* x, \ float* y) const { \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp); \ + INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ @@ -337,7 +523,7 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_->Compute(-1.f, y, y); \ } -#define INTRI_GT16_FLOAT(isa) \ +#define INTRI_GT16_FLOAT(isa, expisa) \ template <> \ VTanhKernelImpl::VTanhKernelImpl(int d) \ : VTanhKernel() { \ @@ -356,7 +542,7 @@ class VTanhKernelImpl : public VTanhKernel { const { \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp); \ + INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y + i, tmp); \ } \ x += this->end_; \ @@ -368,19 +554,19 @@ class VTanhKernelImpl : public VTanhKernel { } #ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_GT8LT16_FLOAT(jit::avx); -INTRI_GT16_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx, detail::ExpAVX); +INTRI16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); #endif #ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); // maybe use avx at gt8lt16 and gt16 #endif #ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); // maybe use avx at gt8lt16 and gt16 #endif