提交 1b9caa46 编写于 作者: C chenjiaoAngel

fix pad=1 error

上级 047a22cb
......@@ -615,7 +615,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"1: \n" \
"subs %[cnt], #1\n" \
"vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \
"vmul.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \
"vmul.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \
"vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \
......@@ -626,7 +626,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \
"vmla.f32 q15, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \
"vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \
"vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\
"vext.32 q10, q8, q9, #1\n" \
......@@ -945,7 +945,7 @@ inline float compute_one_data_post(const float* data, float32x4_t wr, float bias
inline void compute_all_padding_pre(float* dout,
const float** din_ptr_arr,
const float* bias,
std::vector<float32x4_t> weights,
float32x4_t* weights,
int win,
int wout,
int pad_left,
......@@ -955,6 +955,7 @@ inline void compute_all_padding_pre(float* dout,
int cnt,
int remain,
int num) {
int tmp_index = num - 1;
// left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0];
......@@ -962,7 +963,7 @@ inline void compute_all_padding_pre(float* dout,
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
}
*dout++ = sum;
}
......@@ -1158,55 +1159,29 @@ inline void compute_all_padding_pre(float* dout,
default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support";
}
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4);
din_ptr_arr[num]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[5][3 - i], 4);
din_ptr_arr[num - 1 - i]++;
sum += compute_one_data_post(din_ptr_arr[tmp_index - i], weights[3 - i], 0.f, weights[5][3 - i], 4);
din_ptr_arr[tmp_index - i]++;
}
*dout++ = sum;
}
// right
for (int i = 1; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][4 - i], 4 - i);
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][4 - i], 4 - i);
}
*dout++ = sum;
}
/*
switch (pad_right_new) {
case 1:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3], 3);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][3], 3);
}
*dout++ = sum;
case 2:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][2], 2);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][2], 2);
}
*dout++ = sum;
case 3:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][1], 1);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][1], 1);
}
*dout++ = sum;
case 4:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][0], 0);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][0], 0);
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i);
din_ptr_arr[tmp_index - k]++;
}
*dout++ = sum;
}
*/
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0];
}
......@@ -1215,7 +1190,7 @@ inline void compute_all_padding_pre(float* dout,
inline void compute_all_padding_mid(float* dout,
const float** din_ptr_arr,
const float* bias,
std::vector<float32x4_t> weights,
float32x4_t* weights,
int win,
int wout,
int pad_left,
......@@ -1294,6 +1269,7 @@ inline void compute_all_padding_mid(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
......@@ -1309,8 +1285,10 @@ inline void compute_all_padding_mid(float* dout,
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[tmp - k]++;
}
*dout++ = sum;
}
......@@ -1321,7 +1299,7 @@ inline void compute_all_padding_mid(float* dout,
inline void compute_all_padding_post(float* dout,
const float** din_ptr_arr,
const float* bias,
std::vector<float32x4_t> weights,
float32x4_t* weights,
int win,
int wout,
int pad_left,
......@@ -1337,9 +1315,9 @@ inline void compute_all_padding_post(float* dout,
}
int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - i);
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout++ = sum;
}
......@@ -1529,23 +1507,26 @@ inline void compute_all_padding_post(float* dout,
default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support";
}
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[num]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++;
sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i]++;
}
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 3 - i);
din_ptr_arr[3]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][tmp - i], 3 - i);
din_ptr_arr[2 - k]++;
}
*dout++ = sum;
}
......@@ -1612,7 +1593,7 @@ void conv_depthwise_5x5s1_bias(float* dout,
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4};
float32x4_t wei_vwc[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h
for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout);
......@@ -1622,16 +1603,27 @@ void conv_depthwise_5x5s1_bias(float* dout,
compute_all_padding_pre(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout;
}
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
// mid_h
for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 4);
dout_ptr += wout;
for (int i = 0; i < 4; i++) {
din_ptr_arr[i] = din_ptr_arr[i + 1];
}
din_ptr_arr[4] += win;
din_ptr_arr[0] = din_ptr1;
din_ptr_arr[1] = din_ptr2;
din_ptr_arr[2] = din_ptr3;
din_ptr_arr[3] = din_ptr4;
din_ptr_arr[4] = din_ptr4 + win
}
// bottom
for (int h = 0; h < pad_bottom_new; h++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册