提交 4e67fe6a 编写于 作者: T tensor-tang

refine act and vxx with all size

上级 ba3eaed7
...@@ -60,60 +60,53 @@ void VXXJitCode::generate() { ...@@ -60,60 +60,53 @@ void VXXJitCode::generate() {
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { int block = XMM_FLOAT_BLOCK;
if (scalar_index_ != 1) { while (rest > 0) {
vmovups(xmm_src1, ptr[param1 + offset]); if (rest >= 4) {
} if (scalar_index_ != 1) {
if (scalar_index_ != 2) { vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]); }
} if (scalar_index_ != 2) {
if (type_ == operand_type::mul) { vmovups(xmm_src2, ptr[param2 + offset]);
vmulps(xmm_dst, xmm_src1, xmm_src2); }
} else if (type_ == operand_type::add) { } else if (rest >= 2) {
vaddps(xmm_dst, xmm_src1, xmm_src2); if (scalar_index_ != 1) {
} vmovq(xmm_src1, ptr[param1 + offset]);
if (with_relu_) { }
vmaxps(xmm_dst, xmm_zero, xmm_dst); if (scalar_index_ != 2) {
} vmovq(xmm_src2, ptr[param2 + offset]);
vmovups(ptr[param3 + offset], xmm_dst); }
offset += sizeof(float) * 4; } else {
rest -= 4; if (scalar_index_ != 1) {
} vmovss(xmm_src1, ptr[param1 + offset]);
if (rest >= 2) { }
if (scalar_index_ != 1) { if (scalar_index_ != 2) {
vmovq(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src2, ptr[param2 + offset]);
} }
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { switch (type_) {
vmulps(xmm_dst, xmm_src1, xmm_src2); case operand_type::mul:
} else if (type_ == operand_type::add) { vmulps(xmm_dst, xmm_src1, xmm_src2);
vaddps(xmm_dst, xmm_src1, xmm_src2); break;
case operand_type::add:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
} }
if (with_relu_) { if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst); vmaxps(xmm_dst, xmm_zero, xmm_dst);
} }
vmovq(ptr[param3 + offset], xmm_dst); if (rest >= 4) {
offset += sizeof(float) * 2; vmovups(ptr[param3 + offset], xmm_dst);
rest -= 2; } else if (rest >= 2) {
} vmovq(ptr[param3 + offset], xmm_dst);
if (rest > 0) { } else {
if (scalar_index_ != 1) { vmovss(ptr[param3 + offset], xmm_dst);
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddss(xmm_dst, xmm_src1, xmm_src2);
} }
if (with_relu_) { offset += sizeof(float) * block;
vmaxps(xmm_dst, xmm_zero, xmm_dst); rest -= block;
} block /= 2;
vmovss(ptr[param3 + offset], xmm_dst);
} }
ret(); ret();
} }
...@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0}; ...@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
bool VActJitCode::init(int d, operand_type type) { bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx); bool ok = MayIUse(avx);
if (type == operand_type::relu) { if (type == operand_type::relu || type == operand_type::exp) {
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return ok; return ok;
} else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256
return ok; //&& d % 4 == 0 && d < 256;
} else { } else {
// TODO(TJ): support more // TODO(TJ): support more
return ok && d % 8 == 0; return ok && d % 8 == 0;
...@@ -412,24 +403,15 @@ void VActJitCode::generate() { ...@@ -412,24 +403,15 @@ void VActJitCode::generate() {
return; return;
} }
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) { int block = XMM_FLOAT_BLOCK;
vmovups(xmm_src, ptr[param1 + offset]); while (rest > 0) {
switch (type_) { if (rest >= 4) {
case operand_type::relu: vmovups(xmm_src, ptr[param1 + offset]);
relu_xmm(xmm_dst, xmm_src, xmm_zero); } else if (rest >= 2) {
break; vmovq(xmm_src, ptr[param1 + offset]);
case operand_type::exp: } else {
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); vmovss(xmm_src, ptr[param1 + offset]);
break;
default:
break;
} }
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src, ptr[param1 + offset]);
switch (type_) { switch (type_) {
case operand_type::relu: case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero); relu_xmm(xmm_dst, xmm_src, xmm_zero);
...@@ -440,25 +422,16 @@ void VActJitCode::generate() { ...@@ -440,25 +422,16 @@ void VActJitCode::generate() {
default: default:
break; break;
} }
vmovq(ptr[param2 + offset], xmm_dst); if (rest >= 4) {
offset += sizeof(float) * 2; vmovups(ptr[param2 + offset], xmm_dst);
rest -= 2; } else if (rest >= 2) {
} vmovq(ptr[param2 + offset], xmm_dst);
if (rest > 0) { } else {
// vmovups(); vmovss(ptr[param2 + offset], xmm_dst);
vmovss(xmm_src, ptr[param1 + offset]);
switch (type_) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
} }
vmovss(ptr[param2 + offset], xmm_dst); offset += sizeof(float) * block;
rest -= block;
block /= 2;
} }
ret(); ret();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册