提交 4e67fe6a 编写于 作者: T tensor-tang

refine act and vxx with all size

上级 ba3eaed7
...@@ -60,6 +60,8 @@ void VXXJitCode::generate() { ...@@ -60,6 +60,8 @@ void VXXJitCode::generate() {
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
int block = XMM_FLOAT_BLOCK;
while (rest > 0) {
if (rest >= 4) { if (rest >= 4) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src1, ptr[param1 + offset]);
...@@ -67,54 +69,45 @@ void VXXJitCode::generate() { ...@@ -67,54 +69,45 @@ void VXXJitCode::generate() {
if (scalar_index_ != 2) { if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]); vmovups(xmm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { } else if (rest >= 2) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src1, ptr[param1 + offset]);
} }
if (scalar_index_ != 2) { if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]); vmovq(xmm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { } else {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
if (scalar_index_ != 1) { if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src1, ptr[param1 + offset]);
} }
if (scalar_index_ != 2) { if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]); vmovss(xmm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { }
vmulss(xmm_dst, xmm_src1, xmm_src2); switch (type_) {
} else if (type_ == operand_type::add) { case operand_type::mul:
vaddss(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::add:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
} }
if (with_relu_) { if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst); vmaxps(xmm_dst, xmm_zero, xmm_dst);
} }
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst); vmovss(ptr[param3 + offset], xmm_dst);
} }
offset += sizeof(float) * block;
rest -= block;
block /= 2;
}
ret(); ret();
} }
...@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0}; ...@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
bool VActJitCode::init(int d, operand_type type) { bool VActJitCode::init(int d, operand_type type) {
bool ok = MayIUse(avx); bool ok = MayIUse(avx);
if (type == operand_type::relu) { if (type == operand_type::relu || type == operand_type::exp) {
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return ok; return ok;
} else if (type == operand_type::exp) {
// exp is slower than mkl when d >= 256
return ok; //&& d % 4 == 0 && d < 256;
} else { } else {
// TODO(TJ): support more // TODO(TJ): support more
return ok && d % 8 == 0; return ok && d % 8 == 0;
...@@ -412,42 +403,15 @@ void VActJitCode::generate() { ...@@ -412,42 +403,15 @@ void VActJitCode::generate() {
return; return;
} }
int rest = num_ % YMM_FLOAT_BLOCK; int rest = num_ % YMM_FLOAT_BLOCK;
int block = XMM_FLOAT_BLOCK;
while (rest > 0) {
if (rest >= 4) { if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]); vmovups(xmm_src, ptr[param1 + offset]);
switch (type_) { } else if (rest >= 2) {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src, ptr[param1 + offset]); vmovq(xmm_src, ptr[param1 + offset]);
switch (type_) { } else {
case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero);
break;
case operand_type::exp:
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
break;
default:
break;
}
vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
// vmovups();
vmovss(xmm_src, ptr[param1 + offset]); vmovss(xmm_src, ptr[param1 + offset]);
}
switch (type_) { switch (type_) {
case operand_type::relu: case operand_type::relu:
relu_xmm(xmm_dst, xmm_src, xmm_zero); relu_xmm(xmm_dst, xmm_src, xmm_zero);
...@@ -458,8 +422,17 @@ void VActJitCode::generate() { ...@@ -458,8 +422,17 @@ void VActJitCode::generate() {
default: default:
break; break;
} }
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst); vmovss(ptr[param2 + offset], xmm_dst);
} }
offset += sizeof(float) * block;
rest -= block;
block /= 2;
}
ret(); ret();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册