提交 1b9caa46 编写于 作者: C chenjiaoAngel

fix pad=1 error

上级 047a22cb
...@@ -615,7 +615,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -615,7 +615,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"1: \n" \ "1: \n" \
"subs %[cnt], #1\n" \ "subs %[cnt], #1\n" \
"vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \ "vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \
"vmul.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \ "vmul.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \ "vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \
"vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\ "vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \ "vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \
...@@ -626,7 +626,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -626,7 +626,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \ "vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \ "vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \ "vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \
"vmla.f32 q15, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \ "vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \ "vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \
"vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\ "vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\
"vext.32 q10, q8, q9, #1\n" \ "vext.32 q10, q8, q9, #1\n" \
...@@ -945,7 +945,7 @@ inline float compute_one_data_post(const float* data, float32x4_t wr, float bias ...@@ -945,7 +945,7 @@ inline float compute_one_data_post(const float* data, float32x4_t wr, float bias
inline void compute_all_padding_pre(float* dout, inline void compute_all_padding_pre(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
std::vector<float32x4_t> weights, float32x4_t* weights,
int win, int win,
int wout, int wout,
int pad_left, int pad_left,
...@@ -955,6 +955,7 @@ inline void compute_all_padding_pre(float* dout, ...@@ -955,6 +955,7 @@ inline void compute_all_padding_pre(float* dout,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp_index = num - 1;
// left // left
for (int w = pad_left; w > 4; w--) { for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0]; *dout++ = bias[0];
...@@ -962,7 +963,7 @@ inline void compute_all_padding_pre(float* dout, ...@@ -962,7 +963,7 @@ inline void compute_all_padding_pre(float* dout,
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
} }
*dout++ = sum; *dout++ = sum;
} }
...@@ -1158,55 +1159,29 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1158,55 +1159,29 @@ inline void compute_all_padding_pre(float* dout,
default: default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << "does not support";
} }
din_ptr_arr[0] -= 4;
} }
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[5][3 - i], 4); sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4);
din_ptr_arr[num - 1 - i]++; din_ptr_arr[tmp_index - i]++;
} }
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 1; i < pad_right_new; i++) { for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][4 - i], 4 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][4 - i], 4 - i); sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i);
din_ptr_arr[tmp_index - k]++;
} }
*dout++ = sum; *dout++ = sum;
} }
/*
switch (pad_right_new) {
case 1:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3], 3);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][3], 3);
}
*dout++ = sum;
case 2:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][2], 2);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][2], 2);
}
*dout++ = sum;
case 3:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][1], 1);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][1], 1);
}
*dout++ = sum;
case 4:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][0], 0);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][0], 0);
}
*dout++ = sum;
}
*/
for (int w = pad_right; w > 4; w--) { for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0]; *dout++ = bias[0];
} }
...@@ -1215,7 +1190,7 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1215,7 +1190,7 @@ inline void compute_all_padding_pre(float* dout,
inline void compute_all_padding_mid(float* dout, inline void compute_all_padding_mid(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
std::vector<float32x4_t> weights, float32x4_t* weights,
int win, int win,
int wout, int wout,
int pad_left, int pad_left,
...@@ -1293,7 +1268,8 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1293,7 +1268,8 @@ inline void compute_all_padding_mid(float* dout,
"q13", "q13",
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[0] -= 4;
} }
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
...@@ -1309,8 +1285,10 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1309,8 +1285,10 @@ inline void compute_all_padding_mid(float* dout,
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[tmp - k]++;
} }
*dout++ = sum; *dout++ = sum;
} }
...@@ -1321,7 +1299,7 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1321,7 +1299,7 @@ inline void compute_all_padding_mid(float* dout,
inline void compute_all_padding_post(float* dout, inline void compute_all_padding_post(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
const float* bias, const float* bias,
std::vector<float32x4_t> weights, float32x4_t* weights,
int win, int win,
int wout, int wout,
int pad_left, int pad_left,
...@@ -1337,9 +1315,9 @@ inline void compute_all_padding_post(float* dout, ...@@ -1337,9 +1315,9 @@ inline void compute_all_padding_post(float* dout,
} }
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
} }
*dout++ = sum; *dout++ = sum;
} }
...@@ -1529,23 +1507,26 @@ inline void compute_all_padding_post(float* dout, ...@@ -1529,23 +1507,26 @@ inline void compute_all_padding_post(float* dout,
default: default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << "does not support";
} }
din_ptr_arr[0] -= 4;
} }
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4); float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++; din_ptr_arr[2 - i]++;
} }
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 3 - i);
din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i); sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][tmp - i], 3 - i);
din_ptr_arr[2 - k]++;
} }
*dout++ = sum; *dout++ = sum;
} }
...@@ -1612,7 +1593,7 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1612,7 +1593,7 @@ void conv_depthwise_5x5s1_bias(float* dout,
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4};
float32x4_t wei_vwc[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top; h > 4; h--) { for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout); memset(dout_ptr, bias[0], sizeof(float)*wout);
...@@ -1622,16 +1603,27 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1622,16 +1603,27 @@ void conv_depthwise_5x5s1_bias(float* dout,
compute_all_padding_pre(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_pre(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 4 - h); pad_left_new, pad_right, pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout; dout_ptr += wout;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
} }
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
// mid_h // mid_h
for (int h = 0; h < loop_h; h++) { for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 4); pad_left_new, pad_right, pad_right_new, cnt, remain, 4);
dout_ptr += wout; dout_ptr += wout;
for (int i = 0; i < 4; i++) { din_ptr_arr[0] = din_ptr1;
din_ptr_arr[i] = din_ptr_arr[i + 1]; din_ptr_arr[1] = din_ptr2;
} din_ptr_arr[2] = din_ptr3;
din_ptr_arr[4] += win; din_ptr_arr[3] = din_ptr4;
din_ptr_arr[4] = din_ptr4 + win
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom_new; h++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册