diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index c3652217ededa10b57e211ba7f5d3dc76e235978..1817e934cc460fdff6f18ec7491838ff1a5ce640 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -2224,7 +2224,13 @@ void pooling3x3s2p1_max(const float* din, w_unroll_size -= 1; w_unroll_remian = wout - w_unroll_size * 4; } - float32x4_t vmin = vdupq_n_f32(std::numeric_limits::lowest()); + int w_needed = wout * 2 + 1; + int need_right = w_needed - win - pad_right; + int w_2 = need_right > 0 ? w_unroll_remian : w_unroll_remian + 1; + w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2; + need_right = wout > 1 ? need_right : 0; + float minval = std::numeric_limits::lowest(); + float32x4_t vmin = vdupq_n_f32(minval); for (int n = 0; n < num; ++n) { float* data_out_batch = data_out + n * chout * size_channel_out; @@ -2263,6 +2269,11 @@ void pooling3x3s2p1_max(const float* din, break; } } + + auto pr0 = dr0; + auto pr1 = dr1; + auto pr2 = dr2; + int cnt_num = w_unroll_size; if (w_unroll_size > 0) { #ifdef __aarch64__ @@ -2316,27 +2327,60 @@ void pooling3x3s2p1_max(const float* din, "q11", "q15"); #endif + dr0 -= 8; dr1 -= 8; dr2 -= 8; - } - // deal with right pad - int wstart = w_unroll_size * 4 * S - P; - for (int j = 0; j < w_unroll_remian; ++j) { - int wend = std::min(wstart + K, win); - int st = wstart > 0 ? wstart : 0; - float tmp = dr0[0]; - for (int i = 0; i < wend - st; i++) { + } else { + float tmp = minval; + int left_ = std::min(2, win); + for (int i = 0; i < left_; i++) { tmp = std::max(tmp, dr0[i]); tmp = std::max(tmp, dr1[i]); tmp = std::max(tmp, dr2[i]); } - *(dr_out++) = tmp; - dr0 += S - (st - wstart); - dr1 += S - (st - wstart); - dr2 += S - (st - wstart); - wstart += S; + + dr_out[0] = tmp; + dr0++; + dr1++; + dr2++; + dr_out++; } + + for (int w = 0; w < w_2 - 1; w += 1) { + float32x4_t vr0 = vld1q_f32(dr0); + float32x4_t vr1 = vld1q_f32(dr1); + float32x4_t vr2 = vld1q_f32(dr2); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + dr_out[0] = vget_lane_f32(vmax, 0); + dr_out++; + + dr0 += 2; + dr1 += 2; + dr2 += 2; + } + + if (need_right) { + float tmp = minval; + int idx = win - 1; + tmp = std::max(tmp, std::max(pr0[idx], pr1[idx])); + tmp = std::max(tmp, pr2[idx]); + dr_out[0] = tmp; + if (win % 2) { + idx = win - 2; + tmp = std::max(tmp, std::max(pr0[idx], pr1[idx])); + tmp = std::max(tmp, pr2[idx]); + dr_out[0] = tmp; + } + } + data_out_channel += wout; } } @@ -2573,6 +2617,7 @@ void pooling3x3s2p0_max(const float* din, int wend = std::min(tmp_val + K, win) - tmp_val; float minval = std::numeric_limits::lowest(); remain = right > 0 ? remain : remain + 1; + for (int n = 0; n < num; ++n) { float* data_out_batch = data_out + n * chout * size_channel_out; const float* data_in_batch = data_in + n * chin * size_channel_in; @@ -2663,13 +2708,14 @@ void pooling3x3s2p0_max(const float* din, vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); float32x2_t vmax = vpmax_f32(vmax2, vmax2); dr_out[0] = vget_lane_f32(vmax, 0); + dr_out++; dr0 += 2; dr1 += 2; dr2 += 2; } - if (right) { - float tmp = dr0[0]; // std::numeric_limits::min(); + if (right > 0) { + float tmp = dr0[0]; for (int i = 0; i < wend; i++) { tmp = std::max(tmp, std::max(dr0[i], dr1[i])); tmp = std::max(tmp, dr2[i]);