提交 c5925507 编写于 作者: C chenjiaoAngel

fix profiler

上级 6edb02a7
...@@ -193,9 +193,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -193,9 +193,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
} }
#ifdef __aarch64__ #ifdef __aarch64__
#define COMPUTE_ONE_LINE_S1_PRE \ #define COMPUTE_ONE_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
"ext v12.16b, v9.16b, v10.16b, #8\n" \ "ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
...@@ -213,9 +213,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -213,9 +213,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_TWO_LINE_S1_PRE \ #define COMPUTE_TWO_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
"ext v12.16b, v9.16b, v10.16b, #8\n" \ "ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
...@@ -243,9 +243,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -243,9 +243,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_THREE_LINE_S1_PRE \ #define COMPUTE_THREE_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
"ext v12.16b, v9.16b, v10.16b, #8\n" \ "ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
...@@ -283,9 +283,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -283,9 +283,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FOUR_LINE_S1_PRE \ #define COMPUTE_FOUR_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
"ext v12.16b, v9.16b, v10.16b, #8\n" \ "ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
...@@ -333,9 +333,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -333,9 +333,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FIVE_LINE_S1 \ #define COMPUTE_FIVE_LINE_S1 \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
"ext v12.16b, v9.16b, v10.16b, #8\n" \ "ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
...@@ -392,6 +392,98 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -392,6 +392,98 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"fmla v15.4s, v13.4s, %[wr4].s[3]\n" /*3456*wr4[3]*/ \ "fmla v15.4s, v13.4s, %[wr4].s[3]\n" /*3456*wr4[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FIVE_LINE_S1_OUT2 \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ld1 {v16.4s}, [%[bias]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"1: \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
"fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0123*wr0[0]*/ \
"fmul v14.4s, v10.4s, %[wr5].s[0]\n" /*4567*wr5[0]*/ \
"ld1 {v9.4s}, [%[din_ptr1]], #16\n" \
"fmla v15.4s, v11.4s, %[wr0].s[1]\n" /*1234*wr0[1]*/ \
"ld1 {v10.4s}, [%[din_ptr1]]\n" \
"fmla v14.4s, v12.4s, %[wr0].s[2]\n" /*2345*wr0[2]*/ \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"fmla v15.4s, v13.4s, %[wr0].s[3]\n" /*3456*wr0[3]*/ \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fmla v16.4s, v9.4s, %[wr0].s[0]\n" /*0123*wr1[0]*/ \
"fmla v14.4s, v9.4s, %[wr1].s[0]\n" /*0123*wr1[0]*/ \
"ld1 {v9.4s}, [%[din_ptr2]], #16\n" \
"fmul v17.4s, v10.4s, %[wr5].s[0]\n" /*4567*wr5[1]*/ \
"fmla v15.4s, v10.4s, %[wr5].s[1]\n" /*4567*wr5[1]*/ \
"ld1 {v10.4s}, [%[din_ptr2]]\n" \
"fmla v16.4s, v11.4s, %[wr0].s[1]\n" /*1234*wr1[1]*/ \
"fmla v14.4s, v11.4s, %[wr1].s[1]\n" /*1234*wr1[1]*/ \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"fmla v17.4s, v12.4s, %[wr0].s[2]\n" /*2345*wr1[2]*/ \
"fmla v15.4s, v12.4s, %[wr1].s[2]\n" /*2345*wr1[2]*/ \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3456*wr1[3]*/ \
"fmla v14.4s, v13.4s, %[wr1].s[3]\n" /*3456*wr1[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fmla v17.4s, v9.4s, %[wr1].s[0]\n" /*0123*wr2[0]*/ \
"fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0123*wr2[0]*/ \
"ld1 {v9.4s}, [%[din_ptr3]], #16\n" \
"fmla v16.4s, v10.4s, %[wr5].s[1]\n" /*4567*wr5[2]*/ \
"fmla v14.4s, v10.4s, %[wr5].s[2]\n" /*4567*wr5[2]*/ \
"ld1 {v10.4s}, [%[din_ptr3]]\n" \
"fmla v17.4s, v11.4s, %[wr1].s[1]\n" /*1234*wr2[1]*/ \
"fmla v15.4s, v11.4s, %[wr2].s[1]\n" /*1234*wr2[1]*/ \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2345*wr2[2]*/ \
"fmla v14.4s, v12.4s, %[wr2].s[2]\n" /*2345*wr2[2]*/ \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"fmla v17.4s, v13.4s, %[wr1].s[3]\n" /*3456*wr2[3]*/ \
"fmla v15.4s, v13.4s, %[wr2].s[3]\n" /*3456*wr2[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fmla v16.4s, v9.4s, %[wr2].s[0]\n" /*0123*wr3[0]*/ \
"fmla v14.4s, v9.4s, %[wr3].s[0]\n" /*0123*wr3[0]*/ \
"ld1 {v9.4s}, [%[din_ptr4]], #16\n" \
"fmla v17.4s, v10.4s, %[wr5].s[2]\n" /*4567*wr5[3]*/ \
"fmla v15.4s, v10.4s, %[wr5].s[3]\n" /*4567*wr5[3]*/ \
"ld1 {v10.4s}, [%[din_ptr4]]\n" \
"fmla v16.4s, v11.4s, %[wr2].s[1]\n" /*1234*wr3[1]*/ \
"fmla v14.4s, v11.4s, %[wr3].s[1]\n" /*1234*wr3[1]*/ \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"fmla v17.4s, v12.4s, %[wr2].s[2]\n" /*2345*wr3[2]*/ \
"fmla v15.4s, v12.4s, %[wr3].s[2]\n" /*2345*wr3[2]*/ \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3456*wr3[3]*/ \
"fmla v14.4s, v13.4s, %[wr3].s[3]\n" /*3456*wr3[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fmla v17.4s, v9.4s, %[wr3].s[0]\n" /*0123*wr4[0]*/ \
"fmla v15.4s, v9.4s, %[wr4].s[0]\n" /*0123*wr4[0]*/ \
"ld1 {v9.4s}, [%[din_ptr5]], #16\n" \
"fmla v16.4s, v10.4s, %[wr5].s[3]\n" /*4567*wr6[0]*/ \
"fmla v14.4s, v10.4s, %[wr6].s[0]\n" /*4567*wr6[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]]\n" \
"fmla v17.4s, v11.4s, %[wr3].s[1]\n" /*1234*wr4[1]*/ \
"fmla v15.4s, v11.4s, %[wr4].s[1]\n" /*1234*wr4[1]*/ \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"fmla v16.4s, v12.4s, %[wr3].s[2]\n" /*2345*wr4[2]*/ \
"fmla v14.4s, v12.4s, %[wr4].s[2]\n" /*2345*wr4[2]*/ \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"fmla v17.4s, v13.4s, %[wr3].s[3]\n" /*3456*wr4[3]*/ \
"fmla v15.4s, v13.4s, %[wr4].s[3]\n" /*3456*wr4[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fmla v16.4s, v9.4s, %[wr4].s[0]\n" /*0123*wr4[0]*/ \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"fmla v17.4s, v10.4s, %[wr6].s[0]\n" /*4567*wr6[0]*/ \
"fadd v14.4s, v14.4s, v15.4s\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"fmla v16.4s, v11.4s, %[wr4].s[1]\n" /*1234*wr4[1]*/ \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
"fmla v17.4s, v12.4s, %[wr4].s[2]\n" /*2345*wr4[2]*/ \
"ext v12.16b, v9.16b, v10.16b, #8\n" \
"fmla v16.4s, v13.4s, %[wr4].s[3]\n" /*3456*wr4[3]*/ \
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v17.4s, v17.4s, v16.4s\n" \
#define COMPUTE_ONE_LINE_S1_POST \ #define COMPUTE_ONE_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]]\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
...@@ -554,6 +646,42 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -554,6 +646,42 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"bif v16.16b, v18.16b, v17.16b\n" \ "bif v16.16b, v18.16b, v17.16b\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \ "st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_OUT2 \
"ld1 {v15.4s}, [%[bias]]\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"ld1 {v16.4s}, [%[bias]]\n" \
"st1 {v17.4s}, [%[dout_ptr1]], #16\n" \
"bne 1b"
#define RESULT_S1_RELU \
"fmax v14.4s, v14.4s, %[vzero].4s\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"fmax v17.4s, v17.4s, %[vzero].4s\n" \
"ld1 {v16.4s}, [%[bias]]\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"st1 {v17.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#define RESULT_S1_RELU6 \
"fmax 14.4s, v14.4s, %[vzero].4s\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"fmax v17.4s, v17.4s, %[vzero].4s\n" \
"ld1 {v16.4s}, [%[bias]]\n" \
"fmin v14.4s, v14.4s, %[vsix].4s\n" \
"fmin v17.4s, v17.4s, %[vsix].4s\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"st1 {v17.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#define RESULT_S1_LEAKY_RELU \
"fcmge v18.4s, v14.4s, %[vzero].4s\n" \
"fmul v19.4s, v14.4s, %[vscale].4s\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"fcmge v20.4s, v17.4s, %[vzero].4s\n" \
"fmul v21.4s, v17.4s, %[vscale].4s\n" \
"ld1 {v16.4s}, [%[bias]]\n" \
"bif v14.16b, v19.16b, v18.16b\n" \
"bif v17.16b, v21.16b, v20.16b\n" \
"st1 {v14.4s}, [%[dout_ptr0]], #16\n" \
"st1 {v17.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#else #else
#define COMPUTE_ONE_LINE_S1_PRE \ #define COMPUTE_ONE_LINE_S1_PRE \
"vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d30-d31}, [%[bias]]\n" \
...@@ -755,6 +883,98 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -755,6 +883,98 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q11, q8, q9, #2\n" \ "vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \ "vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n" "vadd.f32 q14, q14, q15\n"
#define COMPUTE_FIVE_LINE_S1_OUT2 \
"vld1.f32 {d28-d29}, [%[bias]]\n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
"1: \n" \
"subs %[cnt], #1\n" \
"vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \
"vmul.f32 q13, q9, %e[wr5][0]\n" /*4567*wr5[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \
"vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \
"vmla.f32 q13, q11, %f[wr0][0]\n" /*2345*wr0[2]*/\
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q15, q12, %f[wr0][1]\n" /*3456*wr0[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q13, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \
"vmla.f32 q14, q8, %e[wr0][0]\n" /*0123*wr1[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \
"vmla.f32 q15, q9, %e[wr5][1]\n" /*4567*wr5[1]*/ \
"vmla.f32 q14, q9, %e[wr5][0]\n" /*4567*wr5[1]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \
"vmla.f32 q13, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\
"vmla.f32 q14, q10, %e[wr0][1]\n" /*1234*wr1[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q15, q11, %f[wr1][0]\n" /*2345*wr1[2]*/\
"vmla.f32 q13, q12, %f[wr1][1]\n" /*3456*wr1[3]*/\
"vmla.f32 q14, q11, %f[wr0][0]\n" /*2345*wr1[2]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr0][1]\n" /*3456*wr1[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q15, q8, %e[wr2][0]\n" /*0123*wr2[0]*/ \
"vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr2[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr3]]!\n" \
"vmla.f32 q13, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \
"vmla.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr3]]\n" \
"vmla.f32 q15, q10, %e[wr2][1]\n" /*1234*wr2[1]*/\
"vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr2[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q13, q11, %f[wr2][0]\n" /*2345*wr2[2]*/\
"vmla.f32 q14, q11, %f[wr1][0]\n" /*2345*wr2[2]*/\
"vmla.f32 q15, q12, %f[wr2][1]\n" /*3456*wr2[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr1][1]\n" /*3456*wr2[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q13, q8, %e[wr3][0]\n" /*0123*wr3[0]*/ \
"vmla.f32 q14, q8, %e[wr2][0]\n" /*0123*wr3[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr4]]!\n" \
"vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \
"vmla.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr4]]\n" \
"vmla.f32 q13, q10, %e[wr3][1]\n" /*1234*wr3[1]*/\
"vmla.f32 q14, q10, %e[wr2][1]\n" /*1234*wr3[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q15, q11, %f[wr3][0]\n" /*2345*wr3[2]*/\
"vmla.f32 q14, q11, %f[wr2][0]\n" /*2345*wr3[2]*/\
"vmla.f32 q13, q12, %f[wr3][1]\n" /*3456*wr3[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr2][1]\n" /*3456*wr3[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q15, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \
"vmla.f32 q14, q8, %e[wr3][0]\n" /*0123*wr4[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr5]]!\n" \
"vmla.f32 q13, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \
"vmla.f32 q14, q9, %f[wr5][1]\n" /*4567*wr6[0]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr5]]\n" \
"vmla.f32 q15, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\
"vmla.f32 q14, q10, %e[wr3][1]\n" /*1234*wr4[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q13, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\
"vmla.f32 q14, q11, %f[wr3][0]\n" /*2345*wr4[2]*/\
"vmla.f32 q15, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr3][1]\n" /*3456*wr4[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q13, q13, q15\n" \
"vmla.f32 q14, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \
"vmul.f32 q15, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vmla.f32 q14, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr0]]\n" \
"vmla.f32 q15, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define COMPUTE_ONE_LINE_S1_POST \ #define COMPUTE_ONE_LINE_S1_POST \
"vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
...@@ -921,99 +1141,21 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -921,99 +1141,21 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q11, q8, q9, #2\n" \ "vext.32 q11, q8, q9, #2\n" \
"vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \ "vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \
"bne 1b" "bne 1b"
#define COMPUTE_FIVE_LINE_S1_1 \ #define RESULT_S1_OUT2 \
"vst1.f32 {d26-d27}, [%[dout_ptr0]]!\n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \
"vld1.f32 {d28-d29}, [%[bias]]\n" \ "vld1.f32 {d28-d29}, [%[bias]]\n" \
"bne 1b"
#define RESULT_S1_RELU_OUT2 \
"vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vmax.f32 q13, q13, %q[vzero]\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ "vmax.f32 q14, q14, %q[vzero]\n" \
"vext.32 q10, q8, q9, #1\n" \ "vst1.f32 {d26-d27}, [%[dout_ptr0]]!\n" \
"vext.32 q11, q8, q9, #2\n" \ "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \
"vext.32 q12, q8, q9, #3\n" \ "vld1.f32 {d28-d29}, [%[bias]]\n" \
"1: \n" \ "bne 1b"
"subs %[cnt], #1\n" \ #define RESULT_S1_RELU6_OUT2 \
"vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \
"vmul.f32 q13, q9, %e[wr5][0]\n" /*4567*wr5[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \
"vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \
"vmla.f32 q13, q11, %f[wr0][0]\n" /*2345*wr0[2]*/\
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q15, q12, %f[wr0][1]\n" /*3456*wr0[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q13, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \
"vmla.f32 q14, q8, %e[wr0][0]\n" /*0123*wr1[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \
"vmla.f32 q15, q9, %e[wr5][1]\n" /*4567*wr5[1]*/ \
"vmla.f32 q14, q9, %e[wr5][0]\n" /*4567*wr5[1]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \
"vmla.f32 q13, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\
"vmla.f32 q14, q10, %e[wr0][1]\n" /*1234*wr1[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q15, q11, %f[wr1][0]\n" /*2345*wr1[2]*/\
"vmla.f32 q13, q12, %f[wr1][1]\n" /*3456*wr1[3]*/\
"vmla.f32 q14, q11, %f[wr0][0]\n" /*2345*wr1[2]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr0][1]\n" /*3456*wr1[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q15, q8, %e[wr2][0]\n" /*0123*wr2[0]*/ \
"vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr2[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr3]]!\n" \
"vmla.f32 q13, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \
"vmla.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr3]]\n" \
"vmla.f32 q15, q10, %e[wr2][1]\n" /*1234*wr2[1]*/\
"vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr2[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q13, q11, %f[wr2][0]\n" /*2345*wr2[2]*/\
"vmla.f32 q14, q11, %f[wr1][0]\n" /*2345*wr2[2]*/\
"vmla.f32 q15, q12, %f[wr2][1]\n" /*3456*wr2[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr1][1]\n" /*3456*wr2[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q13, q8, %e[wr3][0]\n" /*0123*wr3[0]*/ \
"vmla.f32 q14, q8, %e[wr2][0]\n" /*0123*wr3[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr4]]!\n" \
"vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \
"vmla.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr4]]\n" \
"vmla.f32 q13, q10, %e[wr3][1]\n" /*1234*wr3[1]*/\
"vmla.f32 q14, q10, %e[wr2][1]\n" /*1234*wr3[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q15, q11, %f[wr3][0]\n" /*2345*wr3[2]*/\
"vmla.f32 q14, q11, %f[wr2][0]\n" /*2345*wr3[2]*/\
"vmla.f32 q13, q12, %f[wr3][1]\n" /*3456*wr3[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr2][1]\n" /*3456*wr3[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q15, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \
"vmla.f32 q14, q8, %e[wr3][0]\n" /*0123*wr4[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr5]]!\n" \
"vmla.f32 q13, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \
"vmla.f32 q14, q9, %f[wr5][1]\n" /*4567*wr6[0]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr5]]\n" \
"vmla.f32 q15, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\
"vmla.f32 q14, q10, %e[wr3][1]\n" /*1234*wr4[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q13, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\
"vmla.f32 q14, q11, %f[wr3][0]\n" /*2345*wr4[2]*/\
"vmla.f32 q15, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr3][1]\n" /*3456*wr4[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q13, q13, q15\n" \
"vmla.f32 q14, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \
"vmul.f32 q15, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vmla.f32 q14, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr0]]\n" \
"vmla.f32 q15, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define RESULT_S1_RELU6_1 \
"vld1.f32 {d30-d31}, [%[six_ptr]]\n" \ "vld1.f32 {d30-d31}, [%[six_ptr]]\n" \
"vmax.f32 q13, q13, %q[vzero]\n" \ "vmax.f32 q13, q13, %q[vzero]\n" \
"vmax.f32 q14, q14, %q[vzero]\n" \ "vmax.f32 q14, q14, %q[vzero]\n" \
...@@ -1024,7 +1166,22 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -1024,7 +1166,22 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \ "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \
"vld1.f32 {d28-d29}, [%[bias]]\n" \ "vld1.f32 {d28-d29}, [%[bias]]\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_LEAKY_RELU_OUT2 \
"vld1.f32 {d30-d31}, [%[scale_ptr]]\n" \
"vcge.f32 q10, q13, %q[vzero]\n" \
"vcge.f32 q11, q14, %q[vzero]\n" \
"vmul.f32 q12, q13, q15\n" \
"vbif q13, q12, q10\n" \
"vmul.f32 q12, q14, q15\n" \
"vext.32 q10, q8, q9, #1\n" \
"vbif q14, q12, q11\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
"vst1.f32 {d26-d27}, [%[dout_ptr0]]!\n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \
"vld1.f32 {d28-d29}, [%[bias]]\n" \
"bne 1b"
#endif #endif
inline float compute_one_data_pre(const float* data, float32x4_t wr, float bias_val, float wei_val, int num) { inline float compute_one_data_pre(const float* data, float32x4_t wr, float bias_val, float wei_val, int num) {
...@@ -1060,10 +1217,6 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1060,10 +1217,6 @@ inline void compute_all_padding_pre(float* dout,
int remain, int remain,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
// left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0];
}
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1285,47 +1438,149 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1285,47 +1438,149 @@ inline void compute_all_padding_pre(float* dout,
} }
*dout++ = sum; *dout++ = sum;
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0];
}
} }
inline void compute_all_padding_mid(float* dout, inline void compute_all_padding_mid(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
float32x4_t* weights, float32x4_t* weights,
int win, int win,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new, int pad_left_new,
int pad_right_new, int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
for (int w = pad_left; w > 4; w--) { int tmp = num - 1;
*dout++ = bias[0]; for (int i = pad_left_new; i > 0; i--) {
} float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout++ = sum;
}
// mid
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[bias] "r"(bias)
: "cc",
"memory",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[bias] "r"(bias)
: "cc",
"memory",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++;
}
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[tmp - k]++;
}
*dout++ = sum;
}
}
inline void compute_all_padding_mid_out2(float* dout0,
float* dout1,
const float** din_ptr_arr,
const float* bias,
float32x4_t* weights,
int win,
int wout,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
int tmp1 = num + 1;
int tmp = num - 1; int tmp = num - 1;
// left
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
} }
*dout++ = sum; *dout0++ = sum;
*dout1++ = sum1;
} }
// mid // mid
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_OUT2
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]), [din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]), [din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]), [din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]), [din_ptr4] "+r"(din_ptr_arr[4]),
[dout_ptr] "+r"(dout) [din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
[wr2] "w"(weights[2]), [wr2] "w"(weights[2]),
...@@ -1343,16 +1598,19 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1343,16 +1598,19 @@ inline void compute_all_padding_mid(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17");
#else #else
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_OUT2
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]), [din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]), [din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]), [din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]), [din_ptr4] "+r"(din_ptr_arr[4]),
[dout_ptr] "+r"(dout) [din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
[wr2] "w"(weights[2]), [wr2] "w"(weights[2]),
...@@ -1377,26 +1635,35 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1377,26 +1635,35 @@ inline void compute_all_padding_mid(float* dout,
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++; sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[num - i]++;
} }
*dout++ = sum; din_ptr_arr[0]++;
*dout0++ = sum;
*dout1++ = sum1;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[tmp - k]++; sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[num - k]++;
} }
*dout++ = sum; din_ptr_arr[0]++;
*dout0++ = sum;
*dout1++ = sum1;
} }
for (int w = pad_right; w > 4; w--) { for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0]; *dout0++ = bias[0];
*dout1++ = bias[0];
} }
} }
inline void compute_all_padding_post(float* dout, inline void compute_all_padding_post(float* dout,
...@@ -1636,9 +1903,6 @@ inline void compute_all_padding_post(float* dout, ...@@ -1636,9 +1903,6 @@ inline void compute_all_padding_post(float* dout,
} }
*dout++ = sum; *dout++ = sum;
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0];
}
} }
void conv_depthwise_5x5s1_bias(float* dout, void conv_depthwise_5x5s1_bias(float* dout,
...@@ -1670,6 +1934,7 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1670,6 +1934,7 @@ void conv_depthwise_5x5s1_bias(float* dout,
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
int num_out = wout << 1;
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
float* dout_batch = dout + n * out_channel_size; float* dout_batch = dout + n * out_channel_size;
...@@ -1684,8 +1949,10 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1684,8 +1949,10 @@ void conv_depthwise_5x5s1_bias(float* dout,
const float* din_ptr2 = din_ptr1 + win; const float* din_ptr2 = din_ptr1 + win;
const float* din_ptr3 = din_ptr2 + win; const float* din_ptr3 = din_ptr2 + win;
const float* din_ptr4 = din_ptr3 + win; const float* din_ptr4 = din_ptr3 + win;
const float* din_ptr5 = din_ptr4 + win;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
float* dout_ptr = dout_ch; float* dout_ptr0 = dout_ch;
float* dout_ptr1 = dout_ch;
float32x4_t wr5; float32x4_t wr5;
float32x4_t wr6; float32x4_t wr6;
float32x4_t wr0 = vld1q_f32(weights_ch); float32x4_t wr0 = vld1q_f32(weights_ch);
...@@ -1698,33 +1965,48 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1698,33 +1965,48 @@ void conv_depthwise_5x5s1_bias(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout);
dout_ptr += wout;
}
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top_new; h > 0; h--) {
compute_all_padding_pre(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_pre(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h; h++) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4); pad_right, pad_left_new, pad_right_new, cnt, remain, 4);
dout_ptr += wout; dout_ptr0 += num_out;
dout_ptr1 += num_out;
din_ptr0 = din_ptr2;
din_ptr1 = din_ptr3;
din_ptr2 = din_ptr4;
din_ptr3 = din_ptr5;
din_ptr4 = din_ptr5 + win;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr5 = din_ptr4 + win;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
din_ptr_arr[5] = din_ptr5;
}
if (loop_h % 2 != 0) {
compute_all_padding_mid(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4);
dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
din_ptr2 = din_ptr3; din_ptr2 = din_ptr3;
din_ptr3 = din_ptr4; din_ptr3 = din_ptr4;
din_ptr4 = din_ptr4 + win; din_ptr4 = din_ptr5;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
...@@ -2108,6 +2390,136 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2108,6 +2390,136 @@ inline void compute_all_padding_mid_relu(float* dout,
*dout++ = bias[0] > 0.f ? bias[0] : 0.f; *dout++ = bias[0] > 0.f ? bias[0] : 0.f;
} }
} }
inline void compute_all_padding_mid_relu_out2(float* dout0,
float* dout1,
const float** din_ptr_arr,
const float* bias,
float32x4_t* weights,
float32x4_t vzero,
int win,
int wout,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
// left
for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : 0.f;
*dout1++ = bias[0] > 0.f ? bias[0] : 0.f;
}
int tmp = num - 1;
int tmp1 = num + 1;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f;
}
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU_OUT2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[bias] "r"(bias)
: "cc",
"memory",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU_OUT2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[bias] "r"(bias)
: "cc",
"memory",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(din_ptr_arr[tmp], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++;
}
din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f;
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[num - k]++;
}
din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f;
}
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : 0.f;
*dout1++ = bias[0] > 0.f ? bias[0] : 0.f;
}
}
inline void compute_all_padding_post_relu(float* dout, inline void compute_all_padding_post_relu(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
...@@ -2388,6 +2800,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2388,6 +2800,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
int num_out = wout << 1;
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
...@@ -2403,8 +2816,10 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2403,8 +2816,10 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
const float* din_ptr2 = din_ptr1 + win; const float* din_ptr2 = din_ptr1 + win;
const float* din_ptr3 = din_ptr2 + win; const float* din_ptr3 = din_ptr2 + win;
const float* din_ptr4 = din_ptr3 + win; const float* din_ptr4 = din_ptr3 + win;
const float* din_ptr5 = din_ptr4 + win;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
float* dout_ptr = dout_ch; float* dout_ptr0 = dout_ch;
float* dout_ptr1 = dout_ch;
float32x4_t wr5; float32x4_t wr5;
float32x4_t wr6; float32x4_t wr6;
float32x4_t wr0 = vld1q_f32(weights_ch); float32x4_t wr0 = vld1q_f32(weights_ch);
...@@ -2417,33 +2832,48 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2417,33 +2832,48 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout);
dout_ptr += wout;
}
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top_new; h > 0; h--) {
compute_all_padding_pre_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_pre_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h; h++) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4);
dout_ptr0 += num_out;
dout_ptr1 += num_out;
din_ptr0 = din_ptr2;
din_ptr1 = din_ptr3;
din_ptr2 = din_ptr4;
din_ptr3 = din_ptr5;
din_ptr4 = din_ptr5 + win;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr5 = din_ptr4 + win;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
din_ptr_arr[5] = din_ptr5;
}
if (loop_h % 2 != 0) {
compute_all_padding_mid_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_mid_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4); pad_right, pad_left_new, pad_right_new, cnt, remain, 4);
dout_ptr += wout; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
din_ptr2 = din_ptr3; din_ptr2 = din_ptr3;
din_ptr3 = din_ptr4; din_ptr3 = din_ptr4;
din_ptr4 = din_ptr4 + win; din_ptr4 = din_ptr5;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
...@@ -2452,9 +2882,9 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2452,9 +2882,9 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom_new; h++) {
compute_all_padding_post_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_post_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h); pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
...@@ -2847,32 +3277,35 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -2847,32 +3277,35 @@ inline void compute_all_padding_mid_relu6(float* dout,
} }
} }
inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, inline void compute_all_padding_mid_relu6_out2(float* dout0,
const float** din_ptr_arr, float* dout1,
const float* bias, const float** din_ptr_arr,
const float* six, const float* bias,
float32x4_t* weights, const float* six,
float32x4_t vzero, float32x4_t* weights,
int win, float32x4_t vzero,
int wout, int win,
int pad_left, int wout,
int pad_right, int pad_left,
int pad_left_new, int pad_right,
int pad_right_new, int pad_left_new,
int cnt, int pad_right_new,
int remain, int cnt,
int num) { int remain,
int num) {
#ifdef __aarch64__ #ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six); float32x4_t vsix = vld1q_f32(six);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) { for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f; *dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
*dout1++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
} }
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[num + 1], weights[num], bias[0], weights[6][0], 4 - i); float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num -k], weights[tmp -k], 0.f, weights[5][tmp - k], 4 - i); sum1 += compute_one_data_pre(din_ptr_arr[num -k], weights[tmp -k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -2882,7 +3315,7 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, ...@@ -2882,7 +3315,7 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1,
} }
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU6_OUT2
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]), [din_ptr1] "+r"(din_ptr_arr[1]),
...@@ -2890,7 +3323,7 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, ...@@ -2890,7 +3323,7 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1,
[din_ptr3] "+r"(din_ptr_arr[3]), [din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]), [din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]), [din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "++r"(dout0), [dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1) [dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -2911,9 +3344,10 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, ...@@ -2911,9 +3344,10 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17");
#else #else
asm volatile(COMPUTE_FIVE_LINE_S1_1 RESULT_S1_RELU6_1 asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_RELU6_OUT2
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]), [din_ptr1] "+r"(din_ptr_arr[1]),
...@@ -2949,8 +3383,8 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, ...@@ -2949,8 +3383,8 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1,
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(din_ptr_arr[num + 1], weights[num], bias[0], weights[6][0], 4); float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num + 1]++; din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
...@@ -2963,8 +3397,8 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, ...@@ -2963,8 +3397,8 @@ inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1,
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[num + 1], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num+1]++; din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
...@@ -3272,6 +3706,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3272,6 +3706,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
int num_out = wout << 1;
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
...@@ -3289,7 +3724,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3289,7 +3724,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
const float* din_ptr4 = din_ptr3 + win; const float* din_ptr4 = din_ptr3 + win;
const float* din_ptr5 = din_ptr4 + win; const float* din_ptr5 = din_ptr4 + win;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
float* dout_ptr = dout_ch; float* dout_ptr0 = dout_ch;
float* dout_ptr1 = dout_ch; float* dout_ptr1 = dout_ch;
float32x4_t wr5; float32x4_t wr5;
float32x4_t wr6; float32x4_t wr6;
...@@ -3306,51 +3741,60 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3306,51 +3741,60 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout);
dout_ptr += wout;
}
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top_new; h > 0; h--) {
compute_all_padding_pre_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_pre_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4 - h); pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
dout_ptr1 = dout_ptr + wout; dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu6_1(dout_ptr, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_mid_relu6_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4); pad_right_new, cnt, remain, 4);
dout_ptr += 2 * wout; dout_ptr0 += num_out;
dout_ptr1 += 2 * wout; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
din_ptr1 = din_ptr3; din_ptr1 = din_ptr3;
din_ptr2 = din_ptr4; din_ptr2 = din_ptr4;
din_ptr3 = din_ptr5; din_ptr3 = din_ptr5;
din_ptr4 = din_ptr5 + win; din_ptr4 = din_ptr5 + win;
din_ptr5 = din_ptr4 + win;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
din_ptr5 = din_ptr4 + win;
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
din_ptr_arr[5] = din_ptr5; din_ptr_arr[5] = din_ptr5;
} }
if (loop_h % 2) compute_all_padding_mid_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, if (loop_h % 2 != 0) {
compute_all_padding_mid_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4); pad_right_new, cnt, remain, 4);
dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2;
din_ptr2 = din_ptr3;
din_ptr3 = din_ptr4;
din_ptr4 = din_ptr5;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom_new; h++) {
compute_all_padding_post_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_post_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 3 - h); pad_right_new, cnt, remain, 3 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
...@@ -3753,6 +4197,261 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -3753,6 +4197,261 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0]; *dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
} }
} }
inline void compute_all_padding_mid_leakyRelu(float* dout,
const float** din_ptr_arr,
const float* bias,
const float* scale,
float32x4_t* weights,
float32x4_t vzero,
int win,
int wout,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
// left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[bias] "r"(bias)
: "cc",
"memory",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[scale_ptr] "r"(scale),
[bias] "r"(bias)
: "cc",
"memory",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++;
}
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[tmp - k]++;
}
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
}inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
float** dout1,
const float** din_ptr_arr,
const float* bias,
const float* scale,
float32x4_t* weights,
float32x4_t vzero,
int win,
int wout,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
// left
for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
*dout1++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
int tmp = num - 1;
int tmp1 = num + 1;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
}
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_LEAKY_RELU_OUT2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[bias] "r"(bias)
: "cc",
"memory",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(COMPUTE_FIVE_LINE_S1_OUT2 RESULT_S1_LEAKY_RELU_OUT2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[scale_ptr] "r"(scale),
[bias] "r"(bias)
: "cc",
"memory",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[tmp1]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[num - i]++;
}
*dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[num - k]++;
}
*dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
}
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
*dout1++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
}
inline void compute_all_padding_post_leakyRelu(float* dout, inline void compute_all_padding_post_leakyRelu(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
...@@ -4054,6 +4753,7 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4054,6 +4753,7 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
int num_out = wout << 1;
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
...@@ -4069,8 +4769,10 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4069,8 +4769,10 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
const float* din_ptr2 = din_ptr1 + win; const float* din_ptr2 = din_ptr1 + win;
const float* din_ptr3 = din_ptr2 + win; const float* din_ptr3 = din_ptr2 + win;
const float* din_ptr4 = din_ptr3 + win; const float* din_ptr4 = din_ptr3 + win;
const float* din_ptr5 = din_ptr4 + win;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
float* dout_ptr = dout_ch; float* dout_ptr0 = dout_ch;
float* dout_ptr1 = dout_ch;
float32x4_t wr5; float32x4_t wr5;
float32x4_t wr6; float32x4_t wr6;
float32x4_t wr0 = vld1q_f32(weights_ch); float32x4_t wr0 = vld1q_f32(weights_ch);
...@@ -4083,35 +4785,51 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4083,35 +4785,51 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout);
dout_ptr += wout;
}
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top_new; h > 0; h--) {
compute_all_padding_pre_leakyRelu(dout_ptr, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_pre_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4 - h); pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
dout_ptr1 = dout_ptr0 + wout;
// mid_h // mid_h
for (int h = 0; h < loop_h; h++) { for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid_leakyRelu(dout_ptr, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_mid_leakyRelu(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4); pad_right_new, cnt, remain, 4);
dout_ptr += wout; dout_ptr0 += num_out;
dout_ptr1 += num_out;
din_ptr0 = din_ptr2;
din_ptr1 = din_ptr3;
din_ptr2 = din_ptr4;
din_ptr3 = din_ptr5;
din_ptr4 = din_ptr5 + win;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr5 = din_ptr4 + win;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
din_ptr_arr[5] = din_ptr5;
}
if (loop_h % 2 != 0) {
compute_all_padding_mid_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4);
dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
din_ptr2 = din_ptr3; din_ptr2 = din_ptr3;
din_ptr3 = din_ptr4; din_ptr3 = din_ptr4;
din_ptr4 = din_ptr4 + win; din_ptr4 = din_ptr5;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
...@@ -4120,10 +4838,10 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4120,10 +4838,10 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom_new; h++) {
compute_all_padding_post_leakyRelu(dout_ptr, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_post_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 3 - h); pad_right_new, cnt, remain, 3 - h);
dout_ptr += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2; din_ptr_arr[2] = din_ptr2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册