diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index 1c0e8e5bf9ad350e8948e06808c9510e476139bd..cbe3c4b267902f5dcfd0fd111db6d1394f8de8f0 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -80,6 +80,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) conv3x3s2_depthwise_int8.cc conv5x5s1_depthwise_int8.cc conv5x5s1_depthwise_fp32.cc + conv5x5s1_depthwise_fp32_c4.cc conv5x5s2_depthwise_int8.cc conv5x5s2_depthwise_fp32.cc conv3x3_winograd_fp32_c4.cc diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index d1d8c31cf47ac73c4d63be82d7dc30f7a125dad8..74000715d31893f7d802546388da1e8d6785c4a4 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -576,9 +576,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q12, q8, q9, #3\n" \ "vadd.f32 q14, q14, q15\n" #define COMPUTE_TWO_LINE_S1_PRE \ - "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vext.32 q10, q8, q9, #1\n" \ "vext.32 q11, q8, q9, #2\n" \ "vext.32 q12, q8, q9, #3\n" \ @@ -606,9 +606,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q12, q8, q9, #3\n" \ "vadd.f32 q14, q14, q15\n" #define COMPUTE_THREE_LINE_S1_PRE \ - "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vext.32 q10, q8, q9, #1\n" \ "vext.32 q11, q8, q9, #2\n" \ "vext.32 q12, q8, q9, #3\n" \ @@ -646,9 +646,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q12, q8, q9, #3\n" \ "vadd.f32 q14, q14, q15\n" #define COMPUTE_FOUR_LINE_S1_PRE \ - "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vext.32 q10, q8, q9, #1\n" \ "vext.32 q11, q8, q9, #2\n" \ "vext.32 q12, q8, q9, #3\n" \ @@ -696,9 +696,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q12, q8, q9, #3\n" \ "vadd.f32 q14, q14, q15\n" #define COMPUTE_FIVE_LINE_S1 \ - "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vext.32 q10, q8, q9, #1\n" \ "vext.32 q11, q8, q9, #2\n" \ "vext.32 q12, q8, q9, #3\n" \ @@ -776,9 +776,9 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q12, q8, q9, #3\n" \ "vadd.f32 q14, q14, q15\n" #define COMPUTE_TWO_LINE_S1_POST \ - "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ "vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ "vext.32 q10, q8, q9, #1\n" \ "vext.32 q11, q8, q9, #2\n" \ "vext.32 q12, q8, q9, #3\n" \ @@ -921,6 +921,110 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q11, q8, q9, #2\n" \ "vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \ "bne 1b" +#define COMPUTE_FIVE_LINE_S1_1 \ + "vld1.f32 {d28-d29}, [%[bias]]\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \ + "vext.32 q10, q8, q9, #1\n" \ + "vext.32 q11, q8, q9, #2\n" \ + "vext.32 q12, q8, q9, #3\n" \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \ + "vmul.f32 q13, q9, %e[wr5][0]\n" /*4567*wr5[0]*/ \ + "vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\ + "vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \ + "vmla.f32 q13, q11, %f[wr0][0]\n" /*2345*wr0[2]*/\ + "vext.32 q10, q8, q9, #1\n" \ + "vext.32 q11, q8, q9, #2\n" \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3456*wr0[3]*/\ + "vext.32 q12, q8, q9, #3\n" \ + "vmla.f32 q13, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \ + "vmla.f32 q14, q8, %e[wr0][0]\n" /*0123*wr1[0]*/ \ + "vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \ + "vmla.f32 q15, q9, %e[wr5][1]\n" /*4567*wr5[1]*/ \ + "vmla.f32 q14, q9, %e[wr5][0]\n" /*4567*wr5[1]*/ \ + "vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \ + "vmla.f32 q13, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\ + "vmla.f32 q14, q10, %e[wr0][1]\n" /*1234*wr1[1]*/\ + "vext.32 q10, q8, q9, #1\n" \ + "vmla.f32 q15, q11, %f[wr1][0]\n" /*2345*wr1[2]*/\ + "vmla.f32 q13, q12, %f[wr1][1]\n" /*3456*wr1[3]*/\ + "vmla.f32 q14, q11, %f[wr0][0]\n" /*2345*wr1[2]*/\ + "vext.32 q11, q8, q9, #2\n" \ + "vmla.f32 q14, q12, %f[wr0][1]\n" /*3456*wr1[3]*/\ + "vext.32 q12, q8, q9, #3\n" \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0123*wr2[0]*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr2[0]*/ \ + "vld1.f32 {d16-d17}, [%[din_ptr3]]!\n" \ + "vmla.f32 q13, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \ + "vmla.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \ + "vld1.f32 {d18-d19}, [%[din_ptr3]]\n" \ + "vmla.f32 q15, q10, %e[wr2][1]\n" /*1234*wr2[1]*/\ + "vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr2[1]*/\ + "vext.32 q10, q8, q9, #1\n" \ + "vmla.f32 q13, q11, %f[wr2][0]\n" /*2345*wr2[2]*/\ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2345*wr2[2]*/\ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3456*wr2[3]*/\ + "vext.32 q11, q8, q9, #2\n" \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3456*wr2[3]*/\ + "vext.32 q12, q8, q9, #3\n" \ + "vmla.f32 q13, q8, %e[wr3][0]\n" /*0123*wr3[0]*/ \ + "vmla.f32 q14, q8, %e[wr2][0]\n" /*0123*wr3[0]*/ \ + "vld1.f32 {d16-d17}, [%[din_ptr4]]!\n" \ + "vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \ + "vmla.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \ + "vld1.f32 {d18-d19}, [%[din_ptr4]]\n" \ + "vmla.f32 q13, q10, %e[wr3][1]\n" /*1234*wr3[1]*/\ + "vmla.f32 q14, q10, %e[wr2][1]\n" /*1234*wr3[1]*/\ + "vext.32 q10, q8, q9, #1\n" \ + "vmla.f32 q15, q11, %f[wr3][0]\n" /*2345*wr3[2]*/\ + "vmla.f32 q14, q11, %f[wr2][0]\n" /*2345*wr3[2]*/\ + "vmla.f32 q13, q12, %f[wr3][1]\n" /*3456*wr3[3]*/\ + "vext.32 q11, q8, q9, #2\n" \ + "vmla.f32 q14, q12, %f[wr2][1]\n" /*3456*wr3[3]*/\ + "vext.32 q12, q8, q9, #3\n" \ + "vmla.f32 q15, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \ + "vmla.f32 q14, q8, %e[wr3][0]\n" /*0123*wr4[0]*/ \ + "vld1.f32 {d16-d17}, [%[din_ptr5]]!\n" \ + "vmla.f32 q13, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \ + "vmla.f32 q14, q9, %f[wr5][1]\n" /*4567*wr6[0]*/ \ + "vld1.f32 {d18-d19}, [%[din_ptr5]]\n" \ + "vmla.f32 q15, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\ + "vmla.f32 q14, q10, %e[wr3][1]\n" /*1234*wr4[1]*/\ + "vext.32 q10, q8, q9, #1\n" \ + "vmla.f32 q13, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\ + "vmla.f32 q14, q11, %f[wr3][0]\n" /*2345*wr4[2]*/\ + "vmla.f32 q15, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\ + "vext.32 q11, q8, q9, #2\n" \ + "vmla.f32 q14, q12, %f[wr3][1]\n" /*3456*wr4[3]*/\ + "vext.32 q12, q8, q9, #3\n" \ + "vadd.f32 q13, q13, q15\n" \ + "vmla.f32 q14, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \ + "vmul.f32 q15, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \ + "vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \ + "vmla.f32 q14, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\ + "vld1.f32 {d18-d19}, [%[din_ptr0]]\n" \ + "vmla.f32 q15, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\ + "vext.32 q10, q8, q9, #1\n" \ + "vext.32 q11, q8, q9, #2\n" \ + "vmla.f32 q14, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\ + "vext.32 q12, q8, q9, #3\n" \ + "vadd.f32 q14, q14, q15\n" +#define RESULT_S1_RELU6_1 \ + "vld1.f32 {d30-d31}, [%[six_ptr]]\n" \ + "vmax.f32 q13, q13, %q[vzero]\n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmin.f32 q13, q13, q15\n" \ + "vmin.f32 q14, q14, q15\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vst1.f32 {d26-d27}, [%[dout_ptr0]]!\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \ + "vld1.f32 {d28-d29}, [%[bias]]\n" \ + "bne 1b" + #endif inline float compute_one_data_pre(const float* data, float32x4_t wr, float bias_val, float wei_val, int num) { @@ -2742,6 +2846,139 @@ inline void compute_all_padding_mid_relu6(float* dout, *dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f; } } + +inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1, + const float** din_ptr_arr, + const float* bias, + const float* six, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int pad_left_new, + int pad_right_new, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + // left + for (int w = pad_left; w > 4; w--) { + *dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f; + } + int tmp = num - 1; + for (int i = pad_left_new; i > 0; i--) { + float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + float sum1 = compute_one_data_pre(din_ptr_arr[num + 1], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); + sum1 += compute_one_data_pre(din_ptr_arr[num -k], weights[tmp -k], 0.f, weights[5][tmp - k], 4 - i); + } + *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; + } + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "++r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FIVE_LINE_S1_1 RESULT_S1_RELU6_1 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + float sum1 = compute_one_data_post(din_ptr_arr[num + 1], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[num + 1]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[num - i]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; + } + // right + for (int i = 0; i < pad_right_new; i++) { + float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum1 = compute_one_data_post(din_ptr_arr[num + 1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num+1]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[num - k]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; + } + for (int w = pad_right; w > 4; w--) { + *dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f; + *dout1++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f; + } +} inline void compute_all_padding_post_relu6(float* dout, const float** din_ptr_arr, const float* bias, @@ -3050,8 +3287,10 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, const float* din_ptr2 = din_ptr1 + win; const float* din_ptr3 = din_ptr2 + win; const float* din_ptr4 = din_ptr3 + win; + const float* din_ptr5 = din_ptr4 + win; float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; float* dout_ptr = dout_ch; + float* dout_ptr1 = dout_ch; float32x4_t wr5; float32x4_t wr6; float32x4_t wr0 = vld1q_f32(weights_ch); @@ -3064,7 +3303,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); - const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; + const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h for (int h = pad_top; h > 4; h--) { @@ -3082,23 +3321,30 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; } + dout_ptr1 = dout_ptr + wout; // mid_h - for (int h = 0; h < loop_h; h++) { - compute_all_padding_mid_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, + for (int h = 0; h < loop_h - 1; h += 2) { + compute_all_padding_mid_relu6_1(dout_ptr, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero, win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 4); - dout_ptr += wout; - din_ptr0 = din_ptr1; - din_ptr1 = din_ptr2; - din_ptr2 = din_ptr3; - din_ptr3 = din_ptr4; - din_ptr4 = din_ptr4 + win; + dout_ptr += 2 * wout; + dout_ptr1 += 2 * wout; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr5 + win; + din_ptr5 = din_ptr4 + win; din_ptr_arr[0] = din_ptr0; din_ptr_arr[1] = din_ptr1; din_ptr_arr[2] = din_ptr2; din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; } + if (loop_h % 2) compute_all_padding_mid_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, + win, wout, pad_left, pad_right, pad_left_new, + pad_right_new, cnt, remain, 4); // bottom for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero,