diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 66e80a07e457d78477fa7eaa1e22e6a179426846..c4247580f491a7ca26259528ca74dd92e35785a9 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -141,50 +141,52 @@ typedef union imm_xmm_union { AVX2_BITOP_USING_SSE2(slli_epi32); AVX2_INTOP_USING_SSE2(add_epi32); +#define AVXEXP_BASE \ + __m256 tmp = _mm256_setzero_ps(), fx; \ + __m256 one = *reinterpret_cast(_ps256_one); \ + __m256i imm0; \ + x = _mm256_min_ps(x, *reinterpret_cast(_ps256_exp_hi)); \ + x = _mm256_max_ps(x, *reinterpret_cast(_ps256_exp_lo)); \ + /* express exp(x) as exp(g + n*log(2)) */ \ + fx = _mm256_mul_ps(x, \ + *reinterpret_cast(_ps256_cephes_LOG2EF)); \ + fx = _mm256_add_ps(fx, *reinterpret_cast(_ps256_0p5)); \ + tmp = _mm256_floor_ps(fx); \ + /* if greater, substract 1 */ \ + __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \ + mask = _mm256_and_ps(mask, one); \ + fx = _mm256_sub_ps(tmp, mask); \ + tmp = _mm256_mul_ps(fx, \ + *reinterpret_cast(_ps256_cephes_exp_C1)); \ + __m256 z = _mm256_mul_ps( \ + fx, *reinterpret_cast(_ps256_cephes_exp_C2)); \ + x = _mm256_sub_ps(x, tmp); \ + x = _mm256_sub_ps(x, z); \ + z = _mm256_mul_ps(x, x); \ + __m256 y = *reinterpret_cast(_ps256_cephes_exp_p0); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p1)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p2)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p3)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p4)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p5)); \ + y = _mm256_mul_ps(y, z); \ + y = _mm256_add_ps(y, x); \ + y = _mm256_add_ps(y, one); \ + /* build 2^n */ \ + imm0 = _mm256_cvttps_epi32(fx) + __m256 ExpAVX(__m256 x) { - __m256 tmp = _mm256_setzero_ps(), fx; - __m256 one = *reinterpret_cast(_ps256_one); - __m256i imm0; - - x = _mm256_min_ps(x, *reinterpret_cast(_ps256_exp_hi)); - x = _mm256_max_ps(x, *reinterpret_cast(_ps256_exp_lo)); - - /* express exp(x) as exp(g + n*log(2)) */ - fx = _mm256_mul_ps(x, *reinterpret_cast(_ps256_cephes_LOG2EF)); - fx = _mm256_add_ps(fx, *reinterpret_cast(_ps256_0p5)); - - tmp = _mm256_floor_ps(fx); - - /* if greater, substract 1 */ - __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); - mask = _mm256_and_ps(mask, one); - fx = _mm256_sub_ps(tmp, mask); - - tmp = - _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C1)); - __m256 z = - _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C2)); - x = _mm256_sub_ps(x, tmp); - x = _mm256_sub_ps(x, z); - z = _mm256_mul_ps(x, x); - - __m256 y = *reinterpret_cast(_ps256_cephes_exp_p0); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p1)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p2)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p3)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p4)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p5)); - y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, x); - y = _mm256_add_ps(y, one); - - /* build 2^n */ - imm0 = _mm256_cvttps_epi32(fx); + AVXEXP_BASE; // two AVX2 instructions using SSE2 imm0 = avx2_mm256_add_epi32(imm0, *reinterpret_cast(_pi256_0x7f)); @@ -197,48 +199,7 @@ __m256 ExpAVX(__m256 x) { #ifdef __AVX2__ __m256 ExpAVX2(__m256 x) { - __m256 tmp = _mm256_setzero_ps(), fx; - __m256 one = *reinterpret_cast(_ps256_one); - __m256i imm0; - - x = _mm256_min_ps(x, *reinterpret_cast(_ps256_exp_hi)); - x = _mm256_max_ps(x, *reinterpret_cast(_ps256_exp_lo)); - - /* express exp(x) as exp(g + n*log(2)) */ - fx = _mm256_mul_ps(x, *reinterpret_cast(_ps256_cephes_LOG2EF)); - fx = _mm256_add_ps(fx, *reinterpret_cast(_ps256_0p5)); - - tmp = _mm256_floor_ps(fx); - - /* if greater, substract 1 */ - __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); - mask = _mm256_and_ps(mask, one); - fx = _mm256_sub_ps(tmp, mask); - - tmp = - _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C1)); - __m256 z = - _mm256_mul_ps(fx, *reinterpret_cast(_ps256_cephes_exp_C2)); - x = _mm256_sub_ps(x, tmp); - x = _mm256_sub_ps(x, z); - z = _mm256_mul_ps(x, x); - __m256 y = *reinterpret_cast(_ps256_cephes_exp_p0); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p1)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p2)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p3)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p4)); - y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *reinterpret_cast(_ps256_cephes_exp_p5)); - y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, x); - y = _mm256_add_ps(y, one); - - /* build 2^n */ - imm0 = _mm256_cvttps_epi32(fx); + AVXEXP_BASE; // two AVX2 instructions imm0 = _mm256_add_epi32(imm0, *reinterpret_cast(_pi256_0x7f)); imm0 = _mm256_slli_epi32(imm0, 23);