提交 247f0c7f 编写于 作者: C chenjiaoAngel

fix format. test=develop

上级 9348cb41
...@@ -112,11 +112,30 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -112,11 +112,30 @@ void conv_depthwise_5x5s1_fp32(float* dout,
float vscale[4] = {ss, ss, ss, ss}; float vscale[4] = {ss, ss, ss, ss};
if (has_active) { if (has_active) {
switch (act_type) { switch (act_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
conv_depthwise_5x5s1_bias_relu(dout, conv_depthwise_5x5s1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
num,
chin,
hin,
win,
hout,
wout,
pad_top,
pad_bottom,
pad_left,
pad_right,
ctx);
break;
case lite_api::ActivationType::kRelu6:
conv_depthwise_5x5s1_bias_relu6(dout,
din, din,
weights, weights,
bias, bias,
vsix,
flag_bias, flag_bias,
num, num,
chin, chin,
...@@ -129,48 +148,29 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -129,48 +148,29 @@ void conv_depthwise_5x5s1_fp32(float* dout,
pad_left, pad_left,
pad_right, pad_right,
ctx); ctx);
break; break;
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kLeakyRelu:
conv_depthwise_5x5s1_bias_relu6(dout, conv_depthwise_5x5s1_bias_leakyRelu(dout,
din, din,
weights, weights,
bias, bias,
vsix, vscale,
flag_bias, flag_bias,
num, num,
chin, chin,
hin, hin,
win, win,
hout, hout,
wout, wout,
pad_top, pad_top,
pad_bottom, pad_bottom,
pad_left, pad_left,
pad_right, pad_right,
ctx); ctx);
break; break;
case lite_api::ActivationType::kLeakyRelu: default:
conv_depthwise_5x5s1_bias_leakyRelu(dout, LOG(FATAL) << "this act_type: " << static_cast<int>(act_type)
din, << " fuse not support";
weights,
bias,
vscale,
flag_bias,
num,
chin,
hin,
win,
hout,
wout,
pad_top,
pad_bottom,
pad_left,
pad_right,
ctx);
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_type)
<< " fuse not support";
} }
} else { } else {
conv_depthwise_5x5s1_bias(dout, conv_depthwise_5x5s1_bias(dout,
...@@ -191,6 +191,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -191,6 +191,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
ctx); ctx);
} }
} }
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
#define COMPUTE_ONE_LINE_S1_PRE \ #define COMPUTE_ONE_LINE_S1_PRE \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
...@@ -201,7 +202,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -201,7 +202,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"1: \n" \ "1: \n" \
"subs %w[cnt], %w[cnt], #1 \n" \ "subs %w[cnt], %w[cnt], #1 \n" \
"fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0123*wr0[0]*/ \ "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0123*wr0[0]*/ \
"fmul v14.4s, v10.4s, %[wr6].s[0]\n" /*4567*wr6[0*/ \ "fmul v14.4s, v10.4s, %[wr6].s[0]\n" /*4567*wr6[0*/ \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"fmla v15.4s, v11.4s, %[wr0].s[1]\n" /*1234*wr0[1]*/ \ "fmla v15.4s, v11.4s, %[wr0].s[1]\n" /*1234*wr0[1]*/ \
...@@ -483,7 +484,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -483,7 +484,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v12.16b, v9.16b, v10.16b, #8\n" \ "ext v12.16b, v9.16b, v10.16b, #8\n" \
"fmla v16.4s, v13.4s, %[wr4].s[3]\n" /*3456*wr4[3]*/ \ "fmla v16.4s, v13.4s, %[wr4].s[3]\n" /*3456*wr4[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v17.4s, v17.4s, v16.4s\n" \ "fadd v17.4s, v17.4s, v16.4s\n"
#define COMPUTE_ONE_LINE_S1_POST \ #define COMPUTE_ONE_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]]\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
...@@ -652,25 +653,25 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -652,25 +653,25 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ld1 {v16.4s}, [%[bias]]\n" \ "ld1 {v16.4s}, [%[bias]]\n" \
"st1 {v17.4s}, [%[dout_ptr1]], #16\n" \ "st1 {v17.4s}, [%[dout_ptr1]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_RELU \ #define RESULT_S1_RELU_OUT2 \
"fmax v14.4s, v14.4s, %[vzero].4s\n" \ "fmax v14.4s, v14.4s, %[vzero].4s\n" \
"ld1 {v15.4s}, [%[bias]]\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"fmax v17.4s, v17.4s, %[vzero].4s\n" \ "fmax v17.4s, v17.4s, %[vzero].4s\n" \
"ld1 {v16.4s}, [%[bias]]\n" \ "ld1 {v16.4s}, [%[bias]]\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \ "st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"st1 {v17.4s}, [%[dout_ptr]], #16\n" \ "st1 {v17.4s}, [%[dout_ptr1]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_RELU6 \ #define RESULT_S1_RELU6_OUT2 \
"fmax 14.4s, v14.4s, %[vzero].4s\n" \ "fmax v14.4s, v14.4s, %[vzero].4s\n" \
"ld1 {v15.4s}, [%[bias]]\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"fmax v17.4s, v17.4s, %[vzero].4s\n" \ "fmax v17.4s, v17.4s, %[vzero].4s\n" \
"ld1 {v16.4s}, [%[bias]]\n" \ "ld1 {v16.4s}, [%[bias]]\n" \
"fmin v14.4s, v14.4s, %[vsix].4s\n" \ "fmin v14.4s, v14.4s, %[vsix].4s\n" \
"fmin v17.4s, v17.4s, %[vsix].4s\n" \ "fmin v17.4s, v17.4s, %[vsix].4s\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \ "st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"st1 {v17.4s}, [%[dout_ptr]], #16\n" \ "st1 {v17.4s}, [%[dout_ptr1]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_LEAKY_RELU \ #define RESULT_S1_LEAKY_RELU_OUT2 \
"fcmge v18.4s, v14.4s, %[vzero].4s\n" \ "fcmge v18.4s, v14.4s, %[vzero].4s\n" \
"fmul v19.4s, v14.4s, %[vscale].4s\n" \ "fmul v19.4s, v14.4s, %[vscale].4s\n" \
"ld1 {v15.4s}, [%[bias]]\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
...@@ -680,7 +681,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -680,7 +681,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"bif v14.16b, v19.16b, v18.16b\n" \ "bif v14.16b, v19.16b, v18.16b\n" \
"bif v17.16b, v21.16b, v20.16b\n" \ "bif v17.16b, v21.16b, v20.16b\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \ "st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"st1 {v17.4s}, [%[dout_ptr]], #16\n" \ "st1 {v17.4s}, [%[dout_ptr1]], #16\n" \
"bne 1b" "bne 1b"
#else #else
#define COMPUTE_ONE_LINE_S1_PRE \ #define COMPUTE_ONE_LINE_S1_PRE \
...@@ -1184,20 +1185,23 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -1184,20 +1185,23 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"bne 1b" "bne 1b"
#endif #endif
inline float compute_one_data_pre(const float* data, float32x4_t wr, float bias_val, float wei_val, int num) { // clang-format on
inline float compute_one_data_pre(
const float* data, float32x4_t wr, float bias_val, float wei_val, int num) {
float sum = bias_val; float sum = bias_val;
int index = 4 - num; int index = 4 - num;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += data[i] * wr[index + i]; sum += data[i] * wr[index + i];
} }
sum += data[num] * wei_val; sum += data[num] * wei_val;
return sum; return sum;
} }
inline float compute_one_data_post(const float* data, float32x4_t wr, float bias_val, float wei_val, int num) { inline float compute_one_data_post(
const float* data, float32x4_t wr, float bias_val, float wei_val, int num) {
float sum = bias_val; float sum = bias_val;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += data[i] * wr[i]; sum += data[i] * wr[i];
} }
sum += data[num] * wei_val; sum += data[num] * wei_val;
return sum; return sum;
...@@ -1216,13 +1220,19 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1216,13 +1220,19 @@ inline void compute_all_padding_pre(float* dout,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
} }
*dout++ = sum; *dout++ = sum;
} }
// mid // mid
// clang-format off
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
case 0: case 0:
...@@ -1416,22 +1426,33 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1416,22 +1426,33 @@ inline void compute_all_padding_pre(float* dout,
} }
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp_index - i],
din_ptr_arr[tmp_index - i]++; weights[3 - i],
0.f,
weights[5][3 - i],
4);
din_ptr_arr[tmp_index - i]++;
} }
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
din_ptr_arr[tmp_index - k]++; din_ptr_arr[tmp_index - k]++;
} }
*dout++ = sum; *dout++ = sum;
...@@ -1451,12 +1472,18 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1451,12 +1472,18 @@ inline void compute_all_padding_mid(float* dout,
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum; *dout++ = sum;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -1516,22 +1543,33 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1516,22 +1543,33 @@ inline void compute_all_padding_mid(float* dout,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
din_ptr_arr[tmp - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[tmp - i]++;
} }
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[tmp - k]++; din_ptr_arr[tmp - k]++;
} }
*dout++ = sum; *dout++ = sum;
...@@ -1553,15 +1591,26 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1553,15 +1591,26 @@ inline void compute_all_padding_mid_out2(float* dout0,
int tmp = num - 1; int tmp = num - 1;
// left // left
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
sum1 += compute_one_data_pre(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout0++ = sum; *dout0++ = sum;
*dout1++ = sum1; *dout1++ = sum1;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -1626,14 +1675,25 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1626,14 +1675,25 @@ inline void compute_all_padding_mid_out2(float* dout0,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
sum1 += compute_one_data_post(din_ptr_arr[num - i],
weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[num - i]++; din_ptr_arr[num - i]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
...@@ -1642,12 +1702,22 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1642,12 +1702,22 @@ inline void compute_all_padding_mid_out2(float* dout0,
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[num - k]++; din_ptr_arr[num - k]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
...@@ -1655,6 +1725,7 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1655,6 +1725,7 @@ inline void compute_all_padding_mid_out2(float* dout0,
*dout1++ = sum1; *dout1++ = sum1;
} }
} }
inline void compute_all_padding_post(float* dout, inline void compute_all_padding_post(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
...@@ -1669,12 +1740,18 @@ inline void compute_all_padding_post(float* dout, ...@@ -1669,12 +1740,18 @@ inline void compute_all_padding_post(float* dout,
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum; *dout++ = sum;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -1866,23 +1943,34 @@ inline void compute_all_padding_post(float* dout, ...@@ -1866,23 +1943,34 @@ inline void compute_all_padding_post(float* dout,
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << "does not support";
} }
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[2 - i],
din_ptr_arr[2 - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[2 - i]++;
} }
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k]++; din_ptr_arr[2 - k]++;
} }
*dout++ = sum; *dout++ = sum;
...@@ -1904,7 +1992,7 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1904,7 +1992,7 @@ void conv_depthwise_5x5s1_bias(float* dout,
int pad_bottom, int pad_bottom,
int pad_left, int pad_left,
int pad_right, int pad_right,
ARMContext* ctx){ ARMContext* ctx) {
int loop_w = wout - pad_left - pad_right; int loop_w = wout - pad_left - pad_right;
int loop_h = hout - pad_top - pad_bottom; int loop_h = hout - pad_top - pad_bottom;
int in_size = win * hin; int in_size = win * hin;
...@@ -1945,12 +2033,22 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1945,12 +2033,22 @@ void conv_depthwise_5x5s1_bias(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {
din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_pre(dout_ptr0,
pad_right, cnt, remain, 4 - h); din_ptr_arr,
vbias,
weights_vec,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -1961,8 +2059,18 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1961,8 +2059,18 @@ void conv_depthwise_5x5s1_bias(float* dout,
dout_ptr1 = dout_ptr0 + wout; dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid_out2(dout_ptr0,
pad_right, cnt, remain, 4); dout_ptr1,
din_ptr_arr,
vbias,
weights_vec,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -1979,8 +2087,17 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1979,8 +2087,17 @@ void conv_depthwise_5x5s1_bias(float* dout,
din_ptr_arr[5] = din_ptr5; din_ptr_arr[5] = din_ptr5;
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid(dout_ptr0,
pad_right, cnt, remain, 4); din_ptr_arr,
vbias,
weights_vec,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -1995,8 +2112,17 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1995,8 +2112,17 @@ void conv_depthwise_5x5s1_bias(float* dout,
} }
// bottom // bottom
for (int h = 0; h < pad_bottom; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_post(dout_ptr0,
pad_right, cnt, remain, 3 - h); din_ptr_arr,
vbias,
weights_vec,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -2022,12 +2148,18 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2022,12 +2148,18 @@ inline void compute_all_padding_pre_relu(float* dout,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -2230,22 +2362,33 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2230,22 +2362,33 @@ inline void compute_all_padding_pre_relu(float* dout,
} }
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp_index - i],
weights[3 - i],
0.f,
weights[5][3 - i],
4);
din_ptr_arr[tmp_index - i]++; din_ptr_arr[tmp_index - i]++;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
din_ptr_arr[tmp_index - k]++; din_ptr_arr[tmp_index - k]++;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
...@@ -2265,12 +2408,18 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2265,12 +2408,18 @@ inline void compute_all_padding_mid_relu(float* dout,
int num) { int num) {
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// clang-format off
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU
...@@ -2331,22 +2480,33 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2331,22 +2480,33 @@ inline void compute_all_padding_mid_relu(float* dout,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
din_ptr_arr[tmp - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[tmp - i]++;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[tmp - k]++; din_ptr_arr[tmp - k]++;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
...@@ -2369,15 +2529,26 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2369,15 +2529,26 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1; int tmp1 = num + 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
sum1 += compute_one_data_pre(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout0++ = sum > 0.f ? sum : 0.f; *dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f;
} }
// clang-format off
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU_OUT2 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU_OUT2
...@@ -2443,16 +2614,26 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2443,16 +2614,26 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i],
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); weights[tmp - i],
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); 0.f,
din_ptr_arr[num - i]++; weights[5][tmp - i],
4);
sum1 += compute_one_data_post(din_ptr_arr[num - i],
weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[num - i]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? sum : 0.f; *dout0++ = sum > 0.f ? sum : 0.f;
...@@ -2460,12 +2641,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2460,12 +2641,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[num - k]++; din_ptr_arr[num - k]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
...@@ -2488,12 +2679,18 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2488,12 +2679,18 @@ inline void compute_all_padding_post_relu(float* dout,
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -2693,26 +2890,37 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2693,26 +2890,37 @@ inline void compute_all_padding_post_relu(float* dout,
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << "does not support";
} }
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[2 - i],
din_ptr_arr[2 - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[2 - i]++;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k]++; din_ptr_arr[2 - k]++;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
} }
...@@ -2731,7 +2939,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2731,7 +2939,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
int pad_bottom, int pad_bottom,
int pad_left, int pad_left,
int pad_right, int pad_right,
ARMContext* ctx){ ARMContext* ctx) {
int loop_w = wout - pad_left - pad_right; int loop_w = wout - pad_left - pad_right;
int loop_h = hout - pad_top - pad_bottom; int loop_h = hout - pad_top - pad_bottom;
int in_size = win * hin; int in_size = win * hin;
...@@ -2773,12 +2981,23 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2773,12 +2981,23 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {
din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_pre_relu(dout_ptr0,
pad_right, cnt, remain, 4 - h); din_ptr_arr,
vbias,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -2789,8 +3008,19 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2789,8 +3008,19 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
dout_ptr1 = dout_ptr0 + wout; dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_mid_relu_out2(dout_ptr0,
pad_right, cnt, remain, 4); dout_ptr1,
din_ptr_arr,
vbias,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -2807,8 +3037,18 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2807,8 +3037,18 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
din_ptr_arr[5] = din_ptr5; din_ptr_arr[5] = din_ptr5;
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_mid_relu(dout_ptr0,
pad_right, cnt, remain, 4); din_ptr_arr,
vbias,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -2823,8 +3063,18 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2823,8 +3063,18 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
} }
// bottom // bottom
for (int h = 0; h < pad_bottom; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_post_relu(dout_ptr0,
pad_right, cnt, remain, 3 - h); din_ptr_arr,
vbias,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -2855,12 +3105,18 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -2855,12 +3105,18 @@ inline void compute_all_padding_pre_relu6(float* dout,
int tmp_index = num - 1; int tmp_index = num - 1;
// left // left
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -3071,22 +3327,33 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -3071,22 +3327,33 @@ inline void compute_all_padding_pre_relu6(float* dout,
} }
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp_index - i],
din_ptr_arr[tmp_index - i]++; weights[3 - i],
0.f,
weights[5][3 - i],
4);
din_ptr_arr[tmp_index - i]++;
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
din_ptr_arr[tmp_index - k]++; din_ptr_arr[tmp_index - k]++;
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
...@@ -3111,12 +3378,18 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3111,12 +3378,18 @@ inline void compute_all_padding_mid_relu6(float* dout,
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// clang-format off
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6 asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6
...@@ -3179,22 +3452,33 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3179,22 +3452,33 @@ inline void compute_all_padding_mid_relu6(float* dout,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
din_ptr_arr[tmp - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[tmp - i]++;
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[tmp - k]++; din_ptr_arr[tmp - k]++;
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
...@@ -3222,15 +3506,26 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3222,15 +3506,26 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1; int tmp1 = num + 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
sum1 += compute_one_data_pre(din_ptr_arr[num -k], weights[tmp -k], 0.f, weights[5][tmp - k], 4 - i); weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num -k],
weights[tmp -k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
} }
// clang-format off
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU6_OUT2 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU6_OUT2
...@@ -3298,15 +3593,26 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3298,15 +3593,26 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); weights[tmp - i],
din_ptr_arr[num - i]++; 0.f,
weights[5][tmp - i],
4);
sum1 += compute_one_data_post(din_ptr_arr[num - i],
weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[num - i]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
...@@ -3314,12 +3620,22 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3314,12 +3620,22 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[num - k]++; din_ptr_arr[num - k]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
...@@ -3346,12 +3662,18 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3346,12 +3662,18 @@ inline void compute_all_padding_post_relu6(float* dout,
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -3559,23 +3881,34 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3559,23 +3881,34 @@ inline void compute_all_padding_post_relu6(float* dout,
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << "does not support";
} }
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[2 - i],
din_ptr_arr[2 - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[2 - i]++;
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k]++; din_ptr_arr[2 - k]++;
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
...@@ -3640,12 +3973,24 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3640,12 +3973,24 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {
din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_pre_relu6(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 4 - h); din_ptr_arr,
vbias,
six,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -3656,8 +4001,20 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3656,8 +4001,20 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
dout_ptr1 = dout_ptr0 + wout; dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu6_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_mid_relu6_out2(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 4); dout_ptr1,
din_ptr_arr,
vbias,
six,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -3674,8 +4031,19 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3674,8 +4031,19 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
din_ptr_arr[5] = din_ptr5; din_ptr_arr[5] = din_ptr5;
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_mid_relu6(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 4); din_ptr_arr,
vbias,
six,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -3690,8 +4058,19 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3690,8 +4058,19 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
} }
// bottom // bottom
for (int h = 0; h < pad_bottom; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_post_relu6(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 3 - h); din_ptr_arr,
vbias,
six,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -3722,12 +4101,18 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3722,12 +4101,18 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
int tmp_index = num - 1; int tmp_index = num - 1;
// left // left
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -3946,28 +4331,39 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3946,28 +4331,39 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
} }
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp_index - i],
weights[3 - i],
0.f,
weights[5][3 - i],
4);
din_ptr_arr[tmp_index - i]++; din_ptr_arr[tmp_index - i]++;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
din_ptr_arr[tmp_index - k]++; din_ptr_arr[tmp_index - k]++;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
for (int w = pad_right; w > 4; w--) { for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0]; *dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
} }
} }
...@@ -3988,17 +4384,20 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -3988,17 +4384,20 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
float32x4_t vscale = vld1q_f32(scale); float32x4_t vscale = vld1q_f32(scale);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// clang-format off
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
...@@ -4063,23 +4462,34 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4063,23 +4462,34 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
din_ptr_arr[tmp - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[tmp - i]++;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[tmp - k]++; din_ptr_arr[tmp - k]++;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
...@@ -4106,15 +4516,29 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4106,15 +4516,29 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1; int tmp1 = num + 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1],
weights[num],
bias[0],
weights[6][0],
4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k],
sum1 += compute_one_data_pre(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
} }
// clang-format off
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_LEAKY_RELU_OUT2 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_LEAKY_RELU_OUT2
...@@ -4185,15 +4609,26 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4185,15 +4609,26 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 4;
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i],
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); weights[tmp - i],
din_ptr_arr[num - i]++; 0.f,
weights[5][tmp - i],
4);
sum1 += compute_one_data_post(din_ptr_arr[num - i],
weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[num - i]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout0++ = sum > 0.f ? sum : sum * scale[0];
...@@ -4201,12 +4636,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4201,12 +4636,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k],
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[num - k]++; din_ptr_arr[num - k]++;
} }
din_ptr_arr[0]++; din_ptr_arr[0]++;
...@@ -4233,12 +4678,18 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4233,12 +4678,18 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// clang-format off
// mid // mid
if (cnt > 0) { if (cnt > 0) {
switch (num) { switch (num) {
...@@ -4454,23 +4905,34 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4454,23 +4905,34 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << "does not support";
} }
} }
// clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[2 - i],
din_ptr_arr[2 - i]++; weights[tmp - i],
0.f,
weights[5][tmp - i],
4);
din_ptr_arr[2 - i]++;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k]++; din_ptr_arr[2 - k]++;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
...@@ -4535,12 +4997,24 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4535,12 +4997,24 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {
din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_pre_leakyRelu(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 4 - h); din_ptr_arr,
vbias,
scale,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -4551,8 +5025,20 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4551,8 +5025,20 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
dout_ptr1 = dout_ptr0 + wout; dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_leakyRelu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_mid_leakyRelu_out2(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 4); dout_ptr1,
din_ptr_arr,
vbias,
scale,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -4569,8 +5055,19 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4569,8 +5055,19 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
din_ptr_arr[5] = din_ptr5; din_ptr_arr[5] = din_ptr5;
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_mid_leakyRelu(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 4); din_ptr_arr,
vbias,
scale,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -4585,8 +5082,19 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4585,8 +5082,19 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
} }
// bottom // bottom
for (int h = 0; h < pad_bottom; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_post_leakyRelu(dout_ptr0,
win, wout, pad_left, pad_right, cnt, remain, 3 - h); din_ptr_arr,
vbias,
scale,
weights_vec,
vzero,
win,
wout,
pad_left,
pad_right,
cnt,
remain,
3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
......
...@@ -622,7 +622,8 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -622,7 +622,8 @@ void conv_depthwise_3x3_fp32(const void* din,
bool ch_four = ch_in <= 4 * w_in; bool ch_four = ch_in <= 4 * w_in;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (stride == 1) { if (stride == 1) {
if (ch_four && pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] if (ch_four && pads_less && (pad_h == pad_w) &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -677,7 +678,8 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -677,7 +678,8 @@ void conv_depthwise_3x3_fp32(const void* din,
#endif #endif
} }
} else if (stride == 2) { } else if (stride == 2) {
if (ch_four && pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] if (ch_four && pads_less && pad_h == pad_w &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
......
...@@ -61,7 +61,8 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -61,7 +61,8 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// VLOG(5) << "invoke 5x5 dw conv fp32"; // VLOG(5) << "invoke 5x5 dw conv fp32";
bool pads_five = (paddings[0] < 5) || (paddings[2] < 5); bool pads_five = (paddings[0] < 5) || (paddings[2] < 5);
auto strides = param.strides; auto strides = param.strides;
if (ch_four && pads_five && win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) { if (ch_four && pads_five && win >= kw && hin >= kw &&
(strides[0] == 1 && strides[1] == 1)) {
flag_trans_weights_ = false; flag_trans_weights_ = false;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册