From 4dbda3aa2b227700898ee63421b90c7464234ac8 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Fri, 14 Aug 2020 17:07:27 +0800 Subject: [PATCH] add conv 5x5s1_dw relu6/leakyRelu fusion --- .../arm/math/conv5x5s1_depthwise_fp32.cc | 488 +++++++++--------- 1 file changed, 249 insertions(+), 239 deletions(-) diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index 14a2d4b683..19e61385d9 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -1571,31 +1571,31 @@ void conv_depthwise_5x5s1_bias(float* dout, float* dout_batch = dout + n * out_channel_size; #pragma omp parallel for for (int c = 0; c < chin; c++) { - const float* din_ch = din_batch + c * in_size; - const float* weights_ch = weights + c * weights_size; - float* dout_ch = dout_batch + c * out_size; - float bias_val = flag_bias ? bias[c] : 0.f; - const float* din_ptr0 = din_ch; - const float* din_ptr1 = din_ptr0 + win; - const float* din_ptr2 = din_ptr1 + win; - const float* din_ptr3 = din_ptr2 + win; - const float* din_ptr4 = din_ptr3 + win; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - float* dout_ptr = dout_ch; - float32x4_t wr5; - float32x4_t wr6; - float32x4_t wr0 = vld1q_f32(weights_ch); - float32x4_t wr1 = vld1q_f32(weights_ch + 5); - float32x4_t wr2 = vld1q_f32(weights_ch + 10); - float32x4_t wr3 = vld1q_f32(weights_ch + 15); - float32x4_t wr4 = vld1q_f32(weights_ch + 20); - wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); - wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); - wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); - wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); - wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); - const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; - float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h for (int h = pad_top; h > 4; h--) { memset(dout_ptr, bias[0], sizeof(float)*wout); @@ -1611,11 +1611,11 @@ void conv_depthwise_5x5s1_bias(float* dout, din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; } - din_ptr_arr[0] = din_ptr0; - din_ptr_arr[1] = din_ptr1; - din_ptr_arr[2] = din_ptr2; - din_ptr_arr[3] = din_ptr3; - din_ptr_arr[4] = din_ptr4; + // din_ptr_arr[0] = din_ptr0; + // din_ptr_arr[1] = din_ptr1; + // din_ptr_arr[2] = din_ptr2; + // din_ptr_arr[3] = din_ptr3; + // din_ptr_arr[4] = din_ptr4; // mid_h for (int h = 0; h < loop_h; h++) { compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, @@ -2255,27 +2255,6 @@ inline void compute_all_padding_post_relu(float* dout, } *dout++ = sum > 0.f ? sum : 0.f; } -/* - // remain - for (int w = 0; w < remain; w++) { - float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[3]++; - for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[2 - i]++; - } - *dout++ = sum > 0.f ? sum : 0.f; - } - - // right - for (int i = 0; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); - for (int k = 0; k < num; k++) { - sum +=compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); - } - *dout++ = sum > 0.f ? sum : 0.f; - } - */ for (int w = pad_right; w > 4; w--) { *dout++ = bias[0] > 0.f ? bias[0] : 0.f; } @@ -2316,31 +2295,31 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, float* dout_batch = dout + n * out_channel_size; #pragma omp parallel for for (int c = 0; c < chin; c++) { - const float* din_ch = din_batch + c * in_size; - const float* weights_ch = weights + c * weights_size; - float* dout_ch = dout_batch + c * out_size; - float bias_val = flag_bias ? bias[c] : 0.f; - const float* din_ptr0 = din_ch; - const float* din_ptr1 = din_ptr0 + win; - const float* din_ptr2 = din_ptr1 + win; - const float* din_ptr3 = din_ptr2 + win; - const float* din_ptr4 = din_ptr3 + win; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - float* dout_ptr = dout_ch; - float32x4_t wr5; - float32x4_t wr6; - float32x4_t wr0 = vld1q_f32(weights_ch); - float32x4_t wr1 = vld1q_f32(weights_ch + 5); - float32x4_t wr2 = vld1q_f32(weights_ch + 10); - float32x4_t wr3 = vld1q_f32(weights_ch + 15); - float32x4_t wr4 = vld1q_f32(weights_ch + 20); - wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); - wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); - wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); - wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); - wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); - const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; - float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h for (int h = pad_top; h > 4; h--) { memset(dout_ptr, bias[0], sizeof(float)*wout); @@ -2356,11 +2335,11 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, din_ptr_arr[3] = din_ptr3; din_ptr_arr[4] = din_ptr4; } - din_ptr_arr[0] = din_ptr0; - din_ptr_arr[1] = din_ptr1; - din_ptr_arr[2] = din_ptr2; - din_ptr_arr[3] = din_ptr3; - din_ptr_arr[4] = din_ptr4; + // din_ptr_arr[0] = din_ptr0; + // din_ptr_arr[1] = din_ptr1; + // din_ptr_arr[2] = din_ptr2; + // din_ptr_arr[3] = din_ptr3; + // din_ptr_arr[4] = din_ptr4; // mid_h for (int h = 0; h < loop_h; h++) { compute_all_padding_mid_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, @@ -2393,10 +2372,10 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, } inline void compute_all_padding_pre_relu6(float* dout, - std::vector din_ptr_arr, + const float** din_ptr_arr, const float* bias, const float* six, - std::vector weights, + float32x4_t* weights, float32x4_t vzero, int win, int wout, @@ -2410,6 +2389,7 @@ inline void compute_all_padding_pre_relu6(float* dout, #ifdef __aarch64__ float32x4_t vsix = vld1q_f32(six); #endif + int tmp_index = num - 1; // left for (int w = pad_left; w > 4; w--) { *dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f; @@ -2417,7 +2397,7 @@ inline void compute_all_padding_pre_relu6(float* dout, for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); + sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -2629,23 +2609,25 @@ inline void compute_all_padding_pre_relu6(float* dout, default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } + din_ptr_arr[0] -= 4; } // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); din_ptr_arr[num]++; for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[5][3 - i], 4); - din_ptr_arr[num - 1 - i]++; + sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); + din_ptr_arr[tmp_index - i]++; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } - // right - for (int i = 1; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][4 - i], 4 - i); + for (int i = 0; i < pad_right_new; i++) { + float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num]++; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][4 - i], 4 - i); + sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); + din_ptr_arr[tmp_index - k]++; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -2655,10 +2637,10 @@ inline void compute_all_padding_pre_relu6(float* dout, } inline void compute_all_padding_mid_relu6(float* dout, - std::vector din_ptr_arr, + const float** din_ptr_arr, const float* bias, const float* six, - std::vector weights, + float32x4_t* weights, float32x4_t vzero, int win, int wout, @@ -2743,7 +2725,8 @@ inline void compute_all_padding_mid_relu6(float* dout, "q13", "q14", "q15"); -#endif +#endif + din_ptr_arr[0] -= 4; } // remain for (int w = 0; w < remain; w++) { @@ -2755,12 +2738,13 @@ inline void compute_all_padding_mid_relu6(float* dout, } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } - // right for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num]++; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[tmp - k]++; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -2769,10 +2753,10 @@ inline void compute_all_padding_mid_relu6(float* dout, } } inline void compute_all_padding_post_relu6(float* dout, - std::vector din_ptr_arr, + const float** din_ptr_arr, const float* bias, const float* six, - std::vector weights, + float32x4_t* weights, float32x4_t vzero, int win, int wout, @@ -2792,9 +2776,9 @@ inline void compute_all_padding_post_relu6(float* dout, } int tmp = num - 1; for (int i = pad_left_new; i > 0; i--) { - float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - i); + float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); + sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -2805,7 +2789,7 @@ inline void compute_all_padding_post_relu6(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr0] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -2825,7 +2809,7 @@ inline void compute_all_padding_post_relu6(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr0] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -2843,13 +2827,14 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif + din_ptr_arr[3] -= 4; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2870,8 +2855,8 @@ inline void compute_all_padding_post_relu6(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2890,14 +2875,15 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif + din_ptr_arr[2] -= 4; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), - [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2919,9 +2905,9 @@ inline void compute_all_padding_post_relu6(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_RELU6 : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), - [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -2941,6 +2927,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif + din_ptr_arr[1] -= 4; break; case 3: #ifdef __aarch64__ @@ -2996,6 +2983,7 @@ inline void compute_all_padding_post_relu6(float* dout, "q14", "q15"); #endif + din_ptr_arr[0] -= 4; break; default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; @@ -3003,20 +2991,22 @@ inline void compute_all_padding_post_relu6(float* dout, } // remain for (int w = 0; w < remain; w++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[num]++; + float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3]++; for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[tmp - i]++; + sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i]++; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } // right - for (int i = 0; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + for (int i = 0; i < pad_right_new; i++) { + float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3]++; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[2 - k]++; } *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; } @@ -3061,43 +3051,31 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, float* dout_batch = dout + n * out_channel_size; #pragma omp parallel for for (int c = 0; c < chin; c++) { - const float* din_ch = din_batch + c * in_size; - const float* weights_ch = weights + c * weights_size; - float* dout_ch = dout_batch + c * out_size; - float bias_val = flag_bias ? bias[c] : 0.f; - const float* din_ptr0 = din_ch; - const float* din_ptr1 = din_ptr0 + win; - const float* din_ptr2 = din_ptr1 + win; - const float* din_ptr3 = din_ptr2 + win; - const float* din_ptr4 = din_ptr3 + win; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - float* dout_ptr = dout_ch; - float32x4_t wr5; - float32x4_t wr6; - float32x4_t wr0 = vld1q_f32(weights_ch); - float32x4_t wr1 = vld1q_f32(weights_ch + 5); - float32x4_t wr2 = vld1q_f32(weights_ch + 10); - float32x4_t wr3 = vld1q_f32(weights_ch + 15); - float32x4_t wr4 = vld1q_f32(weights_ch + 20); - wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); - wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); - wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); - wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); - wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); - std::vector din_ptr_arr; - std::vector weights_vec; - din_ptr_arr.push_back(din_ptr0); - din_ptr_arr.push_back(din_ptr1); - din_ptr_arr.push_back(din_ptr2); - din_ptr_arr.push_back(din_ptr3); - din_ptr_arr.push_back(din_ptr4); - weights_vec.push_back(wr0); - weights_vec.push_back(wr1); - weights_vec.push_back(wr2); - weights_vec.push_back(wr3); - weights_vec.push_back(wr4); - weights_vec.push_back(wr5); - weights_vec.push_back(wr6); + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h for (int h = pad_top; h > 4; h--) { memset(dout_ptr, bias[0], sizeof(float)*wout); @@ -3105,37 +3083,53 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, } for (int h = pad_top_new; h > 0; h--) { compute_all_padding_pre_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, - win, wout, pad_left, pad_left_new, pad_right, + win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); dout_ptr += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } // mid_h for (int h = 0; h < loop_h; h++) { compute_all_padding_mid_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, - win, wout, pad_left, pad_left_new, pad_right, + win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 4); dout_ptr += wout; - for (int i = 0; i < 4; i++) { - din_ptr_arr[i] = din_ptr_arr[i + 1]; - } - din_ptr_arr[4] += win; + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr4 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } // bottom for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero, - win, wout, pad_left, pad_left_new, pad_right, + win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h); dout_ptr += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } } } } inline void compute_all_padding_pre_leakyRelu(float* dout, - std::vector din_ptr_arr, + const float** din_ptr_arr, const float* bias, const float* scale, - std::vector weights, + float32x4_t* weights, float32x4_t vzero, int win, int wout, @@ -3149,6 +3143,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, #ifdef __aarch64__ float32x4_t vscale = vld1q_f32(scale); #endif + int tmp_index = num - 1; // left for (int w = pad_left; w > 4; w--) { *dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0]; @@ -3156,7 +3151,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, for (int i = pad_left_new; i > 0; i--) { float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); + sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -3368,23 +3363,25 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } + din_ptr_arr[0] -= 4; } // remain for (int w = 0; w < remain; w++) { float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); din_ptr_arr[num]++; for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[5][3 - i], 4); - din_ptr_arr[num - 1 - i]++; + sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4); + din_ptr_arr[tmp_index - i]++; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } - // right - for (int i = 1; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][4 - i], 4 - i); + for (int i = 0; i < pad_right_new; i++) { + float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num]++; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][4 - i], 4 - i); + sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i); + din_ptr_arr[tmp_index - k]++; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -3394,10 +3391,10 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, } inline void compute_all_padding_mid_leakyRelu(float* dout, - std::vector din_ptr_arr, + const float** din_ptr_arr, const float* bias, const float* scale, - std::vector weights, + float32x4_t* weights, float32x4_t vzero, int win, int wout, @@ -3482,7 +3479,8 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, "q13", "q14", "q15"); -#endif +#endif + din_ptr_arr[0] -= 4; } // remain for (int w = 0; w < remain; w++) { @@ -3498,8 +3496,10 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, // right for (int i = 0; i < pad_right_new; i++) { float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num]++; for (int k = 0; k < num; k++) { sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[tmp - k]++; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -3508,10 +3508,10 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, } } inline void compute_all_padding_post_leakyRelu(float* dout, - std::vector din_ptr_arr, + const float** din_ptr_arr, const float* bias, const float* scale, - std::vector weights, + float32x4_t* weights, float32x4_t vzero, int win, int wout, @@ -3531,9 +3531,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, } int tmp = num - 1; for (int i = pad_left_new; i > 0; i--) { - float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - i); + float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); for (int k = 0; k < num; k++) { - sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); + sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -3544,7 +3544,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #ifdef __aarch64__ asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr0] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -3564,7 +3564,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #else asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr0] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr5] "w"(weights[5]), @@ -3582,13 +3582,14 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif + din_ptr_arr[3] -= 4; break; case 1: #ifdef __aarch64__ asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -3609,8 +3610,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #else asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -3629,14 +3630,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif + din_ptr_arr[2] -= 4; break; case 2: #ifdef __aarch64__ asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), - [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -3658,9 +3660,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, #else asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr_arr[0]), - [din_ptr1] "+r"(din_ptr_arr[1]), - [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), [dout_ptr] "+r"(dout) : [wr0] "w"(weights[0]), [wr1] "w"(weights[1]), @@ -3680,6 +3682,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif + din_ptr_arr[1] -= 4; break; case 3: #ifdef __aarch64__ @@ -3735,6 +3738,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, "q14", "q15"); #endif + din_ptr_arr[0] -= 4; break; default: LOG(FATAL) << "This num: " << (num + 1) << "does not support"; @@ -3742,20 +3746,22 @@ inline void compute_all_padding_post_leakyRelu(float* dout, } // remain for (int w = 0; w < remain; w++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); - din_ptr_arr[num]++; + float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3]++; for (int i = 0; i < num; i++) { - sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); - din_ptr_arr[tmp - i]++; + sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i]++; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } // right for (int i = 0; i < pad_right_new; i++) { - float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3]++; for (int k = 0; k < num; k++) { - sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); + din_ptr_arr[2 - k]++; } *dout++ = sum > 0.f ? sum : sum * scale[0]; } @@ -3800,43 +3806,31 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, float* dout_batch = dout + n * out_channel_size; #pragma omp parallel for for (int c = 0; c < chin; c++) { - const float* din_ch = din_batch + c * in_size; - const float* weights_ch = weights + c * weights_size; - float* dout_ch = dout_batch + c * out_size; - float bias_val = flag_bias ? bias[c] : 0.f; - const float* din_ptr0 = din_ch; - const float* din_ptr1 = din_ptr0 + win; - const float* din_ptr2 = din_ptr1 + win; - const float* din_ptr3 = din_ptr2 + win; - const float* din_ptr4 = din_ptr3 + win; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - float* dout_ptr = dout_ch; - float32x4_t wr5; - float32x4_t wr6; - float32x4_t wr0 = vld1q_f32(weights_ch); - float32x4_t wr1 = vld1q_f32(weights_ch + 5); - float32x4_t wr2 = vld1q_f32(weights_ch + 10); - float32x4_t wr3 = vld1q_f32(weights_ch + 15); - float32x4_t wr4 = vld1q_f32(weights_ch + 20); - wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); - wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); - wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); - wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); - wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); - std::vector din_ptr_arr; - std::vector weights_vec; - din_ptr_arr.push_back(din_ptr0); - din_ptr_arr.push_back(din_ptr1); - din_ptr_arr.push_back(din_ptr2); - din_ptr_arr.push_back(din_ptr3); - din_ptr_arr.push_back(din_ptr4); - weights_vec.push_back(wr0); - weights_vec.push_back(wr1); - weights_vec.push_back(wr2); - weights_vec.push_back(wr3); - weights_vec.push_back(wr4); - weights_vec.push_back(wr5); - weights_vec.push_back(wr6); + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; // top_h for (int h = pad_top; h > 4; h--) { memset(dout_ptr, bias[0], sizeof(float)*wout); @@ -3844,27 +3838,43 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, } for (int h = pad_top_new; h > 0; h--) { compute_all_padding_pre_leakyRelu(dout_ptr, din_ptr_arr, vbias, scale, weights_vec, vzero, - win, wout, pad_left, pad_left_new, pad_right, + win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); dout_ptr += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } // mid_h for (int h = 0; h < loop_h; h++) { compute_all_padding_mid_leakyRelu(dout_ptr, din_ptr_arr, vbias, scale, weights_vec, vzero, - win, wout, pad_left, pad_left_new, pad_right, + win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 4); dout_ptr += wout; - for (int i = 0; i < 4; i++) { - din_ptr_arr[i] = din_ptr_arr[i + 1]; - } - din_ptr_arr[4] += win; + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr4 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } // bottom for (int h = 0; h < pad_bottom_new; h++) { compute_all_padding_post_leakyRelu(dout_ptr, din_ptr_arr, vbias, scale, weights_vec, vzero, - win, wout, pad_left, pad_left_new, pad_right, + win, wout, pad_left, pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h); dout_ptr += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } } } -- GitLab