提交 4e67fe6a 编写于 作者: T tensor-tang

refine act and vxx with all size

上级 ba3eaed7
......@@ -60,60 +60,53 @@ void VXXJitCode::generate() {
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) {
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
int block = XMM_FLOAT_BLOCK;
while (rest > 0) {
if (rest >= 4) {
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
} else if (rest >= 2) {
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
}
} else {
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
}
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
switch (type_) {
case operand_type::mul:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::add:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddss(xmm_dst, xmm_src1, xmm_src2);
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovss(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * block;
rest -= block;
block /= 2;
}
ret();
}
......@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx);
if (type == operand_type::relu) {
if (type == operand_type::relu || type == operand_type::exp) {
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return ok;
} else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256
return ok; //&& d % 4 == 0 && d < 256;
} else {
// TODO(TJ): support more
return ok && d % 8 == 0;
......@@ -412,24 +403,15 @@ void VActJitCode::generate() {
return;
}
int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]);
switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
int block = XMM_FLOAT_BLOCK;
while (rest > 0) {
if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) {
vmovq(xmm_src, ptr[param1 + offset]);
} else {
vmovss(xmm_src, ptr[param1 + offset]);
}
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src, ptr[param1 + offset]);
switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
......@@ -440,25 +422,16 @@ void VActJitCode::generate() {
default:
break;
}
vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
// vmovups();
vmovss(xmm_src, ptr[param1 + offset]);
switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst);
}
vmovss(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * block;
rest -= block;
block /= 2;
}
ret();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册