提交 f7c30837 编写于 作者: C chenjiaoAngel

fix c4

上级 ae8cfcec
......@@ -80,6 +80,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv3x3s2_depthwise_int8.cc
conv5x5s1_depthwise_int8.cc
conv5x5s1_depthwise_fp32.cc
conv5x5s1_depthwise_fp32_c4.cc
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc
......
......@@ -576,9 +576,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define COMPUTE_TWO_LINE_S1_PRE \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
......@@ -606,9 +606,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define COMPUTE_THREE_LINE_S1_PRE \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
......@@ -646,9 +646,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define COMPUTE_FOUR_LINE_S1_PRE \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
......@@ -696,9 +696,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define COMPUTE_FIVE_LINE_S1 \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
......@@ -776,9 +776,9 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define COMPUTE_TWO_LINE_S1_POST \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
......@@ -921,6 +921,110 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q11, q8, q9, #2\n" \
"vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \
"bne 1b"
#define COMPUTE_FIVE_LINE_S1_1 \
"vld1.f32 {d28-d29}, [%[bias]]\n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vld1.f32 {d18-d19}, [%[din_ptr0]] \n" \
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vext.32 q12, q8, q9, #3\n" \
"1: \n" \
"subs %[cnt], #1\n" \
"vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \
"vmul.f32 q13, q9, %e[wr5][0]\n" /*4567*wr5[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \
"vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \
"vmla.f32 q13, q11, %f[wr0][0]\n" /*2345*wr0[2]*/\
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q15, q12, %f[wr0][1]\n" /*3456*wr0[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q13, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \
"vmla.f32 q14, q8, %e[wr0][0]\n" /*0123*wr1[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \
"vmla.f32 q15, q9, %e[wr5][1]\n" /*4567*wr5[1]*/ \
"vmla.f32 q14, q9, %e[wr5][0]\n" /*4567*wr5[1]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \
"vmla.f32 q13, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\
"vmla.f32 q14, q10, %e[wr0][1]\n" /*1234*wr1[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q15, q11, %f[wr1][0]\n" /*2345*wr1[2]*/\
"vmla.f32 q13, q12, %f[wr1][1]\n" /*3456*wr1[3]*/\
"vmla.f32 q14, q11, %f[wr0][0]\n" /*2345*wr1[2]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr0][1]\n" /*3456*wr1[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q15, q8, %e[wr2][0]\n" /*0123*wr2[0]*/ \
"vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr2[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr3]]!\n" \
"vmla.f32 q13, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \
"vmla.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr3]]\n" \
"vmla.f32 q15, q10, %e[wr2][1]\n" /*1234*wr2[1]*/\
"vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr2[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q13, q11, %f[wr2][0]\n" /*2345*wr2[2]*/\
"vmla.f32 q14, q11, %f[wr1][0]\n" /*2345*wr2[2]*/\
"vmla.f32 q15, q12, %f[wr2][1]\n" /*3456*wr2[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr1][1]\n" /*3456*wr2[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q13, q8, %e[wr3][0]\n" /*0123*wr3[0]*/ \
"vmla.f32 q14, q8, %e[wr2][0]\n" /*0123*wr3[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr4]]!\n" \
"vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \
"vmla.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr4]]\n" \
"vmla.f32 q13, q10, %e[wr3][1]\n" /*1234*wr3[1]*/\
"vmla.f32 q14, q10, %e[wr2][1]\n" /*1234*wr3[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q15, q11, %f[wr3][0]\n" /*2345*wr3[2]*/\
"vmla.f32 q14, q11, %f[wr2][0]\n" /*2345*wr3[2]*/\
"vmla.f32 q13, q12, %f[wr3][1]\n" /*3456*wr3[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr2][1]\n" /*3456*wr3[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q15, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \
"vmla.f32 q14, q8, %e[wr3][0]\n" /*0123*wr4[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr5]]!\n" \
"vmla.f32 q13, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \
"vmla.f32 q14, q9, %f[wr5][1]\n" /*4567*wr6[0]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr5]]\n" \
"vmla.f32 q15, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\
"vmla.f32 q14, q10, %e[wr3][1]\n" /*1234*wr4[1]*/\
"vext.32 q10, q8, q9, #1\n" \
"vmla.f32 q13, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\
"vmla.f32 q14, q11, %f[wr3][0]\n" /*2345*wr4[2]*/\
"vmla.f32 q15, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr3][1]\n" /*3456*wr4[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q13, q13, q15\n" \
"vmla.f32 q14, q8, %e[wr4][0]\n" /*0123*wr4[0]*/ \
"vmul.f32 q15, q9, %e[wr6][0]\n" /*4567*wr6[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr0]]!\n" \
"vmla.f32 q14, q10, %e[wr4][1]\n" /*1234*wr4[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr0]]\n" \
"vmla.f32 q15, q11, %f[wr4][0]\n" /*2345*wr4[2]*/\
"vext.32 q10, q8, q9, #1\n" \
"vext.32 q11, q8, q9, #2\n" \
"vmla.f32 q14, q12, %f[wr4][1]\n" /*3456*wr4[3]*/\
"vext.32 q12, q8, q9, #3\n" \
"vadd.f32 q14, q14, q15\n"
#define RESULT_S1_RELU6_1 \
"vld1.f32 {d30-d31}, [%[six_ptr]]\n" \
"vmax.f32 q13, q13, %q[vzero]\n" \
"vmax.f32 q14, q14, %q[vzero]\n" \
"vmin.f32 q13, q13, q15\n" \
"vmin.f32 q14, q14, q15\n" \
"vld1.f32 {d30-d31}, [%[bias]]\n" \
"vst1.f32 {d26-d27}, [%[dout_ptr0]]!\n" \
"vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \
"vld1.f32 {d28-d29}, [%[bias]]\n" \
"bne 1b"
#endif
inline float compute_one_data_pre(const float* data, float32x4_t wr, float bias_val, float wei_val, int num) {
......@@ -2742,6 +2846,139 @@ inline void compute_all_padding_mid_relu6(float* dout,
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
}
inline void compute_all_padding_mid_relu6_1(float* dout0, float* dout1,
const float** din_ptr_arr,
const float* bias,
const float* six,
float32x4_t* weights,
float32x4_t vzero,
int win,
int wout,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
// left
for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[num + 1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[num -k], weights[tmp -k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
}
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "++r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[bias] "r"(bias)
: "cc",
"memory",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(COMPUTE_FIVE_LINE_S1_1 RESULT_S1_RELU6_1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr3] "+r"(din_ptr_arr[3]),
[din_ptr4] "+r"(din_ptr_arr[4]),
[din_ptr5] "+r"(din_ptr_arr[5]),
[dout_ptr0] "+r"(dout0),
[dout_ptr1] "+r"(dout1)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
[wr2] "w"(weights[2]),
[wr3] "w"(weights[3]),
[wr4] "w"(weights[4]),
[wr5] "w"(weights[5]),
[wr6] "w"(weights[6]),
[vzero] "w"(vzero),
[six_ptr] "r"(six),
[bias] "r"(bias)
: "cc",
"memory",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4);
float sum1 = compute_one_data_post(din_ptr_arr[num + 1], weights[num], bias[0], weights[6][0], 4);
din_ptr_arr[num + 1]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
sum1 += compute_one_data_post(din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[num - i]++;
}
din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[num + 1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num+1]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
sum1 += compute_one_data_post(din_ptr_arr[num - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[num - k]++;
}
din_ptr_arr[0]++;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
}
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
*dout1++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
}
inline void compute_all_padding_post_relu6(float* dout,
const float** din_ptr_arr,
const float* bias,
......@@ -3050,8 +3287,10 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
const float* din_ptr2 = din_ptr1 + win;
const float* din_ptr3 = din_ptr2 + win;
const float* din_ptr4 = din_ptr3 + win;
const float* din_ptr5 = din_ptr4 + win;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
float* dout_ptr = dout_ch;
float* dout_ptr1 = dout_ch;
float32x4_t wr5;
float32x4_t wr6;
float32x4_t wr0 = vld1q_f32(weights_ch);
......@@ -3064,7 +3303,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2);
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4};
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h
for (int h = pad_top; h > 4; h--) {
......@@ -3082,23 +3321,30 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
dout_ptr1 = dout_ptr + wout;
// mid_h
for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero,
for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu6_1(dout_ptr, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4);
dout_ptr += wout;
din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2;
din_ptr2 = din_ptr3;
din_ptr3 = din_ptr4;
din_ptr4 = din_ptr4 + win;
dout_ptr += 2 * wout;
dout_ptr1 += 2 * wout;
din_ptr0 = din_ptr2;
din_ptr1 = din_ptr3;
din_ptr2 = din_ptr4;
din_ptr3 = din_ptr5;
din_ptr4 = din_ptr5 + win;
din_ptr5 = din_ptr4 + win;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
din_ptr_arr[5] = din_ptr5;
}
if (loop_h % 2) compute_all_padding_mid_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new,
pad_right_new, cnt, remain, 4);
// bottom
for (int h = 0; h < pad_bottom_new; h++) {
compute_all_padding_post_relu6(dout_ptr, din_ptr_arr, vbias, six, weights_vec, vzero,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册