提交 ee2a7f1b 编写于 作者: T tensor-tang

refine exp and fix error on avx

test=develop
上级 1e06a32a
...@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; ...@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static int g_tmp_mem[16] ALIGN32 = {0}; static int g_tmp_mem[16] ALIGN32 = {0};
void VExpJitCode::generate() { void VExpJitCode::generate() {
preCode();
// push some?
// in: ymm0, out: ymm1 // in: ymm0, out: ymm1
// use ymm 0~5 (and ymm 14~15 if avx only) // use ymm 0~5, rax
int offset = 0; int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]); vmovups(ymm_src, ptr[param1 + offset]);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts)); mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
...@@ -222,7 +220,8 @@ void VExpJitCode::generate() { ...@@ -222,7 +220,8 @@ void VExpJitCode::generate() {
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(ymm_fy, ymm_fx, ymm_tmp); vmulps(ymm_fy, ymm_fx, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
vmulps(ymm_z, ymm_fx, ymm_tmp);
vsubps(ymm_src, ymm_src, ymm_fy); vsubps(ymm_src, ymm_src, ymm_fy);
vsubps(ymm_src, ymm_src, ymm_z); vsubps(ymm_src, ymm_src, ymm_z);
vmulps(ymm_z, ymm_src, ymm_src); vmulps(ymm_z, ymm_src, ymm_src);
...@@ -240,7 +239,6 @@ void VExpJitCode::generate() { ...@@ -240,7 +239,6 @@ void VExpJitCode::generate() {
vaddps(ymm_dst, ymm_dst, ymm_src); vaddps(ymm_dst, ymm_dst, ymm_src);
vmovaps(ymm_tmp, ptr[reg_ptr_global]); vmovaps(ymm_tmp, ptr[reg_ptr_global]);
vaddps(ymm_dst, ymm_dst, ymm_tmp); vaddps(ymm_dst, ymm_dst, ymm_tmp);
// build 2^n // build 2^n
ymm_t ymm_int = ymm_fx; ymm_t ymm_int = ymm_fx;
vcvttps2dq(ymm_int, ymm_fx); vcvttps2dq(ymm_int, ymm_fx);
...@@ -250,31 +248,30 @@ void VExpJitCode::generate() { ...@@ -250,31 +248,30 @@ void VExpJitCode::generate() {
vpaddd(ymm_int, ymm_int, ymm_tmp); vpaddd(ymm_int, ymm_int, ymm_tmp);
vpslld(ymm_int, ymm_int, 23); vpslld(ymm_int, ymm_int, 23);
} else if (MayIUse(avx)) { } else if (MayIUse(avx)) {
// use ymm_int, ymm_tmp and reg_ptr_global xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp reg64_t reg_ptr_tmp = reg_ptr_global;
mov(reg_ptr_global, reinterpret_cast<size_t>(g_tmp_mem)); mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_global], ymm_int); vmovdqa(ptr[reg_ptr_tmp], ymm_int);
vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp); vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2); vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23); vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_global], xtmp1); vmovdqa(ptr[reg_ptr_tmp], xtmp1);
// next 128bits // next 128bits
vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]); vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
vmovdqa(xtmp2, vmovdqa(xtmp2,
ptr[reg_ptr_global + ptr[reg_ptr_tmp +
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); (AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2); vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23); vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1); vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
// load out // load out
vmovdqa(ymm_int, ptr[reg_ptr_global]); vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
} }
vmulps(ymm_dst, ymm_dst, ymm_int); vmulps(ymm_dst, ymm_dst, ymm_int);
vmovups(ptr[param2 + offset], ymm_dst); vmovups(ptr[param2 + offset], ymm_dst);
// ret(); ret();
postCode();
} }
} // namespace gen } // namespace gen
......
...@@ -128,7 +128,6 @@ class VExpJitCode : public JitCode { ...@@ -128,7 +128,6 @@ class VExpJitCode : public JitCode {
ymm_t ymm_fx = ymm_t(2); ymm_t ymm_fx = ymm_t(2);
ymm_t ymm_fy = ymm_t(3); ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4); ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_z = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5); ymm_t ymm_tmp = ymm_t(5);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册