diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index dd79949eca70edfc68fba52cc838b71c912a70ed..0d94a639b4a344e2c2ab39cd4485818619fe7618 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; static int g_tmp_mem[16] ALIGN32 = {0}; void VExpJitCode::generate() { - preCode(); - // push some? // in: ymm0, out: ymm1 - // use ymm 0~5 (and ymm 14~15 if avx only) + // use ymm 0~5, rax int offset = 0; vmovups(ymm_src, ptr[param1 + offset]); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); @@ -222,7 +220,8 @@ void VExpJitCode::generate() { vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); vmulps(ymm_fy, ymm_fx, ymm_tmp); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); - vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask + ymm_t ymm_z = ymm_t(ymm_mask.getIdx()); + vmulps(ymm_z, ymm_fx, ymm_tmp); vsubps(ymm_src, ymm_src, ymm_fy); vsubps(ymm_src, ymm_src, ymm_z); vmulps(ymm_z, ymm_src, ymm_src); @@ -240,7 +239,6 @@ void VExpJitCode::generate() { vaddps(ymm_dst, ymm_dst, ymm_src); vmovaps(ymm_tmp, ptr[reg_ptr_global]); vaddps(ymm_dst, ymm_dst, ymm_tmp); - // build 2^n ymm_t ymm_int = ymm_fx; vcvttps2dq(ymm_int, ymm_fx); @@ -250,31 +248,30 @@ void VExpJitCode::generate() { vpaddd(ymm_int, ymm_int, ymm_tmp); vpslld(ymm_int, ymm_int, 23); } else if (MayIUse(avx)) { - // use ymm_int, ymm_tmp and reg_ptr_global - xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int - xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp - mov(reg_ptr_global, reinterpret_cast(g_tmp_mem)); - vmovdqa(ptr[reg_ptr_global], ymm_int); - vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); + xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); + xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx()); + reg64_t reg_ptr_tmp = reg_ptr_global; + mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); + vmovdqa(ptr[reg_ptr_tmp], ymm_int); + vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); vpaddd(xtmp1, xtmp1, xtmp2); vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_global], xtmp1); + vmovdqa(ptr[reg_ptr_tmp], xtmp1); // next 128bits - vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]); + vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]); vmovdqa(xtmp2, - ptr[reg_ptr_global + + ptr[reg_ptr_tmp + (AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); vpaddd(xtmp1, xtmp1, xtmp2); vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1); + vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1); // load out - vmovdqa(ymm_int, ptr[reg_ptr_global]); + vmovdqa(ymm_int, ptr[reg_ptr_tmp]); } vmulps(ymm_dst, ymm_dst, ymm_int); vmovups(ptr[param2 + offset], ymm_dst); - // ret(); - postCode(); + ret(); } } // namespace gen diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 984bd15a22a20cdb34207803dd55a0e2cf26c928..8296de9b72d725fc6d6021b000f31fa41d09e7b0 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -128,7 +128,6 @@ class VExpJitCode : public JitCode { ymm_t ymm_fx = ymm_t(2); ymm_t ymm_fy = ymm_t(3); ymm_t ymm_mask = ymm_t(4); - ymm_t ymm_z = ymm_t(4); ymm_t ymm_tmp = ymm_t(5); };