diff --git a/paddle/fluid/operators/math/pooling.cc b/paddle/fluid/operators/math/pooling.cc index 8df43bb616179e2487534e0acabb71b09b87e1af..68fed9fd4ebbcb9658fb29a36f1ae3c57d0dc897 100644 --- a/paddle/fluid/operators/math/pooling.cc +++ b/paddle/fluid/operators/math/pooling.cc @@ -19,6 +19,16 @@ namespace paddle { namespace operators { namespace math { +static inline int ADAPT_START_INDEX(int ph, int input_size, int output_size) { + return static_cast( + floor(static_cast(ph * input_size) / output_size)); +} + +static inline int ADAPT_END_INDEX(int ph, int input_size, int output_size) { + return static_cast( + ceil(static_cast((ph + 1) * input_size) / output_size)); +} + /* * All tensors are in NCHW format. * Ksize, strides, paddings are two elements. These two elements represent @@ -31,7 +41,7 @@ class Pool2dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - bool exclusive, framework::Tensor* output) { + bool exclusive, bool adaptive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -54,13 +64,23 @@ class Pool2dFunctor { for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); + } else { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + } for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); + if (adaptive) { + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + } T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { @@ -68,8 +88,9 @@ class Pool2dFunctor { pool_process.compute(input_data[h * input_width + w], &ele); } } - int pool_size = exclusive ? (hend - hstart) * (wend - wstart) - : ksize_height * ksize_width; + int pool_size = (exclusive || adaptive) + ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[ph * output_width + pw] = ele; } @@ -94,7 +115,7 @@ class Pool2dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_grad_process, - bool exclusive, framework::Tensor* input_grad) { + bool exclusive, bool adaptive, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -118,15 +139,26 @@ class Pool2dGradFunctor { for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); + } else { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + } for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); - int pool_size = exclusive ? (hend - hstart) * (wend - wstart) - : ksize_height * ksize_width; + if (adaptive) { + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + } + int pool_size = (exclusive || adaptive) + ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; float scale = 1.0 / pool_size; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -251,7 +283,7 @@ class Pool3dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - bool exclusive, framework::Tensor* output) { + bool exclusive, bool adaptive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -279,17 +311,32 @@ class Pool3dFunctor { for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { - int dstart = pd * stride_depth - padding_depth; - int dend = std::min(dstart + ksize_depth, input_depth); - dstart = std::max(dstart, 0); + if (adaptive) { + int dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + int dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + } else { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + } for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); + } else { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + } for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); + if (adaptive) { + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + } int output_idx = (pd * output_height + ph) * output_width + pw; T ele = pool_process.initial(); for (int d = dstart; d < dend; ++d) { @@ -302,7 +349,7 @@ class Pool3dFunctor { } } int pool_size = - exclusive + (exclusive || adaptive) ? (dend - dstart) * (hend - hstart) * (wend - wstart) : ksize_depth * ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); @@ -330,7 +377,7 @@ class Pool3dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_grad_process, - bool exclusive, framework::Tensor* input_grad) { + bool exclusive, bool adaptive, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -359,21 +406,35 @@ class Pool3dGradFunctor { for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { - int dstart = pd * stride_depth - padding_depth; - int dend = std::min(dstart + ksize_depth, input_depth); - dstart = std::max(dstart, 0); + if (adaptive) { + int dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + int dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + } else { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + } for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); - + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); + } else { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + } for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); + if (adaptive) { + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + } int pool_size = - exclusive + (exclusive || adaptive) ? (dend - dstart) * (hend - hstart) * (wend - wstart) : ksize_depth * ksize_height * ksize_width; float scale = 1.0 / pool_size; @@ -517,8 +578,8 @@ class MaxPool2dWithIndexFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output, - framework::Tensor* mask) { + const std::vector& paddings, bool adaptive, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -541,13 +602,23 @@ class MaxPool2dWithIndexFunctor { for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); + } else { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + } for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); + if (adaptive) { + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + } T1 ele = static_cast(-FLT_MAX); int index = -1; @@ -666,17 +737,32 @@ class MaxPool3dWithIndexFunctor { for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { - int dstart = pd * stride_depth - padding_depth; - int dend = std::min(dstart + ksize_depth, input_depth); - dstart = std::max(dstart, 0); + if (adaptive) { + int dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + int dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + } else { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + } for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); + } else { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + } for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); + if (adaptive) { + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + } int output_idx = (pd * output_height + ph) * output_width + pw; T1 ele = static_cast(-FLT_MAX); diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index cdc79e207aa9a2e59e25a07002134c12ad5a1df8..06e92665c70421b4ddf330d97359876c5233b6cb 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, PoolProcess pool_process, - bool exclusive, T* output_data) { + bool exclusive, bool adaptive, T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -37,13 +37,21 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, int c = (index / output_width / output_height) % channels; int batch_idx = index / output_width / output_height / channels; - int hstart = ph * stride_height - padding_height; - int hend = min(hstart + ksize_height, input_height); - hstart = max(hstart, 0); + if (adaptive) { + int hstart = ADAPT_START_INDEX(ph, input_height, output_height); + int hend = ADAPT_END_INDEX(ph, input_height, output_height); - int wstart = pw * stride_width - padding_width; - int wend = min(wstart + ksize_width, input_width); - wstart = max(wstart, 0); + int wstart = ADAPT_START_INDEX(pw, input_width, output_width); + int wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + } input_data += (batch_idx * channels + c) * input_height * input_width; T ele = pool_process.initial(); @@ -52,8 +60,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, pool_process.compute(input_data[h * input_width + w], &ele); } } - int pool_size = exclusive ? (hend - hstart) * (wend - wstart) - : ksize_height * ksize_width; + int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[index] = ele; }