From 589e852c44a3eabee520886b86945e210c11e4c8 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Thu, 20 Aug 2020 15:35:16 +0800 Subject: [PATCH] fix relu relu6 error. test=develop --- .../arm/math/conv5x5s2_depthwise_fp32.cc | 516 +++++++++++------- 1 file changed, 308 insertions(+), 208 deletions(-) diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index 01ededa033..1813cf593f 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -1320,14 +1320,13 @@ inline void compute_all_padding_pre(float* dout, bool odds, int pad_left, int pad_right, - int pad_left_new, - int pad_right_new, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { int tmp_index = num - 1; - int num_index_left = 4 - pad_left; - for (int i = pad_left_new; i > 0; i--) { + for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { @@ -1337,7 +1336,7 @@ inline void compute_all_padding_pre(float* dout, weights[5][3 - k], num_index_left); } - num_index_left -= 2; + num_index_left += 2; *dout++ = sum; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -1558,9 +1557,7 @@ inline void compute_all_padding_pre(float* dout, *dout++ = sum; } // right - int num_index_right = 4 - pad_right; - LOG(INFO) << "pad_right_new: " << pad_right_new << ", num_index_right: " << num_index_right; - for (int i = 0; i < pad_right_new; i++) { + for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right); din_ptr_arr[num] += 2; @@ -1583,15 +1580,14 @@ inline void compute_all_padding_mid(float* dout, bool odds, int pad_left, int pad_right, - int pad_left_new, - int pad_right_new, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { // left int tmp = num - 1; - int num_index_left = 4 - pad_left; - for (int i = pad_left_new; i > 0; i--) { + for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { @@ -1601,7 +1597,7 @@ inline void compute_all_padding_mid(float* dout, weights[5][tmp - k], num_index_left); } - num_index_left -= 2; + num_index_left += 2; *dout++ = sum; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -1684,19 +1680,19 @@ inline void compute_all_padding_mid(float* dout, *dout++ = sum; } // right - for (int i = 0; i < pad_right_new; i++) { + for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][4 - pad_right], - 4 - pad_right); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp - k] += 2; } - pad_right += 2; + num_index_right -= 2; *dout++ = sum; } } @@ -1708,8 +1704,8 @@ inline void compute_all_padding_mid_out2(float* dout0, bool odds, int pad_left, int pad_right, - int pad_left_new, - int pad_right_new, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -1717,24 +1713,24 @@ inline void compute_all_padding_mid_out2(float* dout0, int tmp2 = num + 1; int tmp = num - 1; // left - for (int i = pad_left_new; i > 0; i--) { + for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - pad_left); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); float sum1 = compute_one_data_pre( - din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - pad_left); + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - pad_left); + num_index_left); sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - pad_left); + num_index_left); } - pad_left -= 2; + num_index_left += 2; *dout0++ = sum; *dout1++ = sum1; } @@ -1835,26 +1831,26 @@ inline void compute_all_padding_mid_out2(float* dout0, *dout1++ = sum1; } // right - for (int i = 0; i < pad_right_new; i++) { + for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); float sum1 = compute_one_data_post( - din_ptr_arr[tmp1], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][4 - pad_right], - 4 - pad_right); + weights[tmp - k][num_index_right], + num_index_right); sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, - weights[tmp - k][4 - pad_right], - 4 - pad_right); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp2 - k] += 2; } - pad_right += 2; + num_index_right -= 2; din_ptr_arr[1] += 2; din_ptr_arr[0] += 2; *dout0++ = sum; @@ -1869,22 +1865,22 @@ inline void compute_all_padding_post(float* dout, bool odds, int pad_left, int pad_right, - int pad_left_new, - int pad_right_new, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { // left int tmp = num - 1; - for (int i = pad_left_new; i > 0; i--) { + for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - pad_left); + din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - pad_left); + num_index_left); } pad_left -= 2; *dout++ = sum; @@ -2101,20 +2097,19 @@ inline void compute_all_padding_post(float* dout, *dout++ = sum; } // right - int num_index = 4 - pad_right; - for (int i = 0; i < pad_right_new; i++) { + for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][num_index], num_index); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][num_index], - num_index); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp - k] += 2; } - num_index -= 2; + num_index_right -= 2; *dout++ = sum; } } @@ -2161,11 +2156,12 @@ void conv_depthwise_5x5s2_bias(float* dout, int remain = loop_w & 3; int n_top_h = 4 - pad_top; int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3); - int n_right_w = odds_w ? pad_right : ((win % 2) ? 0 : 1); - if (n_right_w == 0) { + int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3); + int n_left_w = 4 - pad_left; + if (n_right_w == 4) { remain++; pad_right_new--; - n_right_w += 2; + n_right_w -= 2; } if (n_bottom_h == 4) { loop_h++; @@ -2214,10 +2210,10 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, - pad_left, - n_right_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -2254,10 +2250,10 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, - pad_left, - n_right_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -2284,10 +2280,10 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, - pad_left, - n_right_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -2311,10 +2307,10 @@ void conv_depthwise_5x5s2_bias(float* dout, vbias, weights_vec, odds_w, - pad_left, - n_right_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -2338,20 +2334,23 @@ inline void compute_all_padding_pre_relu(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { int tmp_index = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout++ = sum > 0.f ? sum : 0.f; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -2582,16 +2581,17 @@ inline void compute_all_padding_pre_relu(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, - weights[3 - k][3 - i], - 3 - i); + weights[3 - k][num_index_right], + num_index_right); din_ptr_arr[tmp_index - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? sum : 0.f; } } @@ -2603,20 +2603,23 @@ inline void compute_all_padding_mid_relu(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout++ = sum > 0.f ? sum : 0.f; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -2702,16 +2705,17 @@ inline void compute_all_padding_mid_relu(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? sum : 0.f; } } @@ -2724,6 +2728,8 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -2733,21 +2739,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); float sum1 = compute_one_data_pre( - din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout0++ = sum > 0.f ? sum : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f; } @@ -2851,23 +2858,24 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); float sum1 = compute_one_data_post( - din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp2 - k] += 2; } + num_index_right -= 2; din_ptr_arr[0] += 2; din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? sum : 0.f; @@ -2882,6 +2890,8 @@ inline void compute_all_padding_post_relu(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -2889,14 +2899,15 @@ inline void compute_all_padding_post_relu(float* dout, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[2 - k], + sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + pad_left -= 2; *dout++ = sum > 0.f ? sum : 0.f; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -2913,7 +2924,7 @@ inline void compute_all_padding_post_relu(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -2932,7 +2943,7 @@ inline void compute_all_padding_post_relu(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -2949,14 +2960,14 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 8; + din_ptr_arr[num] -= 8; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2976,8 +2987,8 @@ inline void compute_all_padding_post_relu(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2995,15 +3006,15 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 8; + din_ptr_arr[tmp] -= 8; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -3024,9 +3035,9 @@ inline void compute_all_padding_post_relu(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -3045,7 +3056,7 @@ inline void compute_all_padding_post_relu(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 8; + din_ptr_arr[tmp - 1] -= 8; break; case 3: #ifdef __aarch64__ @@ -3102,35 +3113,36 @@ inline void compute_all_padding_post_relu(float* dout, din_ptr_arr[0] -= 8; break; default: - LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + LOG(FATAL) << "This num: " << (num + 1) << " does not support"; } } // clang-format on // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( - din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i] += 2; + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i] += 2; } *dout++ = sum > 0.f ? sum : 0.f; } // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[2 - k], + sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); - din_ptr_arr[2 - k] += 2; + weights[tmp - k][num_index_right], + num_index_right); + din_ptr_arr[tmp - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? sum : 0.f; } } @@ -3176,7 +3188,19 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, int cnt = loop_w >> 2; int remain = loop_w & 3; int n_top_h = 4 - pad_top; - int n_bottom_h = 4 -pad_bottom; + int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3); + int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3); + int n_left_w = 4 - pad_left; + if (n_right_w == 4) { + remain++; + pad_right_new--; + n_right_w -= 2; + } + if (n_bottom_h == 4) { + loop_h++; + pad_bottom_new--; + n_bottom_h -= 2; + } float32x4_t vzero = vdupq_n_f32(0.f); for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; @@ -3223,6 +3247,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -3262,6 +3288,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -3291,6 +3319,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -3317,6 +3347,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -3341,6 +3373,8 @@ inline void compute_all_padding_pre_relu6(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -3351,14 +3385,15 @@ inline void compute_all_padding_pre_relu6(float* dout, // left for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -3597,16 +3632,17 @@ inline void compute_all_padding_pre_relu6(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, - weights[3 - k][3 - i], - 3 - i); + weights[3 - k][num_index_right], + num_index_right); din_ptr_arr[tmp_index - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } } @@ -3619,6 +3655,8 @@ inline void compute_all_padding_mid_relu6(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -3629,14 +3667,15 @@ inline void compute_all_padding_mid_relu6(float* dout, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -3706,7 +3745,7 @@ inline void compute_all_padding_mid_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[0] -= 4; + din_ptr_arr[0] -= 8; } // clang-format on // remain @@ -3724,16 +3763,17 @@ inline void compute_all_padding_mid_relu6(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } } @@ -3748,6 +3788,8 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -3761,21 +3803,22 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, // clang-format off for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); float sum1 = compute_one_data_pre( - din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); - sum1 += compute_one_data_pre(din_ptr_arr[tmp2 -k], - weights[tmp -k], + num_index_left); + sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], + weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; } @@ -3880,23 +3923,24 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); float sum1 = compute_one_data_post( - din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[tmp1]++; + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right); + din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp2 - k] += 2; } + num_index_right -= 2; din_ptr_arr[1] += 2; din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; @@ -3912,6 +3956,8 @@ inline void compute_all_padding_post_relu6(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -3922,14 +3968,15 @@ inline void compute_all_padding_post_relu6(float* dout, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[2 - k], + sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + pad_left -= 2; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -3946,7 +3993,7 @@ inline void compute_all_padding_post_relu6(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -3966,7 +4013,7 @@ inline void compute_all_padding_post_relu6(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -3984,14 +4031,14 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 8; + din_ptr_arr[num] -= 8; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -4012,8 +4059,8 @@ inline void compute_all_padding_post_relu6(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -4032,15 +4079,15 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 8; + din_ptr_arr[tmp] -= 8; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -4062,9 +4109,9 @@ inline void compute_all_padding_post_relu6(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -4084,7 +4131,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 8; + din_ptr_arr[tmp - 1] -= 8; break; case 3: #ifdef __aarch64__ @@ -4162,16 +4209,17 @@ inline void compute_all_padding_post_relu6(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[2 - k], + sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); - din_ptr_arr[2 - k] += 2; + weights[tmp - k][num_index_right], + num_index_right); + din_ptr_arr[tmp - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } } @@ -4218,7 +4266,19 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, int cnt = loop_w >> 2; int remain = loop_w & 3; int n_top_h = 4 - pad_top; - int n_bottom_h = 4 -pad_bottom; + int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3); + int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3); + int n_left_w = 4 - pad_left; + if (n_right_w == 4) { + remain++; + pad_right_new--; + n_right_w -= 2; + } + if (n_bottom_h == 4) { + loop_h++; + pad_bottom_new--; + n_bottom_h -= 2; + } float32x4_t vzero = vdupq_n_f32(0.f); for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; @@ -4266,6 +4326,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -4306,6 +4368,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -4336,6 +4400,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -4363,6 +4429,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -4387,6 +4455,8 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -4397,14 +4467,15 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, // left for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout++ = sum > 0.f ? sum : sum * scale[0]; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -4651,22 +4722,19 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, - weights[3 - k][3 - i], - 3 - i); + weights[3 - k][num_index_right], + num_index_right); din_ptr_arr[tmp_index - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? sum : sum * scale[0]; } - for (int w = pad_right; w > 4; w--) { - *dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0]; - } - } inline void compute_all_padding_mid_leakyRelu(float* dout, const float** din_ptr_arr, @@ -4677,6 +4745,8 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -4687,14 +4757,15 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout++ = sum > 0.f ? sum : sum * scale[0]; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -4784,16 +4855,17 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? sum : sum * scale[0]; } } @@ -4807,6 +4879,8 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -4819,21 +4893,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); float sum1 = compute_one_data_pre( - din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left); for (int k = 0; k < num; k++) { sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + num_index_left += 2; *dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; } @@ -4943,23 +5018,24 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); float sum1 = compute_one_data_post( - din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right); din_ptr_arr[tmp1] += 2; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); + weights[tmp - k][num_index_right], + num_index_right); din_ptr_arr[tmp2 - k] += 2; } + num_index_right -= 2; din_ptr_arr[1] += 2; din_ptr_arr[0] += 2; *dout0++ = sum > 0.f ? sum : sum * scale[0]; @@ -4975,6 +5051,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout, bool odds, int pad_left, int pad_right, + int num_index_left, + int num_index_right, int cnt, int remain, int num) { @@ -4985,14 +5063,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout, int tmp = num - 1; for (int i = pad_left; i > 0; i--) { float sum = compute_one_data_pre( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[2 - k], + sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], - 4 - i); + num_index_left); } + pad_left -= 2; *dout++ = sum > 0.f ? sum : sum * scale[0]; } if (odds) { // origin pad_left is odds, such as ori_pad_left=1 @@ -5009,7 +5088,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -5031,7 +5110,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -5049,14 +5128,14 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[3] -= 8; + din_ptr_arr[num] -= 8; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -5079,8 +5158,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[2]), - [din_ptr1] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp]), + [din_ptr1] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -5099,15 +5178,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[2] -= 8; + din_ptr_arr[tmp] -= 8; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -5131,9 +5210,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[1]), - [din_ptr1] "+r"(din_ptr_arr[2]), - [din_ptr2] "+r"(din_ptr_arr[3]), + [din_ptr0] "+r"(din_ptr_arr[tmp - 1]), + [din_ptr1] "+r"(din_ptr_arr[tmp]), + [din_ptr2] "+r"(din_ptr_arr[num]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -5153,7 +5232,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif - din_ptr_arr[1] -= 8; + din_ptr_arr[tmp - 1] -= 8; break; case 3: #ifdef __aarch64__ @@ -5221,28 +5300,29 @@ inline void compute_all_padding_post_leakyRelu(float* dout, // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[num] += 2; for (int i = 0; i < num; i++) { sum += compute_one_data_post( - din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i] += 2; + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i] += 2; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } // right for (int i = 0; i < pad_right; i++) { float sum = compute_one_data_post( - din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - din_ptr_arr[3] += 2; + din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right); + din_ptr_arr[num] += 2; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[2 - k], + sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, - weights[tmp - k][3 - i], - 3 - i); - din_ptr_arr[2 - k] += 2; + weights[tmp - k][num_index_right], + num_index_right); + din_ptr_arr[tmp - k] += 2; } + num_index_right -= 2; *dout++ = sum > 0.f ? sum : sum * scale[0]; } } @@ -5289,7 +5369,19 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, int cnt = loop_w >> 2; int remain = loop_w & 3; int n_top_h = 4 - pad_top; - int n_bottom_h = 4 -pad_bottom; + int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3); + int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3); + int n_left_w = 4 - pad_left; + if (n_right_w == 4) { + remain++; + pad_right_new--; + n_right_w -= 2; + } + if (n_bottom_h == 4) { + loop_h++; + pad_bottom_new--; + n_bottom_h -= 2; + } float32x4_t vzero = vdupq_n_f32(0.f); for (int n = 0; n < num; n++) { const float* din_batch = din + n * in_channel_size; @@ -5337,6 +5429,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); @@ -5377,6 +5471,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -5407,6 +5503,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, 4); @@ -5434,6 +5532,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, odds_w, pad_left_new, pad_right_new, + n_left_w, + n_right_w, cnt, remain, h_in_num); -- GitLab