提交 589e852c 编写于 作者: C chenjiaoAngel

fix relu relu6 error. test=develop

上级 01455e09
...@@ -1320,14 +1320,13 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1320,14 +1320,13 @@ inline void compute_all_padding_pre(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new, int num_index_left,
int pad_right_new, int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
int num_index_left = 4 - pad_left; for (int i = pad_left; i > 0; i--) {
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left); din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1337,7 +1336,7 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1337,7 +1336,7 @@ inline void compute_all_padding_pre(float* dout,
weights[5][3 - k], weights[5][3 - k],
num_index_left); num_index_left);
} }
num_index_left -= 2; num_index_left += 2;
*dout++ = sum; *dout++ = sum;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -1558,9 +1557,7 @@ inline void compute_all_padding_pre(float* dout, ...@@ -1558,9 +1557,7 @@ inline void compute_all_padding_pre(float* dout,
*dout++ = sum; *dout++ = sum;
} }
// right // right
int num_index_right = 4 - pad_right; for (int i = 0; i < pad_right; i++) {
LOG(INFO) << "pad_right_new: " << pad_right_new << ", num_index_right: " << num_index_right;
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right); din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
...@@ -1583,15 +1580,14 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1583,15 +1580,14 @@ inline void compute_all_padding_mid(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new, int num_index_left,
int pad_right_new, int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
int tmp = num - 1; int tmp = num - 1;
int num_index_left = 4 - pad_left; for (int i = pad_left; i > 0; i--) {
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
...@@ -1601,7 +1597,7 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1601,7 +1597,7 @@ inline void compute_all_padding_mid(float* dout,
weights[5][tmp - k], weights[5][tmp - k],
num_index_left); num_index_left);
} }
num_index_left -= 2; num_index_left += 2;
*dout++ = sum; *dout++ = sum;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -1684,19 +1680,19 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1684,19 +1680,19 @@ inline void compute_all_padding_mid(float* dout,
*dout++ = sum; *dout++ = sum;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][4 - pad_right], weights[tmp - k][num_index_right],
4 - pad_right); num_index_right);
din_ptr_arr[tmp - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
pad_right += 2; num_index_right -= 2;
*dout++ = sum; *dout++ = sum;
} }
} }
...@@ -1708,8 +1704,8 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1708,8 +1704,8 @@ inline void compute_all_padding_mid_out2(float* dout0,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new, int num_index_left,
int pad_right_new, int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -1717,24 +1713,24 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1717,24 +1713,24 @@ inline void compute_all_padding_mid_out2(float* dout0,
int tmp2 = num + 1; int tmp2 = num + 1;
int tmp = num - 1; int tmp = num - 1;
// left // left
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - pad_left); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre( float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - pad_left); din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - pad_left); num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - pad_left); num_index_left);
} }
pad_left -= 2; num_index_left += 2;
*dout0++ = sum; *dout0++ = sum;
*dout1++ = sum1; *dout1++ = sum1;
} }
...@@ -1835,26 +1831,26 @@ inline void compute_all_padding_mid_out2(float* dout0, ...@@ -1835,26 +1831,26 @@ inline void compute_all_padding_mid_out2(float* dout0,
*dout1++ = sum1; *dout1++ = sum1;
} }
// right // right
for (int i = 0; i < pad_right_new; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post( float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][4 - pad_right], 4 - pad_right); din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2; din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][4 - pad_right], weights[tmp - k][num_index_right],
4 - pad_right); num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][4 - pad_right], weights[tmp - k][num_index_right],
4 - pad_right); num_index_right);
din_ptr_arr[tmp2 - k] += 2; din_ptr_arr[tmp2 - k] += 2;
} }
pad_right += 2; num_index_right -= 2;
din_ptr_arr[1] += 2; din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2; din_ptr_arr[0] += 2;
*dout0++ = sum; *dout0++ = sum;
...@@ -1869,22 +1865,22 @@ inline void compute_all_padding_post(float* dout, ...@@ -1869,22 +1865,22 @@ inline void compute_all_padding_post(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int pad_left_new, int num_index_left,
int pad_right_new, int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
// left // left
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - pad_left); din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - pad_left); num_index_left);
} }
pad_left -= 2; pad_left -= 2;
*dout++ = sum; *dout++ = sum;
...@@ -2101,20 +2097,19 @@ inline void compute_all_padding_post(float* dout, ...@@ -2101,20 +2097,19 @@ inline void compute_all_padding_post(float* dout,
*dout++ = sum; *dout++ = sum;
} }
// right // right
int num_index = 4 - pad_right; for (int i = 0; i < pad_right; i++) {
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][num_index], num_index); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][num_index], weights[tmp - k][num_index_right],
num_index); num_index_right);
din_ptr_arr[tmp - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index -= 2; num_index_right -= 2;
*dout++ = sum; *dout++ = sum;
} }
} }
...@@ -2161,11 +2156,12 @@ void conv_depthwise_5x5s2_bias(float* dout, ...@@ -2161,11 +2156,12 @@ void conv_depthwise_5x5s2_bias(float* dout,
int remain = loop_w & 3; int remain = loop_w & 3;
int n_top_h = 4 - pad_top; int n_top_h = 4 - pad_top;
int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3); int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 0 : 1); int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
if (n_right_w == 0) { int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++; remain++;
pad_right_new--; pad_right_new--;
n_right_w += 2; n_right_w -= 2;
} }
if (n_bottom_h == 4) { if (n_bottom_h == 4) {
loop_h++; loop_h++;
...@@ -2214,10 +2210,10 @@ void conv_depthwise_5x5s2_bias(float* dout, ...@@ -2214,10 +2210,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias, vbias,
weights_vec, weights_vec,
odds_w, odds_w,
pad_left,
n_right_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -2254,10 +2250,10 @@ void conv_depthwise_5x5s2_bias(float* dout, ...@@ -2254,10 +2250,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias, vbias,
weights_vec, weights_vec,
odds_w, odds_w,
pad_left,
n_right_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -2284,10 +2280,10 @@ void conv_depthwise_5x5s2_bias(float* dout, ...@@ -2284,10 +2280,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias, vbias,
weights_vec, weights_vec,
odds_w, odds_w,
pad_left,
n_right_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -2311,10 +2307,10 @@ void conv_depthwise_5x5s2_bias(float* dout, ...@@ -2311,10 +2307,10 @@ void conv_depthwise_5x5s2_bias(float* dout,
vbias, vbias,
weights_vec, weights_vec,
odds_w, odds_w,
pad_left,
n_right_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -2338,20 +2334,23 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2338,20 +2334,23 @@ inline void compute_all_padding_pre_relu(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp_index = num - 1; int tmp_index = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k], weights[3 - k],
0.f, 0.f,
weights[5][3 - k], weights[5][3 - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -2582,16 +2581,17 @@ inline void compute_all_padding_pre_relu(float* dout, ...@@ -2582,16 +2581,17 @@ inline void compute_all_padding_pre_relu(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k], weights[3 - k],
0.f, 0.f,
weights[3 - k][3 - i], weights[3 - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp_index - k] += 2; din_ptr_arr[tmp_index - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
} }
...@@ -2603,20 +2603,23 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2603,20 +2603,23 @@ inline void compute_all_padding_mid_relu(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -2702,16 +2705,17 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -2702,16 +2705,17 @@ inline void compute_all_padding_mid_relu(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
} }
...@@ -2724,6 +2728,8 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2724,6 +2728,8 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -2733,21 +2739,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2733,21 +2739,22 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre( float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout0++ = sum > 0.f ? sum : 0.f; *dout0++ = sum > 0.f ? sum : 0.f;
*dout1++ = sum1 > 0.f ? sum1 : 0.f; *dout1++ = sum1 > 0.f ? sum1 : 0.f;
} }
...@@ -2851,23 +2858,24 @@ inline void compute_all_padding_mid_relu_out2(float* dout0, ...@@ -2851,23 +2858,24 @@ inline void compute_all_padding_mid_relu_out2(float* dout0,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post( float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2; din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp2 - k] += 2; din_ptr_arr[tmp2 - k] += 2;
} }
num_index_right -= 2;
din_ptr_arr[0] += 2; din_ptr_arr[0] += 2;
din_ptr_arr[0] += 2; din_ptr_arr[0] += 2;
*dout0++ = sum > 0.f ? sum : 0.f; *dout0++ = sum > 0.f ? sum : 0.f;
...@@ -2882,6 +2890,8 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2882,6 +2890,8 @@ inline void compute_all_padding_post_relu(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -2889,14 +2899,15 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2889,14 +2899,15 @@ inline void compute_all_padding_post_relu(float* dout,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
pad_left -= 2;
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -2913,7 +2924,7 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2913,7 +2924,7 @@ inline void compute_all_padding_post_relu(float* dout,
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]), [din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]), [wr5] "w"(weights[5]),
...@@ -2932,7 +2943,7 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2932,7 +2943,7 @@ inline void compute_all_padding_post_relu(float* dout,
#else #else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]), [din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]), [wr5] "w"(weights[5]),
...@@ -2949,14 +2960,14 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2949,14 +2960,14 @@ inline void compute_all_padding_post_relu(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[3] -= 8; din_ptr_arr[num] -= 8;
break; break;
case 1: case 1:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]), [din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[3]), [din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -2976,8 +2987,8 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2976,8 +2987,8 @@ inline void compute_all_padding_post_relu(float* dout,
#else #else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]), [din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[3]), [din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -2995,15 +3006,15 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -2995,15 +3006,15 @@ inline void compute_all_padding_post_relu(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[2] -= 8; din_ptr_arr[tmp] -= 8;
break; break;
case 2: case 2:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]), [din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[2]), [din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[3]), [din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -3024,9 +3035,9 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -3024,9 +3035,9 @@ inline void compute_all_padding_post_relu(float* dout,
#else #else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]), [din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[2]), [din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[3]), [din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -3045,7 +3056,7 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -3045,7 +3056,7 @@ inline void compute_all_padding_post_relu(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[1] -= 8; din_ptr_arr[tmp - 1] -= 8;
break; break;
case 3: case 3:
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -3102,35 +3113,36 @@ inline void compute_all_padding_post_relu(float* dout, ...@@ -3102,35 +3113,36 @@ inline void compute_all_padding_post_relu(float* dout,
din_ptr_arr[0] -= 8; din_ptr_arr[0] -= 8;
break; break;
default: default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support"; LOG(FATAL) << "This num: " << (num + 1) << " does not support";
} }
} }
// clang-format on // clang-format on
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3] += 2; din_ptr_arr[num] += 2;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post( sum += compute_one_data_post(
din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i] += 2; din_ptr_arr[tmp - i] += 2;
} }
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[3] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[2 - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
} }
...@@ -3176,7 +3188,19 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, ...@@ -3176,7 +3188,19 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int n_top_h = 4 - pad_top; int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom; int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
...@@ -3223,6 +3247,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, ...@@ -3223,6 +3247,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -3262,6 +3288,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, ...@@ -3262,6 +3288,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -3291,6 +3319,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, ...@@ -3291,6 +3319,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -3317,6 +3347,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout, ...@@ -3317,6 +3347,8 @@ void conv_depthwise_5x5s2_bias_relu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -3341,6 +3373,8 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -3341,6 +3373,8 @@ inline void compute_all_padding_pre_relu6(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3351,14 +3385,15 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -3351,14 +3385,15 @@ inline void compute_all_padding_pre_relu6(float* dout,
// left // left
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k], weights[3 - k],
0.f, 0.f,
weights[5][3 - k], weights[5][3 - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -3597,16 +3632,17 @@ inline void compute_all_padding_pre_relu6(float* dout, ...@@ -3597,16 +3632,17 @@ inline void compute_all_padding_pre_relu6(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k], weights[3 - k],
0.f, 0.f,
weights[3 - k][3 - i], weights[3 - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp_index - k] += 2; din_ptr_arr[tmp_index - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
} }
...@@ -3619,6 +3655,8 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3619,6 +3655,8 @@ inline void compute_all_padding_mid_relu6(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3629,14 +3667,15 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3629,14 +3667,15 @@ inline void compute_all_padding_mid_relu6(float* dout,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -3706,7 +3745,7 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3706,7 +3745,7 @@ inline void compute_all_padding_mid_relu6(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[0] -= 4; din_ptr_arr[0] -= 8;
} }
// clang-format on // clang-format on
// remain // remain
...@@ -3724,16 +3763,17 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -3724,16 +3763,17 @@ inline void compute_all_padding_mid_relu6(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
} }
...@@ -3748,6 +3788,8 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3748,6 +3788,8 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3761,21 +3803,22 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3761,21 +3803,22 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
// clang-format off // clang-format off
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre( float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 -k], sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp -k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
*dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f;
} }
...@@ -3880,23 +3923,24 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0, ...@@ -3880,23 +3923,24 @@ inline void compute_all_padding_mid_relu6_out2(float* dout0,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post( float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1]++; din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp2 - k] += 2; din_ptr_arr[tmp2 - k] += 2;
} }
num_index_right -= 2;
din_ptr_arr[1] += 2; din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2; din_ptr_arr[0] += 2;
*dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
...@@ -3912,6 +3956,8 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3912,6 +3956,8 @@ inline void compute_all_padding_post_relu6(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -3922,14 +3968,15 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3922,14 +3968,15 @@ inline void compute_all_padding_post_relu6(float* dout,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
pad_left -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -3946,7 +3993,7 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3946,7 +3993,7 @@ inline void compute_all_padding_post_relu6(float* dout,
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6 asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]), [din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]), [wr5] "w"(weights[5]),
...@@ -3966,7 +4013,7 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3966,7 +4013,7 @@ inline void compute_all_padding_post_relu6(float* dout,
#else #else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6 asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]), [din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]), [wr5] "w"(weights[5]),
...@@ -3984,14 +4031,14 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -3984,14 +4031,14 @@ inline void compute_all_padding_post_relu6(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[3] -= 8; din_ptr_arr[num] -= 8;
break; break;
case 1: case 1:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6 asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]), [din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[3]), [din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -4012,8 +4059,8 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -4012,8 +4059,8 @@ inline void compute_all_padding_post_relu6(float* dout,
#else #else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6 asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]), [din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[3]), [din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -4032,15 +4079,15 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -4032,15 +4079,15 @@ inline void compute_all_padding_post_relu6(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[2] -= 8; din_ptr_arr[tmp] -= 8;
break; break;
case 2: case 2:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6 asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]), [din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[2]), [din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[3]), [din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -4062,9 +4109,9 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -4062,9 +4109,9 @@ inline void compute_all_padding_post_relu6(float* dout,
#else #else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6 asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]), [din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[2]), [din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[3]), [din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -4084,7 +4131,7 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -4084,7 +4131,7 @@ inline void compute_all_padding_post_relu6(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[1] -= 8; din_ptr_arr[tmp - 1] -= 8;
break; break;
case 3: case 3:
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -4162,16 +4209,17 @@ inline void compute_all_padding_post_relu6(float* dout, ...@@ -4162,16 +4209,17 @@ inline void compute_all_padding_post_relu6(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[3] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[2 - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
} }
...@@ -4218,7 +4266,19 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, ...@@ -4218,7 +4266,19 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int n_top_h = 4 - pad_top; int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom; int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
...@@ -4266,6 +4326,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, ...@@ -4266,6 +4326,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -4306,6 +4368,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, ...@@ -4306,6 +4368,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -4336,6 +4400,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, ...@@ -4336,6 +4400,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -4363,6 +4429,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout, ...@@ -4363,6 +4429,8 @@ void conv_depthwise_5x5s2_bias_relu6(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -4387,6 +4455,8 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -4387,6 +4455,8 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4397,14 +4467,15 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -4397,14 +4467,15 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
// left // left
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[4], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], sum += compute_one_data_pre(din_ptr_arr[tmp_index - k],
weights[3 - k], weights[3 - k],
0.f, 0.f,
weights[5][3 - k], weights[5][3 - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -4651,22 +4722,19 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -4651,22 +4722,19 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); din_ptr_arr[num], weights[4], bias[0], weights[4][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp_index - k], sum += compute_one_data_post(din_ptr_arr[tmp_index - k],
weights[3 - k], weights[3 - k],
0.f, 0.f,
weights[3 - k][3 - i], weights[3 - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp_index - k] += 2; din_ptr_arr[tmp_index - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0];
}
} }
inline void compute_all_padding_mid_leakyRelu(float* dout, inline void compute_all_padding_mid_leakyRelu(float* dout,
const float** din_ptr_arr, const float** din_ptr_arr,
...@@ -4677,6 +4745,8 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4677,6 +4745,8 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4687,14 +4757,15 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4687,14 +4757,15 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -4784,16 +4855,17 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -4784,16 +4855,17 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[num] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
} }
...@@ -4807,6 +4879,8 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4807,6 +4879,8 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4819,21 +4893,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4819,21 +4893,22 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[6][0], num_index_left);
float sum1 = compute_one_data_pre( float sum1 = compute_one_data_pre(
din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_pre(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
num_index_left += 2;
*dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout0++ = sum > 0.f ? sum : sum * scale[0];
*dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0];
} }
...@@ -4943,23 +5018,24 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, ...@@ -4943,23 +5018,24 @@ inline void compute_all_padding_mid_leakyRelu_out2(float* dout0,
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
float sum1 = compute_one_data_post( float sum1 = compute_one_data_post(
din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[tmp1], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[tmp1] += 2; din_ptr_arr[tmp1] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k], sum1 += compute_one_data_post(din_ptr_arr[tmp2 - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[tmp2 - k] += 2; din_ptr_arr[tmp2 - k] += 2;
} }
num_index_right -= 2;
din_ptr_arr[1] += 2; din_ptr_arr[1] += 2;
din_ptr_arr[0] += 2; din_ptr_arr[0] += 2;
*dout0++ = sum > 0.f ? sum : sum * scale[0]; *dout0++ = sum > 0.f ? sum : sum * scale[0];
...@@ -4975,6 +5051,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4975,6 +5051,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
bool odds, bool odds,
int pad_left, int pad_left,
int pad_right, int pad_right,
int num_index_left,
int num_index_right,
int cnt, int cnt,
int remain, int remain,
int num) { int num) {
...@@ -4985,14 +5063,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -4985,14 +5063,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
int tmp = num - 1; int tmp = num - 1;
for (int i = pad_left; i > 0; i--) { for (int i = pad_left; i > 0; i--) {
float sum = compute_one_data_pre( float sum = compute_one_data_pre(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); din_ptr_arr[num], weights[num], bias[0], weights[5][num], num_index_left);
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[2 - k], sum += compute_one_data_pre(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[5][tmp - k], weights[5][tmp - k],
4 - i); num_index_left);
} }
pad_left -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
if (odds) { // origin pad_left is odds, such as ori_pad_left=1 if (odds) { // origin pad_left is odds, such as ori_pad_left=1
...@@ -5009,7 +5088,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5009,7 +5088,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]), [din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]), [wr5] "w"(weights[5]),
...@@ -5031,7 +5110,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5031,7 +5110,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#else #else
asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[3]), [din_ptr0] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]), [wr5] "w"(weights[5]),
...@@ -5049,14 +5128,14 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5049,14 +5128,14 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[3] -= 8; din_ptr_arr[num] -= 8;
break; break;
case 1: case 1:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]), [din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[3]), [din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -5079,8 +5158,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5079,8 +5158,8 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#else #else
asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[2]), [din_ptr0] "+r"(din_ptr_arr[tmp]),
[din_ptr1] "+r"(din_ptr_arr[3]), [din_ptr1] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -5099,15 +5178,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5099,15 +5178,15 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[2] -= 8; din_ptr_arr[tmp] -= 8;
break; break;
case 2: case 2:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]), [din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[2]), [din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[3]), [din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -5131,9 +5210,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5131,9 +5210,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
#else #else
asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[1]), [din_ptr0] "+r"(din_ptr_arr[tmp - 1]),
[din_ptr1] "+r"(din_ptr_arr[2]), [din_ptr1] "+r"(din_ptr_arr[tmp]),
[din_ptr2] "+r"(din_ptr_arr[3]), [din_ptr2] "+r"(din_ptr_arr[num]),
[dout_ptr] "+r"(dout) [dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]), : [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]), [wr1] "w"(weights[1]),
...@@ -5153,7 +5232,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5153,7 +5232,7 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
din_ptr_arr[1] -= 8; din_ptr_arr[tmp - 1] -= 8;
break; break;
case 3: case 3:
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -5221,28 +5300,29 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -5221,28 +5300,29 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
// remain // remain
for (int w = 0; w < remain; w++) { for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3] += 2; din_ptr_arr[num] += 2;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
sum += compute_one_data_post( sum += compute_one_data_post(
din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i] += 2; din_ptr_arr[tmp - i] += 2;
} }
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
// right // right
for (int i = 0; i < pad_right; i++) { for (int i = 0; i < pad_right; i++) {
float sum = compute_one_data_post( float sum = compute_one_data_post(
din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); din_ptr_arr[num], weights[num], bias[0], weights[num][num_index_right], num_index_right);
din_ptr_arr[3] += 2; din_ptr_arr[num] += 2;
for (int k = 0; k < num; k++) { for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[2 - k], sum += compute_one_data_post(din_ptr_arr[tmp - k],
weights[tmp - k], weights[tmp - k],
0.f, 0.f,
weights[tmp - k][3 - i], weights[tmp - k][num_index_right],
3 - i); num_index_right);
din_ptr_arr[2 - k] += 2; din_ptr_arr[tmp - k] += 2;
} }
num_index_right -= 2;
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
} }
...@@ -5289,7 +5369,19 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, ...@@ -5289,7 +5369,19 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
int cnt = loop_w >> 2; int cnt = loop_w >> 2;
int remain = loop_w & 3; int remain = loop_w & 3;
int n_top_h = 4 - pad_top; int n_top_h = 4 - pad_top;
int n_bottom_h = 4 -pad_bottom; int n_bottom_h = odds_h ? (4 - pad_bottom) : ((hin % 2) ? 4 : 3);
int n_right_w = odds_w ? pad_right : ((win % 2) ? 4 : 3);
int n_left_w = 4 - pad_left;
if (n_right_w == 4) {
remain++;
pad_right_new--;
n_right_w -= 2;
}
if (n_bottom_h == 4) {
loop_h++;
pad_bottom_new--;
n_bottom_h -= 2;
}
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; n++) { for (int n = 0; n < num; n++) {
const float* din_batch = din + n * in_channel_size; const float* din_batch = din + n * in_channel_size;
...@@ -5337,6 +5429,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, ...@@ -5337,6 +5429,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
...@@ -5377,6 +5471,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, ...@@ -5377,6 +5471,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -5407,6 +5503,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, ...@@ -5407,6 +5503,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
4); 4);
...@@ -5434,6 +5532,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, ...@@ -5434,6 +5532,8 @@ void conv_depthwise_5x5s2_bias_leakyRelu(float* dout,
odds_w, odds_w,
pad_left_new, pad_left_new,
pad_right_new, pad_right_new,
n_left_w,
n_right_w,
cnt, cnt,
remain, remain,
h_in_num); h_in_num);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册