提交 589e852c 编写于 作者: C chenjiaoAngel

fix relu relu6 error. test=develop

上级 01455e09
......@@ -1320,14 +1320,13 @@ inline void compute_all_padding_pre(float* dout,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
int tmp_index = num - 1;
int num_index_left = 4 - pad_left;
for (int i = pad_left_new; i > 0; i--) {
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
......@@ -1337,7 +1336,7 @@ inline void compute_all_padding_pre(float* dout,
weights[5][3 - k],
num_index_left);
}
num_index_left -= 2;
num_index_left += 2;
*dout++ = sum;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -1558,9 +1557,7 @@ inline void compute_all_padding_pre(float* dout,
*dout++ = sum;
}
// right
int num_index_right = 4 - pad_right;
LOG(INFO) << "pad_right_new: " << pad_right_new << ", num_index_right: " << num_index_right;
for (int i = 0; i < pad_right_new; i++) {
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
......@@ -1583,15 +1580,14 @@ inline void compute_all_padding_mid(float* dout,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
// left
int tmp = num - 1;
int num_index_left = 4 - pad_left;
for (int i = pad_left_new; i > 0; i--) {
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
......@@ -1601,7 +1597,7 @@ inline void compute_all_padding_mid(float* dout,
weights[5][tmp - k],
num_index_left);
}
num_index_left -= 2;
num_index_left += 2;
*dout++ = sum;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -1684,19 +1680,19 @@ inline void compute_all_padding_mid(float* dout,
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right_new; i++) {
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][4 - pad_right],
4 - pad_right);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
pad_right += 2;
num_index_right -= 2;
*dout++ = sum;
}
}
......@@ -1708,8 +1704,8 @@ inline void compute_all_padding_mid_out2(float* dout0,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -1717,24 +1713,24 @@ inline void compute_all_padding_mid_out2(float* dout0,
int tmp2 = num + 1;
int tmp = num - 1;
// left
for (int i = pad_left_new; i > 0; i--) {
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - pad_left);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - pad_left);
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - pad_left);
num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - pad_left);
num_index_left);
}
pad_left -= 2;
num_index_left += 2;
*dout0++ = sum;
*dout1++ = sum1;
}
......@@ -1835,26 +1831,26 @@ inline void compute_all_padding_mid_out2(float* dout0,
*dout1++ = sum1;
}
// right
for (int i = 0; i < pad_right_new; i++) {
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right);
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][4 - pad_right],
4 - pad_right);
weights[tmp - k][num_index_right],
num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][4 - pad_right],
4 - pad_right);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp2 - k] += 2;
}
pad_right += 2;
num_index_right -= 2;
din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2;
*dout0++ = sum;
......@@ -1869,22 +1865,22 @@ inline void compute_all_padding_post(float* dout,
bool odds,
int pad_left,
int pad_right,
int pad_left_new,
int pad_right_new,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
// left
int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) {
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - pad_left);
din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - pad_left);
num_index_left);
}
pad_left -= 2;
*dout++ = sum;
......@@ -2101,20 +2097,19 @@ inline void compute_all_padding_post(float* dout,
*dout++ = sum;
}
// right
int num_index = 4 - pad_right;
for (int i = 0; i < pad_right_new; i++) {
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index], num_index);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][num_index],
num_index);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index -= 2;
num_index_right -= 2;
*dout++ = sum;
}
}
......@@ -2161,11 +2156,12 @@ void conv_depthwise_5x5s2_bias(float* dout,
int remain = loop_w & 3;
int n_top_h = 4 - pad_top;
int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 0 : 1);
if (n_right_w == 0) {
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w += 2;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
......@@ -2214,10 +2210,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -2254,10 +2250,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -2284,10 +2280,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -2311,10 +2307,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias,
weights_vec,
odds_w,
pad_left,
n_right_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -2338,20 +2334,23 @@ inline void compute_all_padding_pre_relu(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
int tmp_index = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout++ = sum > 0.f ? sum : 0.f;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -2582,16 +2581,17 @@ inline void compute_all_padding_pre_relu(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
weights[3 - k][num_index_right],
num_index_right);
din_ptr_arr[tmp_index - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : 0.f;
}
}
......@@ -2603,20 +2603,23 @@ inline void compute_all_padding_mid_relu(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout++ = sum > 0.f ? sum : 0.f;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -2702,16 +2705,17 @@ inline void compute_all_padding_mid_relu(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : 0.f;
}
}
......@@ -2724,6 +2728,8 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -2733,21 +2739,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f;
}
......@@ -2851,23 +2858,24 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp2 - k] += 2;
}
num_index_right -= 2;
din_ptr_arr[0] += 2;
din_ptr_arr[0] += 2;
*dout0++ = sum > 0.f ? sum : 0.f;
......@@ -2882,6 +2890,8 @@ inline void compute_all_padding_post_relu(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -2889,14 +2899,15 @@ inline void compute_all_padding_post_relu(float* dout,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k],
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
pad_left -= 2;
*dout++ = sum > 0.f ? sum : 0.f;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -2913,7 +2924,7 @@ inline void compute_all_padding_post_relu(float* dout,
#ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -2932,7 +2943,7 @@ inline void compute_all_padding_post_relu(float* dout,
#else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -2949,14 +2960,14 @@ inline void compute_all_padding_post_relu(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[3] -= 8;
din_ptr_arr[num] -= 8;
break;
case 1:
#ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -2976,8 +2987,8 @@ inline void compute_all_padding_post_relu(float* dout,
#else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -2995,15 +3006,15 @@ inline void compute_all_padding_post_relu(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[2] -= 8;
din_ptr_arr[tmp] -= 8;
break;
case 2:
#ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -3024,9 +3035,9 @@ inline void compute_all_padding_post_relu(float* dout,
#else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -3045,7 +3056,7 @@ inline void compute_all_padding_post_relu(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[1] -= 8;
din_ptr_arr[tmp - 1] -= 8;
break;
case 3:
#ifdef __aarch64__
......@@ -3102,35 +3113,36 @@ inline void compute_all_padding_post_relu(float* dout,
din_ptr_arr[0] -= 8;
break;
default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support";
LOG(FATAL) << "This num: " << (num + 1) << " does not support";
}
}
// clang-format on
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[num] += 2;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(
din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i] += 2;
din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i] += 2;
}
*dout++ = sum > 0.f ? sum : 0.f;
}
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k],
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k] += 2;
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : 0.f;
}
}
......@@ -3176,7 +3188,19 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
int cnt = loop_w >> 2;
int remain = loop_w & 3;
int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom;
int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size;
......@@ -3223,6 +3247,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -3262,6 +3288,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -3291,6 +3319,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -3317,6 +3347,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -3341,6 +3373,8 @@ inline void compute_all_padding_pre_relu6(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -3351,14 +3385,15 @@ inline void compute_all_padding_pre_relu6(float* dout,
// left
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -3597,16 +3632,17 @@ inline void compute_all_padding_pre_relu6(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
weights[3 - k][num_index_right],
num_index_right);
din_ptr_arr[tmp_index - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
}
......@@ -3619,6 +3655,8 @@ inline void compute_all_padding_mid_relu6(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -3629,14 +3667,15 @@ inline void compute_all_padding_mid_relu6(float* dout,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -3706,7 +3745,7 @@ inline void compute_all_padding_mid_relu6(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
din_ptr_arr[0] -= 8;
}
// clang-format on
// remain
......@@ -3724,16 +3763,17 @@ inline void compute_all_padding_mid_relu6(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
}
......@@ -3748,6 +3788,8 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -3761,21 +3803,22 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
// clang-format off
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 -k],
weights[tmp -k],
num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
}
......@@ -3880,23 +3923,24 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1]++;
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp2 - k] += 2;
}
num_index_right -= 2;
din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
......@@ -3912,6 +3956,8 @@ inline void compute_all_padding_post_relu6(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -3922,14 +3968,15 @@ inline void compute_all_padding_post_relu6(float* dout,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k],
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
pad_left -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -3946,7 +3993,7 @@ inline void compute_all_padding_post_relu6(float* dout,
#ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -3966,7 +4013,7 @@ inline void compute_all_padding_post_relu6(float* dout,
#else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -3984,14 +4031,14 @@ inline void compute_all_padding_post_relu6(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[3] -= 8;
din_ptr_arr[num] -= 8;
break;
case 1:
#ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -4012,8 +4059,8 @@ inline void compute_all_padding_post_relu6(float* dout,
#else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -4032,15 +4079,15 @@ inline void compute_all_padding_post_relu6(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[2] -= 8;
din_ptr_arr[tmp] -= 8;
break;
case 2:
#ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -4062,9 +4109,9 @@ inline void compute_all_padding_post_relu6(float* dout,
#else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -4084,7 +4131,7 @@ inline void compute_all_padding_post_relu6(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[1] -= 8;
din_ptr_arr[tmp - 1] -= 8;
break;
case 3:
#ifdef __aarch64__
......@@ -4162,16 +4209,17 @@ inline void compute_all_padding_post_relu6(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k],
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k] += 2;
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
}
......@@ -4218,7 +4266,19 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
int cnt = loop_w >> 2;
int remain = loop_w & 3;
int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom;
int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size;
......@@ -4266,6 +4326,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -4306,6 +4368,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -4336,6 +4400,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -4363,6 +4429,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -4387,6 +4455,8 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -4397,14 +4467,15 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
// left
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[5][3 - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -4651,22 +4722,19 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k],
0.f,
weights[3 - k][3 - i],
3 - i);
weights[3 - k][num_index_right],
num_index_right);
din_ptr_arr[tmp_index - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
}
inline void compute_all_padding_mid_leakyRelu(float* dout,
const float** din_ptr_arr,
......@@ -4677,6 +4745,8 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -4687,14 +4757,15 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -4784,16 +4855,17 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
}
......@@ -4807,6 +4879,8 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -4819,21 +4893,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
num_index_left += 2;
*dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
}
......@@ -4943,23 +5018,24 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp2 - k] += 2;
}
num_index_right -= 2;
din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2;
*dout0++ = sum > 0.f ? sum : sum * scale[0];
......@@ -4975,6 +5051,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
bool odds,
int pad_left,
int pad_right,
int num_index_left,
int num_index_right,
int cnt,
int remain,
int num) {
......@@ -4985,14 +5063,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
int tmp = num - 1;
for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k],
sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[5][tmp - k],
4 - i);
num_index_left);
}
pad_left -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
if (odds) { // origin pad_left is odds, such as ori_pad_left=1
......@@ -5009,7 +5088,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -5031,7 +5110,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -5049,14 +5128,14 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[3] -= 8;
din_ptr_arr[num] -= 8;
break;
case 1:
#ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -5079,8 +5158,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -5099,15 +5178,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[2] -= 8;
din_ptr_arr[tmp] -= 8;
break;
case 2:
#ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -5131,9 +5210,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -5153,7 +5232,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[1] -= 8;
din_ptr_arr[tmp - 1] -= 8;
break;
case 3:
#ifdef __aarch64__
......@@ -5221,28 +5300,29 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[num] += 2;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(
din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i] += 2;
din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i] += 2;
}
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
// right
for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3] += 2;
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k],
sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k],
0.f,
weights[tmp - k][3 - i],
3 - i);
din_ptr_arr[2 - k] += 2;
weights[tmp - k][num_index_right],
num_index_right);
din_ptr_arr[tmp - k] += 2;
}
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
}
......@@ -5289,7 +5369,19 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
int cnt = loop_w >> 2;
int remain = loop_w & 3;
int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom;
int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size;
......@@ -5337,6 +5429,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......@@ -5377,6 +5471,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -5407,6 +5503,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
4);
......@@ -5434,6 +5532,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w,
pad_left_new,
pad_right_new,
n_left_w,
n_right_w,
cnt,
remain,
h_in_num);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册