提交 74843558 编写于 作者: T tensor-tang

clean code exp avx

上级 b4751a34
...@@ -141,50 +141,52 @@ typedef union imm_xmm_union { ...@@ -141,50 +141,52 @@ typedef union imm_xmm_union {
AVX2_BITOP_USING_SSE2(slli_epi32); AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32); AVX2_INTOP_USING_SSE2(add_epi32);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */ \
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */ \
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */ \
imm0 = _mm256_cvttps_epi32(fx)
__m256 ExpAVX(__m256 x) { __m256 ExpAVX(__m256 x) {
__m256 tmp = _mm256_setzero_ps(), fx; AVXEXP_BASE;
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
__m256i imm0;
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */
fx = _mm256_mul_ps(x, *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));
tmp = _mm256_floor_ps(fx);
/* if greater, substract 1 */
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
mask = _mm256_and_ps(mask, one);
fx = _mm256_sub_ps(tmp, mask);
tmp =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1));
__m256 z =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));
x = _mm256_sub_ps(x, tmp);
x = _mm256_sub_ps(x, z);
z = _mm256_mul_ps(x, x);
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));
y = _mm256_mul_ps(y, z);
y = _mm256_add_ps(y, x);
y = _mm256_add_ps(y, one);
/* build 2^n */
imm0 = _mm256_cvttps_epi32(fx);
// two AVX2 instructions using SSE2 // two AVX2 instructions using SSE2
imm0 = avx2_mm256_add_epi32(imm0, imm0 = avx2_mm256_add_epi32(imm0,
*reinterpret_cast<const __m256i*>(_pi256_0x7f)); *reinterpret_cast<const __m256i*>(_pi256_0x7f));
...@@ -197,48 +199,7 @@ __m256 ExpAVX(__m256 x) { ...@@ -197,48 +199,7 @@ __m256 ExpAVX(__m256 x) {
#ifdef __AVX2__ #ifdef __AVX2__
__m256 ExpAVX2(__m256 x) { __m256 ExpAVX2(__m256 x) {
__m256 tmp = _mm256_setzero_ps(), fx; AVXEXP_BASE;
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
__m256i imm0;
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */
fx = _mm256_mul_ps(x, *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));
tmp = _mm256_floor_ps(fx);
/* if greater, substract 1 */
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
mask = _mm256_and_ps(mask, one);
fx = _mm256_sub_ps(tmp, mask);
tmp =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1));
__m256 z =
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));
x = _mm256_sub_ps(x, tmp);
x = _mm256_sub_ps(x, z);
z = _mm256_mul_ps(x, x);
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));
y = _mm256_mul_ps(y, x);
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));
y = _mm256_mul_ps(y, z);
y = _mm256_add_ps(y, x);
y = _mm256_add_ps(y, one);
/* build 2^n */
imm0 = _mm256_cvttps_epi32(fx);
// two AVX2 instructions // two AVX2 instructions
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f)); imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = _mm256_slli_epi32(imm0, 23); imm0 = _mm256_slli_epi32(imm0, 23);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册