提交 ee2a7f1b 编写于 作者: T tensor-tang

refine exp and fix error on avx

test=develop
上级 1e06a32a
......@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0};
void VExpJitCode::generate() {
preCode();
// push some?
// in: ymm0, out: ymm1
// use ymm 0~5 (and ymm 14~15 if avx only)
// use ymm 0~5, rax
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
......@@ -222,7 +220,8 @@ void VExpJitCode::generate() {
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(ymm_fy, ymm_fx, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
vmulps(ymm_z, ymm_fx, ymm_tmp);
vsubps(ymm_src, ymm_src, ymm_fy);
vsubps(ymm_src, ymm_src, ymm_z);
vmulps(ymm_z, ymm_src, ymm_src);
......@@ -240,7 +239,6 @@ void VExpJitCode::generate() {
vaddps(ymm_dst, ymm_dst, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
// build 2^n
ymm_t ymm_int = ymm_fx;
vcvttps2dq(ymm_int, ymm_fx);
......@@ -250,31 +248,30 @@ void VExpJitCode::generate() {
vpaddd(ymm_int, ymm_int, ymm_tmp);
vpslld(ymm_int, ymm_int, 23);
} else if (MayIUse(avx)) {
// use ymm_int, ymm_tmp and reg_ptr_global
xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int
xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp
mov(reg_ptr_global, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_global], ymm_int);
vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
reg64_t reg_ptr_tmp = reg_ptr_global;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_global], xtmp1);
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
// next 128bits
vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp2,
ptr[reg_ptr_global +
ptr[reg_ptr_tmp +
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
// load out
vmovdqa(ymm_int, ptr[reg_ptr_global]);
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
}
vmulps(ymm_dst, ymm_dst, ymm_int);
vmovups(ptr[param2 + offset], ymm_dst);
// ret();
postCode();
ret();
}
} // namespace gen
......
......@@ -128,7 +128,6 @@ class VExpJitCode : public JitCode {
ymm_t ymm_fx = ymm_t(2);
ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_z = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册