未验证 提交 a74791b6 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

fix pooling3x3s2 max. test=develop (#4411)

* fix pooling3x3s2 max. test=develop

* fix format. test=devleop

* fix format. test=develop
上级 54a75ecb
......@@ -2224,7 +2224,13 @@ void pooling3x3s2p1_max(const float* din,
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
float32x4_t vmin = vdupq_n_f32(std::numeric_limits<float>::lowest());
int w_needed = wout * 2 + 1;
int need_right = w_needed - win - pad_right;
int w_2 = need_right > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
need_right = wout > 1 ? need_right : 0;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmin = vdupq_n_f32(minval);
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
......@@ -2263,6 +2269,11 @@ void pooling3x3s2p1_max(const float* din,
break;
}
}
auto pr0 = dr0;
auto pr1 = dr1;
auto pr2 = dr2;
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
......@@ -2316,27 +2327,60 @@ void pooling3x3s2p1_max(const float* din,
"q11",
"q15");
#endif
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = dr0[0];
for (int i = 0; i < wend - st; i++) {
} else {
float tmp = minval;
int left_ = std::min(2, win);
for (int i = 0; i < left_; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
dr2 += S - (st - wstart);
wstart += S;
dr_out[0] = tmp;
dr0++;
dr1++;
dr2++;
dr_out++;
}
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (need_right) {
float tmp = minval;
int idx = win - 1;
tmp = std::max(tmp, std::max(pr0[idx], pr1[idx]));
tmp = std::max(tmp, pr2[idx]);
dr_out[0] = tmp;
if (win % 2) {
idx = win - 2;
tmp = std::max(tmp, std::max(pr0[idx], pr1[idx]));
tmp = std::max(tmp, pr2[idx]);
dr_out[0] = tmp;
}
}
data_out_channel += wout;
}
}
......@@ -2573,6 +2617,7 @@ void pooling3x3s2p0_max(const float* din,
int wend = std::min(tmp_val + K, win) - tmp_val;
float minval = std::numeric_limits<float>::lowest();
remain = right > 0 ? remain : remain + 1;
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
......@@ -2663,13 +2708,14 @@ void pooling3x3s2p0_max(const float* din,
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (right) {
float tmp = dr0[0]; // std::numeric_limits<float>::min();
if (right > 0) {
float tmp = dr0[0];
for (int i = 0; i < wend; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册