提交 9348cb41 编写于 作者: C chenjiaoAngel

fix format

上级 9a09cf28
...@@ -1211,13 +1211,11 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1211,13 +1211,11 @@ inline void compute_all_padding_pre(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
...@@ -1429,7 +1427,7 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1429,7 +1427,7 @@ inline void compute_all_padding_pre(float* dout,
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1447,14 +1445,12 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1447,14 +1445,12 @@ inline void compute_all_padding_mid(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -1531,7 +1527,7 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1531,7 +1527,7 @@ inline void compute_all_padding_mid(float* dout,
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1550,15 +1546,13 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1550,15 +1546,13 @@ inline void compute_all_padding_mid_out2(float* dout0,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp1 = num + 1; int tmp1 = num + 1;
int tmp = num - 1; int tmp = num - 1;
// left // left
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1647,7 +1641,7 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1647,7 +1641,7 @@ inline void compute_all_padding_mid_out2(float* dout0,
*dout1++ = sum1; *dout1++ = sum1;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
...@@ -1660,11 +1654,6 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1660,11 +1654,6 @@ inline void compute_all_padding_mid_out2(float* dout0,
*dout0++ = sum; *dout0++ = sum;
*dout1++ = sum1; *dout1++ = sum1;
} }
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0];
*dout1++ = bias[0];
}
} }
inline void compute_all_padding_post(float* dout, inline void compute_all_padding_post(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -1674,17 +1663,12 @@ inline void compute_all_padding_post(float* dout, ...@@ -1674,17 +1663,12 @@ inline void compute_all_padding_post(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
/* for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0];
}*/
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -1894,7 +1878,7 @@ inline void compute_all_padding_post(float* dout, ...@@ -1894,7 +1878,7 @@ inline void compute_all_padding_post(float* dout,
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1927,10 +1911,6 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1927,10 +1911,6 @@ void conv_depthwise_5x5s1_bias(float* dout,
int out_size = wout * hout; int out_size = wout * hout;
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int pad_left_new = pad_left > 4 ? 4 : pad_left;
int pad_right_new = pad_right > 4 ? 4 : pad_right;
int pad_top_new = pad_top > 4 ? 4 : pad_top;
int pad_bottom_new = pad_bottom > 4 ? 4 : pad_bottom;
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
...@@ -1968,9 +1948,9 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1968,9 +1948,9 @@ void conv_depthwise_5x5s1_bias(float* dout,
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_pre(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); pad_right, cnt, remain, 4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -1982,7 +1962,7 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1982,7 +1962,7 @@ void conv_depthwise_5x5s1_bias(float* dout,
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4); pad_right, cnt, remain, 4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -2000,7 +1980,7 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -2000,7 +1980,7 @@ void conv_depthwise_5x5s1_bias(float* dout,
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4); pad_right, cnt, remain, 4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -2014,9 +1994,9 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -2014,9 +1994,9 @@ void conv_depthwise_5x5s1_bias(float* dout,
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_post(dout_ptr0, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h); pad_right, cnt, remain, 3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -2037,17 +2017,11 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2037,17 +2017,11 @@ inline void compute_all_padding_pre_relu(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
// left for (int i = pad_left; i > 0; i--) {
/* for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : 0.f;
}*/
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
...@@ -2267,7 +2241,7 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2267,7 +2241,7 @@ inline void compute_all_padding_pre_relu(float* dout,
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -2276,10 +2250,6 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2276,10 +2250,6 @@ inline void compute_all_padding_pre_relu(float* dout,
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
/*
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : 0.f;
}*/
} }
inline void compute_all_padding_mid_relu(float* dout, inline void compute_all_padding_mid_relu(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -2290,17 +2260,11 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2290,17 +2260,11 @@ inline void compute_all_padding_mid_relu(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left
/* for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : 0.f;
}*/
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -2378,7 +2342,7 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2378,7 +2342,7 @@ inline void compute_all_padding_mid_relu(float* dout,
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -2387,10 +2351,6 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2387,10 +2351,6 @@ inline void compute_all_padding_mid_relu(float* dout,
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
/*
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : 0.f;
}*/
} }
inline void compute_all_padding_mid_relu_out2(float* dout0, inline void compute_all_padding_mid_relu_out2(float* dout0,
float* dout1, float* dout1,
...@@ -2402,19 +2362,13 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2402,19 +2362,13 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : 0.f;
*dout1++ = bias[0] > 0.f ? bias[0] : 0.f;
}
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1; int tmp1 = num + 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -2505,7 +2459,7 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2505,7 +2459,7 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
*dout1++ = sum1 > 0.f ? sum1 : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
...@@ -2518,10 +2472,6 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2518,10 +2472,6 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
*dout0++ = sum > 0.f ? sum : 0.f; *dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f;
} }
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : 0.f;
*dout1++ = bias[0] > 0.f ? bias[0] : 0.f;
}
} }
inline void compute_all_padding_post_relu(float* dout, inline void compute_all_padding_post_relu(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -2532,17 +2482,12 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2532,17 +2482,12 @@ inline void compute_all_padding_post_relu(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
/*for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : 0.f;
}*/
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -2760,7 +2705,7 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2760,7 +2705,7 @@ inline void compute_all_padding_post_relu(float* dout,
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -2769,10 +2714,6 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2769,10 +2714,6 @@ inline void compute_all_padding_post_relu(float* dout,
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
/*
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : 0.f;
}*/
} }
void conv_depthwise_5x5s1_bias_relu(float* dout, void conv_depthwise_5x5s1_bias_relu(float* dout,
...@@ -2797,10 +2738,6 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2797,10 +2738,6 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
int out_size = wout * hout; int out_size = wout * hout;
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int pad_left_new = pad_left > 4 ? 4 : pad_left;
int pad_right_new = pad_right > 4 ? 4 : pad_right;
int pad_top_new = pad_top > 4 ? 4 : pad_top;
int pad_bottom_new = pad_bottom > 4 ? 4 : pad_bottom;
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
...@@ -2839,9 +2776,9 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2839,9 +2776,9 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_pre_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h); pad_right, cnt, remain, 4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -2853,7 +2790,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2853,7 +2790,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_mid_relu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4); pad_right, cnt, remain, 4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -2871,7 +2808,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2871,7 +2808,7 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_mid_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 4); pad_right, cnt, remain, 4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -2885,9 +2822,9 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2885,9 +2822,9 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_post_relu(dout_ptr0, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h); pad_right, cnt, remain, 3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -2909,8 +2846,6 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -2909,8 +2846,6 @@ inline void compute_all_padding_pre_relu6(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -2919,10 +2854,7 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -2919,10 +2854,7 @@ inline void compute_all_padding_pre_relu6(float* dout,
#endif #endif
int tmp_index = num - 1; int tmp_index = num - 1;
// left // left
for (int w = pad_left; w > 4; w--) { for (int i = pad_left; i > 0; i--) {
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
...@@ -3150,7 +3082,7 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -3150,7 +3082,7 @@ inline void compute_all_padding_pre_relu6(float* dout,
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -3159,10 +3091,6 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -3159,10 +3091,6 @@ inline void compute_all_padding_pre_relu6(float* dout,
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
} }
inline void compute_all_padding_mid_relu6(float* dout, inline void compute_all_padding_mid_relu6(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -3174,8 +3102,6 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3174,8 +3102,6 @@ inline void compute_all_padding_mid_relu6(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3183,11 +3109,8 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3183,11 +3109,8 @@ inline void compute_all_padding_mid_relu6(float* dout,
float32x4_t vsix = vld1q_f32(six); float32x4_t vsix = vld1q_f32(six);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -3267,7 +3190,7 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3267,7 +3190,7 @@ inline void compute_all_padding_mid_relu6(float* dout,
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -3276,9 +3199,6 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3276,9 +3199,6 @@ inline void compute_all_padding_mid_relu6(float* dout,
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
} }
inline void compute_all_padding_mid_relu6_out2(float* dout0, inline void compute_all_padding_mid_relu6_out2(float* dout0,
...@@ -3292,8 +3212,6 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3292,8 +3212,6 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3301,13 +3219,9 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3301,13 +3219,9 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
float32x4_t vsix = vld1q_f32(six); float32x4_t vsix = vld1q_f32(six);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
*dout1++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1; int tmp1 = num + 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -3399,7 +3313,7 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3399,7 +3313,7 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
...@@ -3412,10 +3326,6 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3412,10 +3326,6 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
} }
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
*dout1++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
} }
inline void compute_all_padding_post_relu6(float* dout, inline void compute_all_padding_post_relu6(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -3427,8 +3337,6 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3427,8 +3337,6 @@ inline void compute_all_padding_post_relu6(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3436,11 +3344,8 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3436,11 +3344,8 @@ inline void compute_all_padding_post_relu6(float* dout,
float32x4_t vsix = vld1q_f32(six); float32x4_t vsix = vld1q_f32(six);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -3666,7 +3571,7 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3666,7 +3571,7 @@ inline void compute_all_padding_post_relu6(float* dout,
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -3675,9 +3580,6 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3675,9 +3580,6 @@ inline void compute_all_padding_post_relu6(float* dout,
} }
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? (bias[0] < six[0] ? bias[0] : six[0]) : 0.f;
}
} }
void conv_depthwise_5x5s1_bias_relu6(float* dout, void conv_depthwise_5x5s1_bias_relu6(float* dout,
...@@ -3703,10 +3605,6 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3703,10 +3605,6 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
int out_size = wout * hout; int out_size = wout * hout;
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int pad_left_new = pad_left > 4 ? 4 : pad_left;
int pad_right_new = pad_right > 4 ? 4 : pad_right;
int pad_top_new = pad_top > 4 ? 4 : pad_top;
int pad_bottom_new = pad_bottom > 4 ? 4 : pad_bottom;
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
...@@ -3745,10 +3643,9 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3745,10 +3643,9 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_pre_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 4 - h);
pad_right_new, cnt, remain, 4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -3760,8 +3657,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3760,8 +3657,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_relu6_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_mid_relu6_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 4);
pad_right_new, cnt, remain, 4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -3779,8 +3675,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3779,8 +3675,7 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_mid_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 4);
pad_right_new, cnt, remain, 4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -3794,10 +3689,9 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout, ...@@ -3794,10 +3689,9 @@ void conv_depthwise_5x5s1_bias_relu6(float* dout,
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero, compute_all_padding_post_relu6(dout_ptr0, din_ptr_arr, vbias, six, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 3 - h);
pad_right_new, cnt, remain, 3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -3819,8 +3713,6 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3819,8 +3713,6 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3829,10 +3721,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3829,10 +3721,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
#endif #endif
int tmp_index = num - 1; int tmp_index = num - 1;
// left // left
for (int w = pad_left; w > 4; w--) { for (int i = pad_left; i > 0; i--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], weights[3 - k], 0.f, weights[5][3 - k], 4 - i);
...@@ -4068,7 +3957,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -4068,7 +3957,7 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -4092,8 +3981,6 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4092,8 +3981,6 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4105,7 +3992,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4105,7 +3992,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0]; *dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
} }
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -4188,7 +4075,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4188,7 +4075,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++; din_ptr_arr[num]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -4197,9 +4084,6 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4197,9 +4084,6 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
} }
inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
float* dout1, float* dout1,
...@@ -4212,8 +4096,6 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4212,8 +4096,6 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4221,13 +4103,9 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4221,13 +4103,9 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
float32x4_t vscale = vld1q_f32(scale); float32x4_t vscale = vld1q_f32(scale);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
*dout1++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
int tmp = num - 1; int tmp = num - 1;
int tmp1 = num + 1; int tmp1 = num + 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); float sum1 = compute_one_data_pre(din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -4322,7 +4200,7 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4322,7 +4200,7 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum1 = compute_one_data_post(din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1]++;
...@@ -4335,10 +4213,6 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4335,10 +4213,6 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
*dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
} }
for (int w = pad_right; w > 4; w--) {
*dout0++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
*dout1++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
} }
inline void compute_all_padding_post_leakyRelu(float* dout, inline void compute_all_padding_post_leakyRelu(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -4350,8 +4224,6 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4350,8 +4224,6 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
int wout, int wout,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new,
int pad_right_new,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4359,11 +4231,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4359,11 +4231,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
float32x4_t vscale = vld1q_f32(scale); float32x4_t vscale = vld1q_f32(scale);
#endif #endif
// left // left
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i); sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
...@@ -4597,7 +4466,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4597,7 +4466,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++; din_ptr_arr[3]++;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -4606,9 +4475,6 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4606,9 +4475,6 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
} }
void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
...@@ -4634,10 +4500,6 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4634,10 +4500,6 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
int out_size = wout * hout; int out_size = wout * hout;
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int pad_left_new = pad_left > 4 ? 4 : pad_left;
int pad_right_new = pad_right > 4 ? 4 : pad_right;
int pad_top_new = pad_top > 4 ? 4 : pad_top;
int pad_bottom_new = pad_bottom > 4 ? 4 : pad_bottom;
int in_channel_size = chin * in_size; int in_channel_size = chin * in_size;
int out_channel_size = chin * out_size; int out_channel_size = chin * out_size;
int weights_size = 25; int weights_size = 25;
...@@ -4676,10 +4538,9 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4676,10 +4538,9 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5};
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h // top_h
for (int h = pad_top_new; h > 0; h--) { for (int h = pad_top; h > 0; h--) {
compute_all_padding_pre_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_pre_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 4 - h);
pad_right_new, cnt, remain, 4 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
...@@ -4691,8 +4552,7 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4691,8 +4552,7 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
// mid_h // mid_h
for (int h = 0; h < loop_h - 1; h += 2) { for (int h = 0; h < loop_h - 1; h += 2) {
compute_all_padding_mid_leakyRelu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_mid_leakyRelu_out2(dout_ptr0, dout_ptr1, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 4);
pad_right_new, cnt, remain, 4);
dout_ptr0 += num_out; dout_ptr0 += num_out;
dout_ptr1 += num_out; dout_ptr1 += num_out;
din_ptr0 = din_ptr2; din_ptr0 = din_ptr2;
...@@ -4710,8 +4570,7 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4710,8 +4570,7 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
} }
if (loop_h % 2 != 0) { if (loop_h % 2 != 0) {
compute_all_padding_mid_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_mid_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 4);
pad_right_new, cnt, remain, 4);
dout_ptr0 = dout_ptr1; dout_ptr0 = dout_ptr1;
din_ptr0 = din_ptr1; din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2; din_ptr1 = din_ptr2;
...@@ -4725,10 +4584,9 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout, ...@@ -4725,10 +4584,9 @@ void conv_depthwise_5x5s1_bias_leakyRelu(float* dout,
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
// bottom // bottom
for (int h = 0; h < pad_bottom_new; h++) { for (int h = 0; h < pad_bottom; h++) {
compute_all_padding_post_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero, compute_all_padding_post_leakyRelu(dout_ptr0, din_ptr_arr, vbias, scale, weights_vec, vzero,
win, wout, pad_left, pad_right, pad_left_new, win, wout, pad_left, pad_right, cnt, remain, 3 - h);
pad_right_new, cnt, remain, 3 - h);
dout_ptr0 += wout; dout_ptr0 += wout;
din_ptr_arr[0] = din_ptr0; din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1; din_ptr_arr[1] = din_ptr1;
......
...@@ -736,6 +736,7 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -736,6 +736,7 @@ void conv_depthwise_5x5_fp32(const void* din,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
bool ch_four = ch_in > 4 * w_in; bool ch_four = ch_in > 4 * w_in;
bool pads_five = (pad_h < 5) || (pad_w < 5);
ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); ctx->ExtendWorkspace((w_in + w_out) * sizeof(float));
bool flag_act = act_param.has_active; bool flag_act = act_param.has_active;
if (stride == 2) { if (stride == 2) {
...@@ -754,7 +755,7 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -754,7 +755,7 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param, act_param,
ctx); ctx);
} else if (stride == 1) { } else if (stride == 1) {
if (ch_four || h_in < 5 || w_in < 5) { if (ch_four || !pads_five || h_in < 5 || w_in < 5) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout), conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din), reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
......
...@@ -59,8 +59,9 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -59,8 +59,9 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
#endif #endif
} else if (kw == 5) { } else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32"; // VLOG(5) << "invoke 5x5 dw conv fp32";
bool pads_five = (paddings[0] < 5) || (paddings[2] < 5);
auto strides = param.strides; auto strides = param.strides;
if (ch_four && win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) { if (ch_four && pads_five && win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) {
flag_trans_weights_ = false; flag_trans_weights_ = false;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册