diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index 402d879eb54877a597688adf575ded5db3fc4e90..e9e91c77e8e10fb0dc13b3114097aab1ce810b79 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -615,7 +615,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, "1: \n" \ "subs %[cnt], #1\n" \ "vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \ - "vmul.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \ + "vmul.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \ "vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \ "vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\ "vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \ @@ -626,7 +626,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, "vext.32 q12, q8, q9, #3\n" \ "vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \ "vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \ - "vmla.f32 q15, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \ + "vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \ "vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \ "vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\ "vext.32 q10, q8, q9, #1\n" \ @@ -945,7 +945,7 @@ inline float compute_one_data_post(const float* data, float32x4_t wr, float bias inline void compute_all_padding_pre(float* dout, const float** din_ptr_arr, const float* bias, - std::vector weights, + float32x4_t* weights, int win, int wout, int pad_left, @@ -959,6 +959,7 @@ inline void compute_all_padding_pre(float* dout, for (int w = pad_left; w > 4; w--) { *dout++ = bias[0]; } + LOG(INFO) << "pad_left_new: " << pad_left_new; for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); for (int k = 0; k < num; k++) { @@ -1158,6 +1159,7 @@ inline void compute_all_padding_pre(float* dout, default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } + din_ptr_arr[0] -= 4; } // remain for (int w = 0; w < remain; w++) { @@ -1169,44 +1171,17 @@ inline void compute_all_padding_pre(float* dout, } *dout++ = sum; } - + LOG(INFO) << " pad_right_new: " << pad_right_new; // right - for (int i = 1; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][4 - i], 4 - i); + for (int i = 0; i < pad_right_new; i++) { + float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num]++; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][4 - i], 4 - i); + sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); + din_ptr_arr[num - 1 - k]++; } *dout++ = sum; } - /* - switch (pad_right_new) { - case 1: - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3], 3); - for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][3], 3); - } - *dout++ = sum; - case 2: - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][2], 2); - for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][2], 2); - } - *dout++ = sum; - case 3: - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][1], 1); - for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][1], 1); - } - *dout++ = sum; - case 4: - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][0], 0); - for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][0], 0); - } - *dout++ = sum; - - } - */ for (int w = pad_right; w > 4; w--) { *dout++ = bias[0]; } @@ -1215,7 +1190,7 @@ inline void compute_all_padding_pre(float* dout, inline void compute_all_padding_mid(float* dout, const float** din_ptr_arr, const float* bias, - std::vector weights, + float32x4_t* weights, int win, int wout, int pad_left, @@ -1293,7 +1268,8 @@ inline void compute_all_padding_mid(float* dout, "q13", "q14", "q15"); -#endif +#endif + din_ptr_arr[0] -= 4; } // remain for (int w = 0; w < remain; w++) { @@ -1305,15 +1281,17 @@ inline void compute_all_padding_mid(float* dout, } *dout++ = sum; } - // right for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num]++; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[tmp - k]++; } *dout++ = sum; } + for (int w = pad_right; w > 4; w--) { *dout++ = bias[0]; } @@ -1321,7 +1299,7 @@ inline void compute_all_padding_mid(float* dout, inline void compute_all_padding_post(float* dout, const float** din_ptr_arr, const float* bias, - std::vector weights, + float32x4_t* weights, int win, int wout, int pad_left, @@ -1337,9 +1315,9 @@ inline void compute_all_padding_post(float* dout, } int tmp = num - 1; for (int i = pad_left_new; i > 0; i--) { - float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - i); + float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); + sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); } *dout++ = sum; } @@ -1350,7 +1328,7 @@ inline void compute_all_padding_post(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr0] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -1368,7 +1346,7 @@ inline void compute_all_padding_post(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr0] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -1384,13 +1362,14 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif + din_ptr_arr[3] -= 4; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1409,8 +1388,8 @@ inline void compute_all_padding_post(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1427,14 +1406,15 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif + din_ptr_arr[2] -= 4; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), - [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1454,9 +1434,9 @@ inline void compute_all_padding_post(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), - [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1474,6 +1454,7 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif + din_ptr_arr[1] -= 4; break; case 3: #ifdef __aarch64__ @@ -1525,6 +1506,7 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif + din_ptr_arr[0] -= 4; break; default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; @@ -1532,20 +1514,22 @@ inline void compute_all_padding_post(float* dout, } // remain for (int w = 0; w < remain; w++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[num]++; + float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3]++; for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[tmp - i]++; + sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i]++; } *dout++ = sum; } // right for (int i = 0; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3]++; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[2 - k]++; } *dout++ = sum; } @@ -1572,6 +1556,8 @@ void conv_depthwise_5x5s1_bias(float* dout, ARMContext* ctx){ int loop_w = wout - pad_left - pad_right; int loop_h = hout - pad_top - pad_bottom; + LOG(INFO) << "pad_top: " << pad_top << ", pad_bottom: " << pad_bottom; + LOG(INFO) << "pad_left: " << pad_left << ", pad_right: " << pad_right; int in_size = win * hin; int out_size = wout * hout; int cnt = loop_w >> 2; @@ -1612,32 +1598,54 @@ void conv_depthwise_5x5s1_bias(float* dout, wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; - float32x4_t wei_vwc[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; - // top_h + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + // top_h for (int h = pad_top; h > 4; h--) { memset(dout_ptr, bias[0], sizeof(float)*wout); dout_ptr += wout; } for (int h = pad_top_new; h > 0; h--) { compute_all_padding_pre(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, - pad_left_new, pad_right, pad_right_new, cnt, remain, 4 - h); + pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); dout_ptr += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } - // mid_h + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; for (int h = 0; h < loop_h; h++) { compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, - pad_left_new, pad_right, pad_right_new, cnt, remain, 4); + pad_right, pad_left_new, pad_right_new, cnt, remain, 4); dout_ptr += wout; - for (int i = 0; i < 4; i++) { - din_ptr_arr[i] = din_ptr_arr[i + 1]; - } - din_ptr_arr[4] += win; + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr4 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } // bottom for (int h = 0; h < pad_bottom_new; h++) { + for (int i = 0; i < 5; i++) LOG(INFO) << "i: " << i << ", ptr: " << din_ptr_arr[i]; + LOG(INFO) << "num: " << (3 -h); compute_all_padding_post(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, - pad_left_new, pad_right, pad_right_new, cnt, remain, 3 - h); + pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h); dout_ptr += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } } } diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index e5173723c52012f765a83f528c9b9d1edbc92508..5ce112156f88b41bd70392b80f57193f5ff2e943 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -751,7 +751,7 @@ void conv_depthwise_5x5_fp32(const void* din, act_param, ctx); } else if (stride == 1) { -#if 1 +#if 0 conv_depthwise_5x5s1_fp32(reinterpret_cast(dout), reinterpret_cast(din), reinterpret_cast(weights), diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index e65591b0c8de340e46d3c36b52033f6377e0d10f..ef939d19b179c187a949bc6a466344a5f3c38234 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -56,8 +56,16 @@ void DepthwiseConv::PrepareForRun() { } else if (kw == 5) { // VLOG(5) << "invoke 5x5 dw conv fp32"; auto strides = param.strides; - if ((strides[0] == 1 && strides[1] == 1) || - (strides[0] == 2 && strides[1] == 2)) { + auto hin = param.x->dims()[2]; + auto win = param.x->dims()[3]; + if (win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) { + flag_trans_weights_ = false; + impl_ = lite::arm::math::conv_depthwise_5x5_fp32; +#ifdef LITE_WITH_PROFILE + kernel_func_name_ = "conv_depthwise_5x5_fp32"; +#endif + } else if ((strides[0] == 1 && strides[1] == 1) || + (strides[0] == 2 && strides[1] == 2)) { // trans weights constexpr int cblock = 4; auto oc = w_dims[0];