From 5ccb1c1a2a37c38183796d9d69861a94c8eea80a Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Wed, 19 Aug 2020 17:43:49 +0800 Subject: [PATCH] add 5x5s2_dw relu6/leakyyrelu --- .../arm/math/conv5x5s2_depthwise_fp32.cc | 834 +++++++++--------- 1 file changed, 403 insertions(+), 431 deletions(-) diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index 9a4599e4c8..a3276962b5 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -535,110 +535,6 @@ void conv_depthwise_5x5s2_fp32(float* dout, "fadd v16.4s, v16.4s, v15.4s\n" \ "fadd v18.4s, v18.4s, v17.4s\n" -#define COMPUTE_FIVE_LINE_S2_OUT2_1 \ - "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ - "ld1 {v15.4s}, [%[bias]]\n" \ - "ld1 {v17.4s}, [%[bias]]\n" \ - "ld1 {v11.4s}, [%[din_ptr0]]\n" /*891011*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "1: \n" \ - "subs %w[cnt], %w[cnt], #1 \n" \ - "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ - "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ - "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ - "mov v13.s[3], v11.s[1]\n" /*3579*/ \ - "mov v14.s[3], v11.s[2]\n" /*46810*/ \ - "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ - "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ - "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ - "fmla v17.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ - "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ - "fmul v18.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ - "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ - "mov v13.s[3], v11.s[1]\n" /*3579*/ \ - "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ - "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ - "fmla v17.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ - "mov v14.s[3], v11.s[2]\n" /*46810*/ \ - "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ - "fmla v18.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "fmla v16.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ - "fmla v17.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ - "fmla v18.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ - "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ - "fmla v17.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ - "ld2 {v9.4s, v10.4s}, [%[din_ptr3]], #32\n" \ - "mov v13.s[3], v11.s[1]\n" /*3579*/ \ - "ldr d22, [%[din_ptr3]]\n" /*891011*/ \ - "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ - "fmla v18.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ - "mov v14.s[3], v11.s[2]\n" /*46810*/ \ - "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ - "fmla v17.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "fmla v15.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ - "fmla v18.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "fmla v16.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ - "fmla v17.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ - "fmla v15.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ - "fmla v18.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ - "ld2 {v9.4s, v10.4s}, [%[din_ptr4]], #32\n" \ - "mov v13.s[3], v11.s[1]\n" /*3579*/ \ - "ldr d22, [%[din_ptr4]]\n" /*891011*/ \ - "fmla v16.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ - "fmla v17.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ - "mov v14.s[3], v11.s[2]\n" /*46810*/ \ - "fmla v15.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ - "fmla v18.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "fmla v16.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ - "fmla v17.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "fmla v15.4s, v9.4s, %[wr4].s[0]\n" /*0246*wr4[0]*/ \ - "fmla v18.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ - "fmla v16.4s, v10.4s, %[wr4].s[1]\n" /*1357*wr4[1]*/ \ - "fmla v17.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ - "ld2 {v9.4s, v10.4s}, [%[din_ptr5]], #32\n" \ - "mov v13.s[3], v11.s[1]\n" /*3579*/ \ - "ldr d22, [%[din_ptr5]]\n" /*891011*/ \ - "fmla v15.4s, v12.4s, %[wr4].s[2]\n" /*2468*wr4[2]*/ \ - "fmla v18.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ - "mov v14.s[3], v11.s[2]\n" /*46810*/ \ - "fmla v16.4s, v13.4s, %[wr4].s[3]\n" /*3579*wr4[3]*/ \ - "fmla v17.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "fmla v15.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ - "fmla v18.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "fmla v17.4s, v9.4s, %[wr4].s[0]\n" /*0246*wr4[0]*/ \ - "fmla v18.4s, v10.4s, %[wr4].s[1]\n" /*1357*wr4[1]*/\ - "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ - "mov v13.s[3], v11.s[1]\n" /*3579*/ \ - "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ - "fmla v17.4s, v12.4s, %[wr4].s[2]\n" /*2468*wr4[2]*/ \ - "mov v14.s[3], v11.s[2]\n" /*46810*/ \ - "fmla v18.4s, v13.4s, %[wr4].s[3]\n" /*3579*wr4[3]*/ \ - "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ - "fmla v17.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ - "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ - "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ - "fadd v16.4s, v16.4s, v15.4s\n" \ - "fadd v18.4s, v18.4s, v17.4s\n" #define COMPUTE_ONE_LINE_S2_POST \ "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ "ld1 {v15.4s}, [%[bias]]\n" \ @@ -1177,103 +1073,6 @@ void conv_depthwise_5x5s2_fp32(float* dout, "vext.32 q13, q8, q10, #2\n" /*46810*/ \ "vmla.f32 q14, q12, %f[wr4][1]\n" /*3579*wr4[3]*/\ "vext.32 d25, d19, d21, #1\n" /*57-79*/ -#define COMPUTE_FIVE_LINE_S2_OUT2_1 \ - "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ - "vld1.f32 {d30-d31}, [%[bias]]\n" \ - "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "1: \n" \ - "subs %[cnt], #1\n" \ - "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ - "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ - "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ - "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ - "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ - "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ - "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vadd.f32 q15, q15, q14\n" \ - "vld1.f32 {d28-d29}, [%[bias]]\n" \ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ - "vmla.f32 q15, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ - "vmla.f32 q14, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ - "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ - "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ - "vmla.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ - "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ - "vmla.f32 q15, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ - "vmla.f32 q14, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ - "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ - "vmla.f32 q15, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ - "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "vmla.f32 q15, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ - "vmla.f32 q14, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ - "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ - "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ - "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ - "vmla.f32 q15, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ - "vmla.f32 q14, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ - "vld2.f32 {d16-d19}, [%[din_ptr3]]!\n" \ - "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ - "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ - "vld2.f32 {d20-d21}, [%[din_ptr3]]\n" /*810911*/\ - "vmla.f32 q15, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ - "vmla.f32 q14, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ - "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "vmla.f32 q15, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ - "vmla.f32 q14, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ - "vmla.f32 q15, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ - "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ - "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ - "vld2.f32 {d16-d19}, [%[din_ptr4]]!\n" \ - "vmla.f32 q15, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ - "vmla.f32 q14, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ - "vld2.f32 {d20-d21}, [%[din_ptr4]]\n" /*810911*/\ - "vmla.f32 q15, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ - "vmla.f32 q14, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vmla.f32 q15, q12, %f[wr3][1]\n" /*3579*wr3[3]*/\ - "vmla.f32 q14, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ - "vmla.f32 q15, q8, %e[wr4][0]\n" /*0246*wr4[0]*/ \ - "vmla.f32 q14, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ - "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ - "vmla.f32 q15, q9, %e[wr4][1]\n" /*1357*wr4[1]*/ \ - "vmla.f32 q14, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ - "vld2.f32 {d16-d19}, [%[din_ptr5]]!\n" \ - "vmla.f32 q15, q11, %f[wr4][0]\n" /*2468*wr4[2]*/\ - "vmla.f32 q14, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ - "vld2.f32 {d20-d21}, [%[din_ptr5]]\n" /*810911*/\ - "vmla.f32 q15, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ - "vmla.f32 q14, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vmla.f32 q15, q12, %f[wr4][1]\n" /*3579*wr4[3]*/\ - "vmla.f32 q14, q12, %f[wr3][1]\n" /*3579*wr4[3]*/\ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ - "vmla.f32 q14, q8, %e[wr4][0]\n" /*0246*wr4[0]*/ \ - "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ - "vmla.f32 q14, q9, %e[wr4][1]\n" /*1357*wr4[1]*/ \ - "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ - "vmla.f32 q14, q11, %f[wr4][0]\n" /*2468*wr4[2]*/\ - "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ - "vmla.f32 q14, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ - "vext.32 q11, q8, q10, #1\n" /*2468*/ \ - "vext.32 q13, q8, q10, #2\n" /*46810*/ \ - "vmla.f32 q14, q12, %f[wr4][1]\n" /*3579*wr4[3]*/\ - "vext.32 d25, d19, d21, #1\n" /*57-79*/ #define COMPUTE_ONE_LINE_S2_POST \ "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ "vld1.f32 {d30-d31}, [%[bias]]\n" \ @@ -1543,19 +1342,6 @@ inline void compute_all_padding_pre(float* dout, din_ptr_arr[tmp_index - k]++; } } -#if 0 //def __aarch64__ - LOG(INFO) << "data: " << din_ptr_arr[0][0] << ", " << din_ptr_arr[0][1] << ", " << din_ptr_arr[0][2] << ", " << din_ptr_arr[0][3]; - LOG(INFO) << "----"; - asm volatile( - "ld1 {v11.4s}, [%[din_ptr]]\n" - "ld1 {v14.4s}, [%[din_ptr]]\n" - "mov v14.s[3], v11.s[2]\n" - "st1 {v14.4s}, [%[din_ptr]]\n" - :[din_ptr] "+r"(din_ptr_arr[0]) - : - : "cc", "memory", "v10", "v11"); -LOG(INFO) << "data: " << din_ptr_arr[0][0] << ", " << din_ptr_arr[0][1] << ", " << din_ptr_arr[0][2] << ", " << din_ptr_arr[0][3]; -#endif // mid // clang-format off if (cnt > 0) { @@ -2132,7 +1918,7 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 4; + din_ptr_arr[3] -= 8; break; case 1: #ifdef __aarch64__ @@ -2176,7 +1962,7 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 4; + din_ptr_arr[2] -= 8; break; case 2: #ifdef __aarch64__ @@ -2224,7 +2010,7 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 4; + din_ptr_arr[1] -= 8; break; case 3: #ifdef __aarch64__ @@ -2481,7 +2267,7 @@ void conv_depthwise_5x5s2_bias(float* dout, } // bottom h_in_num = n_bottom_h; - for (int h = 0; h < pad_bottom; h++) { + for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post(dout_ptr0, din_ptr_arr, vbias, @@ -2509,8 +2295,7 @@ inline void compute_all_padding_pre_relu(float* dout, const float* bias, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -2529,6 +2314,12 @@ inline void compute_all_padding_pre_relu(float* dout, } *dout++ = sum > 0.f ? sum : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } // clang-format off // mid if (cnt > 0) { @@ -2730,21 +2521,21 @@ inline void compute_all_padding_pre_relu(float* dout, default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); - din_ptr_arr[tmp_index - i]++; + din_ptr_arr[tmp_index - i] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } @@ -2752,14 +2543,14 @@ inline void compute_all_padding_pre_relu(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); - din_ptr_arr[tmp_index - k]++; + din_ptr_arr[tmp_index - k] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } @@ -2769,8 +2560,7 @@ inline void compute_all_padding_mid_relu(float* dout, const float* bias, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -2789,6 +2579,12 @@ inline void compute_all_padding_mid_relu(float* dout, } *dout++ = sum > 0.f ? sum : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp - k]++; + } + } // clang-format off if (cnt > 0) { #ifdef __aarch64__ @@ -2848,18 +2644,18 @@ inline void compute_all_padding_mid_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[tmp - i]++; + din_ptr_arr[tmp - i] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } @@ -2867,14 +2663,14 @@ inline void compute_all_padding_mid_relu(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[tmp - k]++; + din_ptr_arr[tmp - k] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } @@ -2885,16 +2681,16 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, const float* bias, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, int remain, int num) { // left + int tmp1 = num + 2; + int tmp2 = num + 1; int tmp = num - 1; - int tmp1 = num + 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); @@ -2906,7 +2702,7 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, 0.f, weights[5][tmp - k], 4 - i); - sum1 += compute_one_data_pre(din_ptr_arr[num - k], + sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], @@ -2915,6 +2711,14 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, *dout0++ = sum > 0.f ? sum : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp2 - k]++; + } + din_ptr_arr[1]++; + din_ptr_arr[0]++; + } // clang-format off if (cnt > 0) { #ifdef __aarch64__ @@ -2948,7 +2752,8 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, "v14", "v15", "v16", - "v17"); + "v17", + "v18"); #else asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU_OUT2 : [cnt] "+r"(cnt), @@ -2981,7 +2786,7 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain @@ -2990,15 +2795,16 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum1 = compute_one_data_post( din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); - din_ptr_arr[tmp1]++; + din_ptr_arr[tmp1] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum1 += compute_one_data_post( - din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[num - i]++; + din_ptr_arr[tmp2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp2 - i] += 2; } - din_ptr_arr[0]++; + din_ptr_arr[1] += 2; + din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? sum : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f; } @@ -3008,21 +2814,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post( din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[tmp1]++; + din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - sum1 += compute_one_data_post(din_ptr_arr[num - k], + sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[num - k]++; + din_ptr_arr[tmp2 - k] += 2; } - din_ptr_arr[0]++; + din_ptr_arr[0] += 2; + din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? sum : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f; } @@ -3032,8 +2839,7 @@ inline void compute_all_padding_post_relu(float* dout, const float* bias, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -3053,6 +2859,12 @@ inline void compute_all_padding_post_relu(float* dout, } *dout++ = sum > 0.f ? sum : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp - k]++; + } + } // clang-format off // mid if (cnt > 0) { @@ -3097,7 +2909,7 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 4; + din_ptr_arr[3] -= 8; break; case 1: #ifdef __aarch64__ @@ -3143,7 +2955,7 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 4; + din_ptr_arr[2] -= 8; break; case 2: #ifdef __aarch64__ @@ -3193,7 +3005,7 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 4; + din_ptr_arr[1] -= 8; break; case 3: #ifdef __aarch64__ @@ -3247,7 +3059,7 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; break; default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; @@ -3258,11 +3070,11 @@ inline void compute_all_padding_post_relu(float* dout, for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3]++; + din_ptr_arr[3] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i]++; + din_ptr_arr[2 - i] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } @@ -3270,14 +3082,14 @@ inline void compute_all_padding_post_relu(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3]++; + din_ptr_arr[3] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[2 - k]++; + din_ptr_arr[2 - k] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } @@ -3299,16 +3111,32 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, int pad_left, int pad_right, ARMContext* ctx) { - int loop_w = wout - pad_left - pad_right; - int loop_h = hout - pad_top - pad_bottom; int in_size = win * hin; int out_size = wout * hout; - int cnt = loop_w >> 2; - int remain = loop_w & 3; + int pad_left_new = (pad_left + 1) / 2; + int pad_right_new = pad_right / 2; + int pad_top_new = (pad_top + 1) / 2; + int pad_bottom_new = pad_bottom / 2; int in_channel_size = chin * in_size; int out_channel_size = chin * out_size; int weights_size = 25; int num_out = wout << 1; + int loop_w = wout - pad_left_new - pad_right_new; + int loop_h = hout - pad_top_new - pad_bottom_new; + bool odds_w = pad_left % 2; + bool odds_h = pad_top % 2; + if (loop_w != ((win - 4) / 2)) { + loop_w--; + pad_right_new++; + } + if (loop_h != ((hin - 4) / 2)) { + loop_h--; + pad_bottom_new++; + } + int cnt = loop_w >> 2; + int remain = loop_w & 3; + int n_top_h = 4 - pad_top; + int n_bottom_h = 4 -pad_bottom; float32x4_t vzero = vdupq_n_f32(0.f); for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; @@ -3325,6 +3153,7 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, const float* din_ptr3 = din_ptr2 + win; const float* din_ptr4 = din_ptr3 + win; const float* din_ptr5 = din_ptr4 + win; + const float* din_ptr6 = din_ptr5 + win; float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float* dout_ptr0 = dout_ch; float* dout_ptr1 = dout_ch; @@ -3341,10 +3170,11 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); const float* din_ptr_arr[] = { - din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5, din_ptr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h - for (int h = pad_top; h > 0; h--) { + int h_in_num = n_top_h; + for (int h = pad_top_new; h > 0; h--) { compute_all_padding_pre_relu(dout_ptr0, din_ptr_arr, vbias, @@ -3352,17 +3182,34 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, vzero, win, wout, - pad_left, - pad_right, + pad_left_new, + pad_right_new, cnt, remain, - 4 - h); + h_in_num); dout_ptr0 += wout; + h_in_num += 2; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + if (odds_h) { + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr5; + din_ptr5 = din_ptr6; + din_ptr6 += win; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; } dout_ptr1 = dout_ptr0 + wout; // mid_h @@ -3373,27 +3220,28 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, vbias, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, 4); dout_ptr0 += num_out; dout_ptr1 += num_out; - din_ptr0 = din_ptr2; - din_ptr1 = din_ptr3; - din_ptr2 = din_ptr4; - din_ptr3 = din_ptr5; - din_ptr4 = din_ptr5 + win; + din_ptr0 = din_ptr4; + din_ptr1 = din_ptr5; + din_ptr2 = din_ptr6; + din_ptr3 = din_ptr6 + win; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; + din_ptr4 = din_ptr3 + win; din_ptr_arr[2] = din_ptr2; din_ptr5 = din_ptr4 + win; din_ptr_arr[3] = din_ptr3; + din_ptr6 = din_ptr5 + win; din_ptr_arr[4] = din_ptr4; din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; } if (loop_h % 2 != 0) { compute_all_padding_mid_relu(dout_ptr0, @@ -3401,19 +3249,18 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, vbias, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, 4); dout_ptr0 = dout_ptr1; - din_ptr0 = din_ptr1; - din_ptr1 = din_ptr2; - din_ptr2 = din_ptr3; - din_ptr3 = din_ptr4; - din_ptr4 = din_ptr5; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr6; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; @@ -3421,7 +3268,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, din_ptr_arr[4] = din_ptr4; } // bottom - for (int h = 0; h < pad_bottom; h++) { + h_in_num = n_bottom_h; + for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post_relu(dout_ptr0, din_ptr_arr, vbias, @@ -3429,12 +3277,13 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, vzero, win, wout, - pad_left, - pad_right, + pad_left_new, + pad_right_new, cnt, remain, - 3 - h); + h_in_num); dout_ptr0 += wout; + h_in_num -= 2; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; @@ -3451,8 +3300,7 @@ inline void compute_all_padding_pre_relu6(float* dout, const float* six, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -3475,6 +3323,12 @@ inline void compute_all_padding_pre_relu6(float* dout, } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } // clang-format off // mid if (cnt > 0) { @@ -3684,21 +3538,21 @@ inline void compute_all_padding_pre_relu6(float* dout, default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); - din_ptr_arr[tmp_index - i]++; + din_ptr_arr[tmp_index - i] += 2; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -3706,14 +3560,14 @@ inline void compute_all_padding_pre_relu6(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); - din_ptr_arr[tmp_index - k]++; + din_ptr_arr[tmp_index - k] += 2; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -3724,8 +3578,7 @@ inline void compute_all_padding_mid_relu6(float* dout, const float* six, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -3748,6 +3601,12 @@ inline void compute_all_padding_mid_relu6(float* dout, } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp - k]++; + } + } // clang-format off if (cnt > 0) { #ifdef __aarch64__ @@ -3816,11 +3675,11 @@ inline void compute_all_padding_mid_relu6(float* dout, for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[tmp - i]++; + din_ptr_arr[tmp - i] += 2; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -3828,14 +3687,14 @@ inline void compute_all_padding_mid_relu6(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[tmp - k]++; + din_ptr_arr[tmp - k] += 2; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -3848,8 +3707,7 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, const float* six, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -3859,8 +3717,9 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, float32x4_t vsix = vld1q_f32(six); #endif // left + int tmp1 = num + 2; + int tmp2 = num + 1; int tmp = num - 1; - int tmp1 = num + 1; // clang-format off for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( @@ -3873,7 +3732,7 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, 0.f, weights[5][tmp - k], 4 - i); - sum1 += compute_one_data_pre(din_ptr_arr[num -k], + sum1 += compute_one_data_pre(din_ptr_arr[tmp2 -k], weights[tmp -k], 0.f, weights[5][tmp - k], @@ -3882,6 +3741,14 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp2 - k]++; + } + din_ptr_arr[1]++; + din_ptr_arr[0]++; + } if (cnt > 0) { #ifdef __aarch64__ asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU6_OUT2 @@ -3915,7 +3782,8 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, "v14", "v15", "v16", - "v17"); + "v17", + "v18"); #else asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU6_OUT2 : [cnt] "+r"(cnt), @@ -3949,7 +3817,7 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain @@ -3958,15 +3826,16 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum1 = compute_one_data_post( din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); - din_ptr_arr[tmp1]++; + din_ptr_arr[tmp1] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum1 += compute_one_data_post( - din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[num - i]++; + din_ptr_arr[tmp2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp2 - i] += 2; } - din_ptr_arr[0]++; + din_ptr_arr[1] += 2; + din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; } @@ -3983,14 +3852,15 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, 0.f, weights[tmp - k][3 - i], 3 - i); - sum1 += compute_one_data_post(din_ptr_arr[num - k], + sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[num - k]++; + din_ptr_arr[tmp2 - k] += 2; } - din_ptr_arr[0]++; + din_ptr_arr[1] += 2; + din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; } @@ -4001,8 +3871,7 @@ inline void compute_all_padding_post_relu6(float* dout, const float* six, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -4025,6 +3894,12 @@ inline void compute_all_padding_post_relu6(float* dout, } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp - k]++; + } + } // clang-format off // mid if (cnt > 0) { @@ -4071,7 +3946,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 4; + din_ptr_arr[3] -= 8; break; case 1: #ifdef __aarch64__ @@ -4119,7 +3994,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 4; + din_ptr_arr[2] -= 8; break; case 2: #ifdef __aarch64__ @@ -4171,7 +4046,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 4; + din_ptr_arr[1] -= 8; break; case 3: #ifdef __aarch64__ @@ -4227,7 +4102,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; break; default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; @@ -4238,11 +4113,11 @@ inline void compute_all_padding_post_relu6(float* dout, for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3]++; + din_ptr_arr[3] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i]++; + din_ptr_arr[2 - i] += 2; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -4250,14 +4125,14 @@ inline void compute_all_padding_post_relu6(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3]++; + din_ptr_arr[3] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[2 - k]++; + din_ptr_arr[2 - k] += 2; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -4280,16 +4155,32 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, int pad_left, int pad_right, ARMContext* ctx) { - int loop_w = wout - pad_left - pad_right; - int loop_h = hout - pad_top - pad_bottom; int in_size = win * hin; int out_size = wout * hout; - int cnt = loop_w >> 2; - int remain = loop_w & 3; + int pad_left_new = (pad_left + 1) / 2; + int pad_right_new = pad_right / 2; + int pad_top_new = (pad_top + 1) / 2; + int pad_bottom_new = pad_bottom / 2; int in_channel_size = chin * in_size; int out_channel_size = chin * out_size; int weights_size = 25; int num_out = wout << 1; + int loop_w = wout - pad_left_new - pad_right_new; + int loop_h = hout - pad_top_new - pad_bottom_new; + bool odds_w = pad_left % 2; + bool odds_h = pad_top % 2; + if (loop_w != ((win - 4) / 2)) { + loop_w--; + pad_right_new++; + } + if (loop_h != ((hin - 4) / 2)) { + loop_h--; + pad_bottom_new++; + } + int cnt = loop_w >> 2; + int remain = loop_w & 3; + int n_top_h = 4 - pad_top; + int n_bottom_h = 4 -pad_bottom; float32x4_t vzero = vdupq_n_f32(0.f); for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; @@ -4306,6 +4197,7 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, const float* din_ptr3 = din_ptr2 + win; const float* din_ptr4 = din_ptr3 + win; const float* din_ptr5 = din_ptr4 + win; + const float* din_ptr6 = din_ptr5 + win; float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float* dout_ptr0 = dout_ch; float* dout_ptr1 = dout_ch; @@ -4322,29 +4214,46 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); const float* din_ptr_arr[] = { - din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5, din_ptr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h - for (int h = pad_top; h > 0; h--) { + int h_in_num = n_top_h; + for (int h = pad_top_new; h > 0; h--) { compute_all_padding_pre_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, - 4 - h); + h_in_num); dout_ptr0 += wout; + h_in_num += 2; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + if (odds_h) { + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr5; + din_ptr5 = din_ptr6; + din_ptr6 += win; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; } dout_ptr1 = dout_ptr0 + wout; // mid_h @@ -4356,27 +4265,28 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, six, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, 4); dout_ptr0 += num_out; dout_ptr1 += num_out; - din_ptr0 = din_ptr2; - din_ptr1 = din_ptr3; - din_ptr2 = din_ptr4; - din_ptr3 = din_ptr5; - din_ptr4 = din_ptr5 + win; + din_ptr0 = din_ptr4; + din_ptr1 = din_ptr5; + din_ptr2 = din_ptr6; + din_ptr3 = din_ptr6 + win; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; + din_ptr4 = din_ptr3 + win; din_ptr_arr[2] = din_ptr2; din_ptr5 = din_ptr4 + win; din_ptr_arr[3] = din_ptr3; + din_ptr6 = din_ptr5 + win; din_ptr_arr[4] = din_ptr4; din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; } if (loop_h % 2 != 0) { compute_all_padding_mid_relu6(dout_ptr0, @@ -4385,19 +4295,18 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, six, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, 4); dout_ptr0 = dout_ptr1; - din_ptr0 = din_ptr1; - din_ptr1 = din_ptr2; - din_ptr2 = din_ptr3; - din_ptr3 = din_ptr4; - din_ptr4 = din_ptr5; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr6; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; @@ -4405,21 +4314,22 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, din_ptr_arr[4] = din_ptr4; } // bottom - for (int h = 0; h < pad_bottom; h++) { + h_in_num = n_bottom_h; + for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, - 3 - h); + h_in_num); dout_ptr0 += wout; + h_in_num -= 2; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; @@ -4436,8 +4346,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, const float* scale, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -4460,6 +4369,12 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, } *dout++ = sum > 0.f ? sum : sum * scale[0]; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } // clang-format off // mid if (cnt > 0) { @@ -4677,21 +4592,21 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); - din_ptr_arr[tmp_index - i]++; + din_ptr_arr[tmp_index - i] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -4699,14 +4614,14 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); - din_ptr_arr[tmp_index - k]++; + din_ptr_arr[tmp_index - k] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -4721,8 +4636,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, const float* scale, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds_w, int pad_left, int pad_right, int cnt, @@ -4745,6 +4659,12 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, } *dout++ = sum > 0.f ? sum : sum * scale[0]; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } // clang-format off if (cnt > 0) { #ifdef __aarch64__ @@ -4808,18 +4728,18 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[tmp - i]++; + din_ptr_arr[tmp - i] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -4827,14 +4747,14 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[num]++; + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[tmp - k]++; + din_ptr_arr[tmp - k] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -4846,8 +4766,7 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, const float* scale, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -4857,8 +4776,9 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, float32x4_t vscale = vld1q_f32(scale); #endif // left + int tmp1 = num + 2; + int tmp2 = num + 1; int tmp = num - 1; - int tmp1 = num + 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); @@ -4870,7 +4790,7 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, 0.f, weights[5][tmp - k], 4 - i); - sum1 += compute_one_data_pre(din_ptr_arr[num - k], + sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], @@ -4879,6 +4799,14 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, *dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp2 - k]++; + } + din_ptr_arr[1]++; + din_ptr_arr[0]++; + } // clang-format off if (cnt > 0) { #ifdef __aarch64__ @@ -4916,7 +4844,9 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, "v17", "v18", "v19", - "v20"); + "v20", + "v21", + "v22"); #else asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_LEAKY_RELU_OUT2 : [cnt] "+r"(cnt), @@ -4950,7 +4880,7 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain @@ -4959,15 +4889,16 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum1 = compute_one_data_post( din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); - din_ptr_arr[tmp1]++; + din_ptr_arr[tmp1] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum1 += compute_one_data_post( - din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[num - i]++; + din_ptr_arr[tmp2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp2 - i] += 2; } - din_ptr_arr[0]++; + din_ptr_arr[1] += 2; + din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; } @@ -4977,21 +4908,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post( din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[tmp1]++; + din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - sum1 += compute_one_data_post(din_ptr_arr[num - k], + sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[num - k]++; + din_ptr_arr[tmp2 - k] += 2; } - din_ptr_arr[0]++; + din_ptr_arr[1] += 2; + din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; } @@ -5002,8 +4934,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, const float* scale, float32x4_t* weights, float32x4_t vzero, - int win, - int wout, + bool odds, int pad_left, int pad_right, int cnt, @@ -5026,6 +4957,12 @@ inline void compute_all_padding_post_leakyRelu(float* dout, } *dout++ = sum > 0.f ? sum : sum * scale[0]; } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp - k]++; + } + } // clang-format off // mid if (cnt > 0) { @@ -5074,7 +5011,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 4; + din_ptr_arr[3] -= 8; break; case 1: #ifdef __aarch64__ @@ -5124,7 +5061,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 4; + din_ptr_arr[2] -= 8; break; case 2: #ifdef __aarch64__ @@ -5178,7 +5115,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 4; + din_ptr_arr[1] -= 8; break; case 3: #ifdef __aarch64__ @@ -5236,7 +5173,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; break; default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; @@ -5247,11 +5184,11 @@ inline void compute_all_padding_post_leakyRelu(float* dout, for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3]++; + din_ptr_arr[3] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i]++; + din_ptr_arr[2 - i] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -5259,14 +5196,14 @@ inline void compute_all_padding_post_leakyRelu(float* dout, for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3]++; + din_ptr_arr[3] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - din_ptr_arr[2 - k]++; + din_ptr_arr[2 - k] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -5289,16 +5226,32 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, int pad_left, int pad_right, ARMContext* ctx) { - int loop_w = wout - pad_left - pad_right; - int loop_h = hout - pad_top - pad_bottom; int in_size = win * hin; int out_size = wout * hout; - int cnt = loop_w >> 2; - int remain = loop_w & 3; + int pad_left_new = (pad_left + 1) / 2; + int pad_right_new = pad_right / 2; + int pad_top_new = (pad_top + 1) / 2; + int pad_bottom_new = pad_bottom / 2; int in_channel_size = chin * in_size; int out_channel_size = chin * out_size; int weights_size = 25; int num_out = wout << 1; + int loop_w = wout - pad_left_new - pad_right_new; + int loop_h = hout - pad_top_new - pad_bottom_new; + bool odds_w = pad_left % 2; + bool odds_h = pad_top % 2; + if (loop_w != ((win - 4) / 2)) { + loop_w--; + pad_right_new++; + } + if (loop_h != ((hin - 4) / 2)) { + loop_h--; + pad_bottom_new++; + } + int cnt = loop_w >> 2; + int remain = loop_w & 3; + int n_top_h = 4 - pad_top; + int n_bottom_h = 4 -pad_bottom; float32x4_t vzero = vdupq_n_f32(0.f); for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; @@ -5315,6 +5268,7 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, const float* din_ptr3 = din_ptr2 + win; const float* din_ptr4 = din_ptr3 + win; const float* din_ptr5 = din_ptr4 + win; + const float* din_ptr6 = din_ptr5 + win; float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float* dout_ptr0 = dout_ch; float* dout_ptr1 = dout_ch; @@ -5331,30 +5285,47 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); const float* din_ptr_arr[] = { - din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5, din_ptr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h - for (int h = pad_top; h > 0; h--) { + int h_in_num = n_top_h; + for (int h = pad_top_new; h > 0; h--) { compute_all_padding_pre_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, - 4 - h); + h_in_num); dout_ptr0 += wout; + h_in_num += 2; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; } + if (odds_h) { + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr5; + din_ptr5 = din_ptr6; + din_ptr6 += win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; + } dout_ptr1 = dout_ptr0 + wout; // mid_h for (int h = 0; h < loop_h - 1; h += 2) { @@ -5365,27 +5336,28 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, scale, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, 4); dout_ptr0 += num_out; dout_ptr1 += num_out; - din_ptr0 = din_ptr2; - din_ptr1 = din_ptr3; - din_ptr2 = din_ptr4; - din_ptr3 = din_ptr5; - din_ptr4 = din_ptr5 + win; + din_ptr0 = din_ptr4; + din_ptr1 = din_ptr5; + din_ptr2 = din_ptr6; + din_ptr3 = din_ptr6 + win; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; + din_ptr4 = din_ptr3 + win; din_ptr_arr[2] = din_ptr2; din_ptr5 = din_ptr4 + win; din_ptr_arr[3] = din_ptr3; + din_ptr6 = din_ptr5 + win; din_ptr_arr[4] = din_ptr4; din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; } if (loop_h % 2 != 0) { compute_all_padding_mid_leakyRelu(dout_ptr0, @@ -5394,19 +5366,18 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, scale, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, 4); dout_ptr0 = dout_ptr1; - din_ptr0 = din_ptr1; - din_ptr1 = din_ptr2; - din_ptr2 = din_ptr3; - din_ptr3 = din_ptr4; - din_ptr4 = din_ptr5; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr6; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; @@ -5414,21 +5385,22 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, din_ptr_arr[4] = din_ptr4; } // bottom - for (int h = 0; h < pad_bottom; h++) { + h_in_num = n_bottom_h; + for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, - win, - wout, - pad_left, - pad_right, + odds_w, + pad_left_new, + pad_right_new, cnt, remain, - 3 - h); + h_in_num); dout_ptr0 += wout; + h_in_num -= 2; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; -- GitLab