未验证 提交 b88e45a1 编写于 作者: C cc 提交者: GitHub

Pool2d supports adaptive param (#3448)

*Pool2d supports adaptive param
上级 913c1caf
...@@ -21,6 +21,17 @@ namespace paddle { ...@@ -21,6 +21,17 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
int AdaptStartIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
floor(static_cast<double>(ph * input_size) / output_size));
}
int AdaptEndIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
ceil(static_cast<double>((ph + 1) * input_size) / output_size));
}
void pooling_basic(const float* din, void pooling_basic(const float* din,
float* dout, float* dout,
int num, int num,
...@@ -88,15 +99,27 @@ void pooling_basic(const float* din, ...@@ -88,15 +99,27 @@ void pooling_basic(const float* din,
#pragma omp parallel for #pragma omp parallel for
for (int ind_c = 0; ind_c < chin; ++ind_c) { for (int ind_c = 0; ind_c < chin; ++ind_c) {
for (int ind_h = 0; ind_h < hout; ++ind_h) { for (int ind_h = 0; ind_h < hout; ++ind_h) {
int sh = ind_h * stride_h; int sh, eh;
int eh = sh + kernel_h; if (adaptive) {
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; sh = AdaptStartIndex(ind_h, hin, hout);
eh = (eh - pad_h) > hin ? hin : eh - pad_h; eh = AdaptEndIndex(ind_h, hin, hout);
} else {
sh = ind_h * stride_h;
eh = sh + kernel_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > hin ? hin : eh - pad_h;
}
for (int ind_w = 0; ind_w < wout; ++ind_w) { for (int ind_w = 0; ind_w < wout; ++ind_w) {
int sw = ind_w * stride_w; int sw, ew;
int ew = sw + kernel_w; if (adaptive) {
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; sw = AdaptStartIndex(ind_w, win, wout);
ew = (ew - pad_w) > win ? win : ew - pad_w; ew = AdaptEndIndex(ind_w, win, wout);
} else {
sw = ind_w * stride_w;
ew = sw + kernel_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > win ? win : ew - pad_w;
}
float result = static_cast<float>(0); float result = static_cast<float>(0);
int dst_ind = (ind_n * chout + ind_c) * size_channel_out + int dst_ind = (ind_n * chout + ind_c) * size_channel_out +
ind_h * wout + ind_w; ind_h * wout + ind_w;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册