未验证 提交 14e43112 编写于 作者: H HappyAngel 提交者: GitHub

[arm] improve conv3x3_dw performance with relu, relu6 and leakey relu (#3183)


* improve conv_dw profile with rel relu6 leakyrelu, test=develop

* add depthwise, test=develop

* fix ci error, test=develop

* fix cv demo print, test=develop
上级 f7f65134
...@@ -68,6 +68,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -68,6 +68,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
gemv_arm_int8.cc gemv_arm_int8.cc
conv3x3s1_direct_fp32.cc conv3x3s1_direct_fp32.cc
conv3x3s2_direct_fp32.cc conv3x3s2_direct_fp32.cc
conv3x3s1p01_depthwise_fp32_relu.cc
conv3x3s2p01_depthwise_fp32_relu.cc
conv3x3s1p01_depthwise_fp32.cc conv3x3s1p01_depthwise_fp32.cc
conv3x3s2p01_depthwise_fp32.cc conv3x3s2p01_depthwise_fp32.cc
conv3x3s1px_depthwise_fp32.cc conv3x3s1px_depthwise_fp32.cc
......
...@@ -91,8 +91,19 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -91,8 +91,19 @@ void conv_depthwise_3x3s2_fp32(const float* din,
bool flag_bias, bool flag_bias,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
bool has_active = act_param.has_active;
bool flag_relu = false;
bool relu6 = false;
if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) {
flag_relu = true;
} else {
relu6 = true;
}
}
if (pad == 0) { if (pad == 0) {
if (w_in > 8) { if (w_in > 8) {
if (relu6) {
conv_depthwise_3x3s2p0_bias(dout, conv_depthwise_3x3s2p0_bias(dout,
din, din,
weights, weights,
...@@ -107,6 +118,22 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -107,6 +118,22 @@ void conv_depthwise_3x3s2_fp32(const float* din,
act_param, act_param,
ctx); ctx);
} else { } else {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p0_bias_s(dout, conv_depthwise_3x3s2p0_bias_s(dout,
din, din,
weights, weights,
...@@ -120,10 +147,26 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -120,10 +147,26 @@ void conv_depthwise_3x3s2_fp32(const float* din,
w_out, w_out,
act_param, act_param,
ctx); ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} }
} }
if (pad == 1) { if (pad == 1) {
if (w_in > 7) { if (w_in > 7) {
if (relu6) {
conv_depthwise_3x3s2p1_bias(dout, conv_depthwise_3x3s2p1_bias(dout,
din, din,
weights, weights,
...@@ -138,6 +181,22 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -138,6 +181,22 @@ void conv_depthwise_3x3s2_fp32(const float* din,
act_param, act_param,
ctx); ctx);
} else { } else {
conv_depthwise_3x3s2p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p1_bias_s(dout, conv_depthwise_3x3s2p1_bias_s(dout,
din, din,
weights, weights,
...@@ -151,6 +210,21 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -151,6 +210,21 @@ void conv_depthwise_3x3s2_fp32(const float* din,
w_out, w_out,
act_param, act_param,
ctx); ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} }
} }
} }
...@@ -978,8 +1052,6 @@ void act_switch_3x3s2p1(const float* din0_ptr, ...@@ -978,8 +1052,6 @@ void act_switch_3x3s2p1(const float* din0_ptr,
int cnt, int cnt,
int cnt_remain, int cnt_remain,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha; float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp}; float vsix[4] = {tmp, tmp, tmp, tmp};
...@@ -987,8 +1059,7 @@ void act_switch_3x3s2p1(const float* din0_ptr, ...@@ -987,8 +1059,7 @@ void act_switch_3x3s2p1(const float* din0_ptr,
switch (act_param.active_type) { switch (act_param.active_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
asm volatile( asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [inptr0] "+r"(din0_ptr), : [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr), [inptr1] "+r"(din1_ptr),
...@@ -1084,8 +1155,8 @@ void act_switch_3x3s2p1(const float* din0_ptr, ...@@ -1084,8 +1155,8 @@ void act_switch_3x3s2p1(const float* din0_ptr,
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/ /*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU RIGHT_COMPUTE_S2
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU RIGHT_RESULT_S2_LEAKY_RELU
: [inptr0] "+r"(din0_ptr), : [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr), [inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr), [inptr2] "+r"(din2_ptr),
...@@ -1131,55 +1202,9 @@ void act_switch_3x3s2p1(const float* din0_ptr, ...@@ -1131,55 +1202,9 @@ void act_switch_3x3s2p1(const float* din0_ptr,
"v22"); "v22");
break; break;
default: default:
LOG(FATAL) << "this act_type: " LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< static_cast<int>(act_param.active_type)
<< " fuse not support"; << " fuse not support";
} }
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
} }
#endif #endif
/** /**
...@@ -1570,8 +1595,6 @@ void act_switch_3x3s2p0(const float* din0_ptr, ...@@ -1570,8 +1595,6 @@ void act_switch_3x3s2p0(const float* din0_ptr,
int cnt, int cnt,
int cnt_remain, int cnt_remain,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha; float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp}; float vsix[4] = {tmp, tmp, tmp, tmp};
...@@ -1755,65 +1778,9 @@ void act_switch_3x3s2p0(const float* din0_ptr, ...@@ -1755,65 +1778,9 @@ void act_switch_3x3s2p0(const float* din0_ptr,
"v22"); "v22");
break; break;
default: default:
LOG(FATAL) << "this act_type: " LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< static_cast<int>(act_param.active_type)
<< " fuse not support"; << " fuse not support";
} }
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 "4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
} }
#endif #endif
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册