提交 01455e09 编写于 作者: C chenjiaoAngel

fix compute error

上级 07109803
......@@ -1320,20 +1320,24 @@ inline void compute_all_padding_pre(float* dout,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
int tmp_index = num - 1;
for (int i = pad_left; i > 0; i--) {
int num_index_left = 4 - pad_left;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
num_index_left);
}
num_index_left -= 2;
*dout++ = sum;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -1554,18 +1558,21 @@ inline void compute_all_padding_pre(float* dout,
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right; i++) {
int num_index_right = 4 - pad_right;
LOG(INFO) << "pad_right_new: " << pad_right_new << ", num_index_right: " << num_index_right;
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
weights[3 - k][num_index_right],
num_index_right);
din_ptr_arr[tmp_index - k] += 2;
}
num_index_right -= 2;
*dout++ = sum;
}
}
......@@ -1576,21 +1583,25 @@ inline void compute_all_padding_mid(float* dout,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
// left
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
int num_index_left = 4 - pad_left;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left -= 2;
*dout++ = sum;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -1673,18 +1684,19 @@ inline void compute_all_padding_mid(float* dout,
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right; i++) {
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][4 - pad_right],
4 - pad_right);
din_ptr_arr[tmp - k] += 2;
}
pad_right += 2;
*dout++ = sum;
}
}
......@@ -1696,6 +1708,8 @@ inline void compute_all_padding_mid_out2(float* dout0,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
......@@ -1703,23 +1717,24 @@ inline void compute_all_padding_mid_out2(float* dout0,
int tmp2 = num + 1;
int tmp = num - 1;
// left
for (int i = pad_left; i > 0; i--) {
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - pad_left);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - pad_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
4 - pad_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
4 - pad_left);
}
pad_left -= 2;
*dout0++ = sum;
*dout1++ = sum1;
}
......@@ -1820,25 +1835,26 @@ inline void compute_all_padding_mid_out2(float* dout0,
*dout1++ = sum1;
}
// right
for (int i = 0; i < pad_right; i++) {
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right);
din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][4 - pad_right],
4 - pad_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][4 - pad_right],
4 - pad_right);
din_ptr_arr[tmp2 - k] += 2;
}
pad_right += 2;
din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2;
*dout0++ = sum;
......@@ -1853,21 +1869,24 @@ inline void compute_all_padding_post(float* dout,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int cnt,
int remain,
int num) {
// left
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - pad_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k],
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
4 - pad_left);
}
pad_left -= 2;
*dout++ = sum;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -1884,7 +1903,7 @@ inline void compute_all_padding_post(float* dout,
#ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -1902,7 +1921,7 @@ inline void compute_all_padding_post(float* dout,
#else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -1918,14 +1937,14 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[3] -= 8;
din_ptr_arr[num] -= 8;
break;
case 1:
#ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1944,8 +1963,8 @@ inline void compute_all_padding_post(float* dout,
#else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1962,15 +1981,15 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[2] -= 8;
din_ptr_arr[tmp] -= 8;
break;
case 2:
#ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1990,9 +2009,9 @@ inline void compute_all_padding_post(float* dout,
#else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -2010,7 +2029,7 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[1] -= 8;
din_ptr_arr[tmp - 1] -= 8;
break;
case 3:
#ifdef __aarch64__
......@@ -2072,28 +2091,30 @@ inline void compute_all_padding_post(float* dout,
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[num] += 2;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(
din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i] += 2;
din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i] += 2;
}
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right; i++) {
int num_index = 4 - pad_right;
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index], num_index);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k],
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k] += 2;
weights[tmp - k][num_index],
num_index);
din_ptr_arr[tmp - k] += 2;
}
num_index -= 2;
*dout++ = sum;
}
}
......@@ -2139,7 +2160,18 @@ void conv_depthwise_5x5s2_bias(float* dout,
int cnt = loop_w >> 2;
int remain = loop_w & 3;
int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom;
int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 0 : 1);
if (n_right_w == 0) {
remain++;
pad_right_new--;
n_right_w += 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size;
float* dout_batch = dout + n * out_channel_size;
......@@ -2182,6 +2214,8 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
cnt,
......@@ -2220,6 +2254,8 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
cnt,
......@@ -2248,6 +2284,8 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
cnt,
......@@ -2273,6 +2311,8 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
cnt,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册