From ee2a7f1b8c96e75db5747e0419a63d55637ae0c7 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 15 Nov 2018 06:41:13 +0000 Subject: [PATCH] refine exp and fix error on avx test=develop --- paddle/fluid/operators/math/jit_code.cc | 33 +++++++++++-------------- paddle/fluid/operators/math/jit_code.h | 1 - 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index dd79949eca..0d94a639b4 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; static int g_tmp_mem[16] ALIGN32 = {0}; void VExpJitCode::generate() { - preCode(); - // push some? // in: ymm0, out: ymm1 - // use ymm 0~5 (and ymm 14~15 if avx only) + // use ymm 0~5, rax int offset = 0; vmovups(ymm_src, ptr[param1 + offset]); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); @@ -222,7 +220,8 @@ void VExpJitCode::generate() { vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); vmulps(ymm_fy, ymm_fx, ymm_tmp); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); - vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask + ymm_t ymm_z = ymm_t(ymm_mask.getIdx()); + vmulps(ymm_z, ymm_fx, ymm_tmp); vsubps(ymm_src, ymm_src, ymm_fy); vsubps(ymm_src, ymm_src, ymm_z); vmulps(ymm_z, ymm_src, ymm_src); @@ -240,7 +239,6 @@ void VExpJitCode::generate() { vaddps(ymm_dst, ymm_dst, ymm_src); vmovaps(ymm_tmp, ptr[reg_ptr_global]); vaddps(ymm_dst, ymm_dst, ymm_tmp); - // build 2^n ymm_t ymm_int = ymm_fx; vcvttps2dq(ymm_int, ymm_fx); @@ -250,31 +248,30 @@ void VExpJitCode::generate() { vpaddd(ymm_int, ymm_int, ymm_tmp); vpslld(ymm_int, ymm_int, 23); } else if (MayIUse(avx)) { - // use ymm_int, ymm_tmp and reg_ptr_global - xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int - xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp - mov(reg_ptr_global, reinterpret_cast(g_tmp_mem)); - vmovdqa(ptr[reg_ptr_global], ymm_int); - vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); + xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); + xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx()); + reg64_t reg_ptr_tmp = reg_ptr_global; + mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); + vmovdqa(ptr[reg_ptr_tmp], ymm_int); + vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); vpaddd(xtmp1, xtmp1, xtmp2); vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_global], xtmp1); + vmovdqa(ptr[reg_ptr_tmp], xtmp1); // next 128bits - vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]); + vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]); vmovdqa(xtmp2, - ptr[reg_ptr_global + + ptr[reg_ptr_tmp + (AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); vpaddd(xtmp1, xtmp1, xtmp2); vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1); + vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1); // load out - vmovdqa(ymm_int, ptr[reg_ptr_global]); + vmovdqa(ymm_int, ptr[reg_ptr_tmp]); } vmulps(ymm_dst, ymm_dst, ymm_int); vmovups(ptr[param2 + offset], ymm_dst); - // ret(); - postCode(); + ret(); } } // namespace gen diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 984bd15a22..8296de9b72 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -128,7 +128,6 @@ class VExpJitCode : public JitCode { ymm_t ymm_fx = ymm_t(2); ymm_t ymm_fy = ymm_t(3); ymm_t ymm_mask = ymm_t(4); - ymm_t ymm_z = ymm_t(4); ymm_t ymm_tmp = ymm_t(5); }; -- GitLab