From b88e45a13cb2e116456379617c7a81d4fc116089 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 21 Apr 2020 16:49:43 +0800 Subject: [PATCH] Pool2d supports adaptive param (#3448) *Pool2d supports adaptive param --- lite/backends/arm/math/pooling.cc | 39 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index 0955b09d92..fdcbc7394b 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -21,6 +21,17 @@ namespace paddle { namespace lite { namespace arm { namespace math { + +int AdaptStartIndex(int ph, int input_size, int output_size) { + return static_cast( + floor(static_cast(ph * input_size) / output_size)); +} + +int AdaptEndIndex(int ph, int input_size, int output_size) { + return static_cast( + ceil(static_cast((ph + 1) * input_size) / output_size)); +} + void pooling_basic(const float* din, float* dout, int num, @@ -88,15 +99,27 @@ void pooling_basic(const float* din, #pragma omp parallel for for (int ind_c = 0; ind_c < chin; ++ind_c) { for (int ind_h = 0; ind_h < hout; ++ind_h) { - int sh = ind_h * stride_h; - int eh = sh + kernel_h; - sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; - eh = (eh - pad_h) > hin ? hin : eh - pad_h; + int sh, eh; + if (adaptive) { + sh = AdaptStartIndex(ind_h, hin, hout); + eh = AdaptEndIndex(ind_h, hin, hout); + } else { + sh = ind_h * stride_h; + eh = sh + kernel_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > hin ? hin : eh - pad_h; + } for (int ind_w = 0; ind_w < wout; ++ind_w) { - int sw = ind_w * stride_w; - int ew = sw + kernel_w; - sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; - ew = (ew - pad_w) > win ? win : ew - pad_w; + int sw, ew; + if (adaptive) { + sw = AdaptStartIndex(ind_w, win, wout); + ew = AdaptEndIndex(ind_w, win, wout); + } else { + sw = ind_w * stride_w; + ew = sw + kernel_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > win ? win : ew - pad_w; + } float result = static_cast(0); int dst_ind = (ind_n * chout + ind_c) * size_channel_out + ind_h * wout + ind_w; -- GitLab