diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index 00c12587fa510f63e989803f283df67c742b2344..01ededa033de82ca0219ccd05ee2487b31c8141a 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -1320,20 +1320,24 @@ inline void compute_all_padding_pre(float* dout, bool odds, int pad_left, int pad_right, + int pad_left_new, + int pad_right_new, int cnt, int remain, int num) { int tmp_index = num - 1; - for (int i = pad_left; i > 0; i--) { + int num_index_left = 4 - pad_left; + for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], - 4 - i); + num_index_left); } + num_index_left -= 2; *dout++ = sum; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -1554,18 +1558,21 @@ inline void compute_all_padding_pre(float* dout, *dout++ = sum; } // right - for (int i = 0; i < pad_right; i++) { + int num_index_right = 4 - pad_right; + LOG(INFO) << "pad_right_new: " << pad_right_new << ", num_index_right: " << num_index_right; + for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, - weights[3 - k][3 - i], - 3 - i); + weights[3 - k][num_index_right], + num_index_right); din_ptr_arr[tmp_index - k] += 2; } + num_index_right -= 2; *dout++ = sum; } } @@ -1576,21 +1583,25 @@ inline void compute_all_padding_mid(float* dout, bool odds, int pad_left, int pad_right, + int pad_left_new, + int pad_right_new, int cnt, int remain, int num) { // left int tmp = num - 1; - for (int i = pad_left; i > 0; i--) { + int num_index_left = 4 - pad_left; + for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left -= 2; *dout++ = sum; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -1673,18 +1684,19 @@ inline void compute_all_padding_mid(float* dout, *dout++ = sum; } // right - for (int i = 0; i < pad_right; i++) { + for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][4 - pad_right], + 4 - pad_right); din_ptr_arr[tmp - k] += 2; } + pad_right += 2; *dout++ = sum; } } @@ -1696,6 +1708,8 @@ inline void compute_all_padding_mid_out2(float* dout0, bool odds, int pad_left, int pad_right, + int pad_left_new, + int pad_right_new, int cnt, int remain, int num) { @@ -1703,23 +1717,24 @@ inline void compute_all_padding_mid_out2(float* dout0, int tmp2 = num + 1; int tmp = num - 1; // left - for (int i = pad_left; i > 0; i--) { + for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - pad_left); float sum1 = compute_one_data_pre( - din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - pad_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + 4 - pad_left); sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + 4 - pad_left); } + pad_left -= 2; *dout0++ = sum; *dout1++ = sum1; } @@ -1820,25 +1835,26 @@ inline void compute_all_padding_mid_out2(float* dout0, *dout1++ = sum1; } // right - for (int i = 0; i < pad_right; i++) { + for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); float sum1 = compute_one_data_post( - din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][4 - pad_right], + 4 - pad_right); sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][4 - pad_right], + 4 - pad_right); din_ptr_arr[tmp2 - k] += 2; } + pad_right += 2; din_ptr_arr[1] += 2; din_ptr_arr[0] += 2; *dout0++ = sum; @@ -1853,21 +1869,24 @@ inline void compute_all_padding_post(float* dout, bool odds, int pad_left, int pad_right, + int pad_left_new, + int pad_right_new, int cnt, int remain, int num) { // left int tmp = num - 1; - for (int i = pad_left; i > 0; i--) { + for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - pad_left); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[2 - k], + sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + 4 - pad_left); } + pad_left -= 2; *dout++ = sum; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -1884,7 +1903,7 @@ inline void compute_all_padding_post(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -1902,7 +1921,7 @@ inline void compute_all_padding_post(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -1918,14 +1937,14 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 8; + din_ptr_arr[num] -= 8; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1944,8 +1963,8 @@ inline void compute_all_padding_post(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1962,15 +1981,15 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 8; + din_ptr_arr[tmp] -= 8; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -1990,9 +2009,9 @@ inline void compute_all_padding_post(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2010,7 +2029,7 @@ inline void compute_all_padding_post(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 8; + din_ptr_arr[tmp - 1] -= 8; break; case 3: #ifdef __aarch64__ @@ -2072,28 +2091,30 @@ inline void compute_all_padding_post(float* dout, // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( - din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i] += 2; + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i] += 2; } *dout++ = sum; } // right - for (int i = 0; i < pad_right; i++) { + int num_index = 4 - pad_right; + for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index], num_index); + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[2 - k], + sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); - din_ptr_arr[2 - k] += 2; + weights[tmp - k][num_index], + num_index); + din_ptr_arr[tmp - k] += 2; } + num_index -= 2; *dout++ = sum; } } @@ -2139,7 +2160,18 @@ void conv_depthwise_5x5s2_bias(float* dout, int cnt = loop_w >> 2; int remain = loop_w & 3; int n_top_h = 4 - pad_top; - int n_bottom_h = 4 -pad_bottom; + int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3); + int n_right_w = odds_w ? pad_right : ((win % 2) ? 0 : 1); + if (n_right_w == 0) { + remain++; + pad_right_new--; + n_right_w += 2; + } + if (n_bottom_h == 4) { + loop_h++; + pad_bottom_new--; + n_bottom_h -= 2; + } for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; float* dout_batch = dout + n * out_channel_size; @@ -2182,6 +2214,8 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, + pad_left, + n_right_w, pad_left_new, pad_right_new, cnt, @@ -2220,6 +2254,8 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, + pad_left, + n_right_w, pad_left_new, pad_right_new, cnt, @@ -2248,6 +2284,8 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, + pad_left, + n_right_w, pad_left_new, pad_right_new, cnt, @@ -2273,6 +2311,8 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, + pad_left, + n_right_w, pad_left_new, pad_right_new, cnt,