未验证 提交 7ba5c331 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

change pooling3x3s2max padding right to intrisic. test=develop (#4039)

* change pooling3x3s2max padding right to intrisic. test=develop
上级 ebf6b4bb
...@@ -2193,7 +2193,13 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2193,7 +2193,13 @@ void pooling3x3s2p1_max(const float* din,
w_unroll_size -= 1; w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4; w_unroll_remian = wout - w_unroll_size * 4;
} }
float32x4_t vmin = vdupq_n_f32(std::numeric_limits<float>::lowest()); int w_needed = wout * 2 + 1;
int pad_right_ = w_needed - win - pad_bottom;
int w_2 = pad_right_ > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmin = vdupq_n_f32(minval);
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
...@@ -2232,6 +2238,11 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2232,6 +2238,11 @@ void pooling3x3s2p1_max(const float* din,
break; break;
} }
} }
auto pr0 = dr0;
auto pr1 = dr1;
auto pr2 = dr2;
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -2285,27 +2296,53 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2285,27 +2296,53 @@ void pooling3x3s2p1_max(const float* din,
"q11", "q11",
"q15"); "q15");
#endif #endif
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
dr2 -= 8; dr2 -= 8;
} } else {
// deal with right pad float tmp = minval;
int wstart = w_unroll_size * 4 * S - P; for (int i = 0; i < 2; i++) {
for (int j = 0; j < w_unroll_remian; ++j) { tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]); tmp = std::max(tmp, dr2[i]);
} }
*(dr_out++) = tmp;
dr0 += S - (st - wstart); dr_out[0] = tmp;
dr1 += S - (st - wstart); dr0++;
dr2 += S - (st - wstart); dr1++;
wstart += S; dr2++;
dr_out++;
}
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (pad_right_) {
float tmp = minval;
for (int i = 1; i < 3; i++) {
tmp = std::max(tmp, std::max(pr0[win - i], pr1[win - i]));
tmp = std::max(tmp, pr2[win - i]);
}
dr_out[0] = tmp;
} }
data_out_channel += wout; data_out_channel += wout;
} }
} }
...@@ -2539,6 +2576,10 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2539,6 +2576,10 @@ void pooling3x3s2p0_max(const float* din,
int remain = w_unroll_remian - 1; int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad int right = wout * 2 + 1 - win; // if need right pad
int w_2 = right > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
float minval = std::numeric_limits<float>::lowest();
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in; const float* data_in_batch = data_in + n * chin * size_channel_in;
...@@ -2592,18 +2633,24 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2592,18 +2633,24 @@ void pooling3x3s2p0_max(const float* din,
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
dr2 -= 8; dr2 -= 8;
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0; for (int w = 0; w < w_2 - 1; w += 1) {
for (int j = 0; j < w_unroll_remian; ++j) { float32x4_t vr0 = vld1q_f32(dr0);
int wend = std::min(wstart + K, rem); float32x4_t vr1 = vld1q_f32(dr1);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min(); float32x4_t vr2 = vld1q_f32(dr2);
for (int i = wstart; i < wend; i++) { vr0 = vsetq_lane_f32(minval, vr0, 3);
tmp = std::max(tmp, dr0[i]); vr1 = vsetq_lane_f32(minval, vr1, 3);
tmp = std::max(tmp, dr1[i]); vr2 = vsetq_lane_f32(minval, vr2, 3);
tmp = std::max(tmp, dr2[i]); float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
} vmax1 = vmaxq_f32(vmax1, vr2);
*(dr_out++) = tmp; float32x2_t vmax2 =
wstart += S; vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
} }
#else #else
asm volatile( asm volatile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册