未验证 提交 a74791b6 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

fix pooling3x3s2 max. test=develop (#4411)

* fix pooling3x3s2 max. test=develop

* fix format. test=devleop

* fix format. test=develop
上级 54a75ecb
...@@ -2224,7 +2224,13 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2224,7 +2224,13 @@ void pooling3x3s2p1_max(const float* din,
w_unroll_size -= 1; w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4; w_unroll_remian = wout - w_unroll_size * 4;
} }
float32x4_t vmin = vdupq_n_f32(std::numeric_limits<float>::lowest()); int w_needed = wout * 2 + 1;
int need_right = w_needed - win - pad_right;
int w_2 = need_right > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
need_right = wout > 1 ? need_right : 0;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmin = vdupq_n_f32(minval);
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
...@@ -2263,6 +2269,11 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2263,6 +2269,11 @@ void pooling3x3s2p1_max(const float* din,
break; break;
} }
} }
auto pr0 = dr0;
auto pr1 = dr1;
auto pr2 = dr2;
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -2316,27 +2327,60 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2316,27 +2327,60 @@ void pooling3x3s2p1_max(const float* din,
"q11", "q11",
"q15"); "q15");
#endif #endif
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
dr2 -= 8; dr2 -= 8;
} } else {
// deal with right pad float tmp = minval;
int wstart = w_unroll_size * 4 * S - P; int left_ = std::min(2, win);
for (int j = 0; j < w_unroll_remian; ++j) { for (int i = 0; i < left_; i++) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]); tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]); tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]); tmp = std::max(tmp, dr2[i]);
} }
*(dr_out++) = tmp;
dr0 += S - (st - wstart); dr_out[0] = tmp;
dr1 += S - (st - wstart); dr0++;
dr2 += S - (st - wstart); dr1++;
wstart += S; dr2++;
dr_out++;
} }
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (need_right) {
float tmp = minval;
int idx = win - 1;
tmp = std::max(tmp, std::max(pr0[idx], pr1[idx]));
tmp = std::max(tmp, pr2[idx]);
dr_out[0] = tmp;
if (win % 2) {
idx = win - 2;
tmp = std::max(tmp, std::max(pr0[idx], pr1[idx]));
tmp = std::max(tmp, pr2[idx]);
dr_out[0] = tmp;
}
}
data_out_channel += wout; data_out_channel += wout;
} }
} }
...@@ -2573,6 +2617,7 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2573,6 +2617,7 @@ void pooling3x3s2p0_max(const float* din,
int wend = std::min(tmp_val + K, win) - tmp_val; int wend = std::min(tmp_val + K, win) - tmp_val;
float minval = std::numeric_limits<float>::lowest(); float minval = std::numeric_limits<float>::lowest();
remain = right > 0 ? remain : remain + 1; remain = right > 0 ? remain : remain + 1;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in; const float* data_in_batch = data_in + n * chin * size_channel_in;
...@@ -2663,13 +2708,14 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2663,13 +2708,14 @@ void pooling3x3s2p0_max(const float* din,
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2); float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0); dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++; dr_out++;
dr0 += 2; dr0 += 2;
dr1 += 2; dr1 += 2;
dr2 += 2; dr2 += 2;
} }
if (right) { if (right > 0) {
float tmp = dr0[0]; // std::numeric_limits<float>::min(); float tmp = dr0[0];
for (int i = 0; i < wend; i++) { for (int i = 0; i < wend; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i])); tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]); tmp = std::max(tmp, dr2[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册