diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index fd4cf92d85d5daa891d602d4365122c870920bba..87ed586aad9af9609840217f334bc0524d557865 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -77,6 +77,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'] paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) +paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=('max', False, True, None)) +paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=('max', False, True, None)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) diff --git a/paddle/fluid/operators/math/pooling.cc b/paddle/fluid/operators/math/pooling.cc index 68fed9fd4ebbcb9658fb29a36f1ae3c57d0dc897..b4ee82add3133021f5e4bd24dbfeb0e58f9b20ff 100644 --- a/paddle/fluid/operators/math/pooling.cc +++ b/paddle/fluid/operators/math/pooling.cc @@ -61,24 +61,26 @@ class Pool2dFunctor { const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); + int hstart, hend; + int wstart, wend; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); } else { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); hstart = std::max(hstart, 0); } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); } @@ -136,24 +138,26 @@ class Pool2dGradFunctor { const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int hstart, hend; + int wstart, wend; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); } else { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); hstart = std::max(hstart, 0); } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); } int pool_size = (exclusive || adaptive) @@ -308,33 +312,36 @@ class Pool3dFunctor { const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); + int dstart, dend; + int hstart, hend; + int wstart, wend; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { if (adaptive) { - int dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - int dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + dend = ADAPT_END_INDEX(pd, input_depth, output_depth); } else { - int dstart = pd * stride_depth - padding_depth; - int dend = std::min(dstart + ksize_depth, input_depth); + dstart = pd * stride_depth - padding_depth; + dend = std::min(dstart + ksize_depth, input_depth); dstart = std::max(dstart, 0); } for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); } else { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); hstart = std::max(hstart, 0); } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); } int output_idx = (pd * output_height + ph) * output_width + pw; @@ -403,33 +410,36 @@ class Pool3dGradFunctor { const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int dstart, dend; + int hstart, hend; + int wstart, wend; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { if (adaptive) { - int dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - int dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + dend = ADAPT_END_INDEX(pd, input_depth, output_depth); } else { - int dstart = pd * stride_depth - padding_depth; - int dend = std::min(dstart + ksize_depth, input_depth); + dstart = pd * stride_depth - padding_depth; + dend = std::min(dstart + ksize_depth, input_depth); dstart = std::max(dstart, 0); } for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); } else { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); hstart = std::max(hstart, 0); } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); } @@ -599,24 +609,26 @@ class MaxPool2dWithIndexFunctor { T1* output_data = output->mutable_data(context.GetPlace()); T2* mask_data = mask->mutable_data(context.GetPlace()); + int hstart, hend; + int wstart, wend; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); } else { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); hstart = std::max(hstart, 0); } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); } @@ -655,7 +667,7 @@ class MaxPool2dWithIndexGradFunctor { const framework::Tensor& output_grad, const framework::Tensor& mask, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, + const std::vector& paddings, bool adaptive, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_height = input_grad->dims()[2]; @@ -708,8 +720,8 @@ class MaxPool3dWithIndexFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output, - framework::Tensor* mask) { + const std::vector& paddings, bool adaptive, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -734,33 +746,36 @@ class MaxPool3dWithIndexFunctor { T1* output_data = output->mutable_data(context.GetPlace()); T2* mask_data = mask->mutable_data(context.GetPlace()); + int dstart, dend; + int hstart, hend; + int wstart, wend; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { if (adaptive) { - int dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - int dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + dend = ADAPT_END_INDEX(pd, input_depth, output_depth); } else { - int dstart = pd * stride_depth - padding_depth; - int dend = std::min(dstart + ksize_depth, input_depth); + dstart = pd * stride_depth - padding_depth; + dend = std::min(dstart + ksize_depth, input_depth); dstart = std::max(dstart, 0); } for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); } else { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); hstart = std::max(hstart, 0); } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); } @@ -804,7 +819,7 @@ class MaxPool3dWithIndexGradFunctor { const framework::Tensor& output_grad, const framework::Tensor& mask, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, + const std::vector& paddings, bool adaptive, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_depth = input_grad->dims()[2]; diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 06e92665c70421b4ddf330d97359876c5233b6cb..5f3b82ed553d6b96a312fdc025ecb3da8b14fdcd 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -21,6 +21,18 @@ namespace paddle { namespace operators { namespace math { +__device__ __forceinline__ int ADAPT_START_INDEX(int ph, int input_size, + int output_size) { + return static_cast( + floor(static_cast(ph * input_size) / output_size)); +} + +__device__ __forceinline__ int ADAPT_END_INDEX(int ph, int input_size, + int output_size) { + return static_cast( + ceil(static_cast((ph + 1) * input_size) / output_size)); +} + template __global__ void KernelPool2D(const int nthreads, const T* input_data, const int channels, const int input_height, @@ -37,19 +49,21 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, int c = (index / output_width / output_height) % channels; int batch_idx = index / output_width / output_height / channels; + int hstart, hend; + int wstart, wend; if (adaptive) { - int hstart = ADAPT_START_INDEX(ph, input_height, output_height); - int hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); - int wstart = ADAPT_START_INDEX(pw, input_width, output_width); - int wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); } else { - int hstart = ph * stride_height - padding_height; - int hend = min(hstart + ksize_height, input_height); + hstart = ph * stride_height - padding_height; + hend = min(hstart + ksize_height, input_height); hstart = max(hstart, 0); - int wstart = pw * stride_width - padding_width; - int wend = min(wstart + ksize_width, input_width); + wstart = pw * stride_width - padding_width; + wend = min(wstart + ksize_width, input_width); wstart = max(wstart, 0); } @@ -74,7 +88,7 @@ __global__ void KernelPool2DGrad( const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, - PoolProcess pool_process, bool exclusive, T* input_grad) { + PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -82,14 +96,24 @@ __global__ void KernelPool2DGrad( int offsetC = (index / input_width / input_height) % channels; int batch_idx = index / input_width / input_height / channels; - int phstart = (offsetH < ksize_height) - ? 0 - : (offsetH - ksize_height) / stride_height + 1; - int pwstart = (offsetW < ksize_width) - ? 0 - : (offsetW - ksize_width) / stride_width + 1; - int phend = min(offsetH / stride_height + 1, output_height); - int pwend = min(offsetW / stride_width + 1, output_width); + int phstart, phend; + int pwstart, pwend; + if (adaptive) { + phstart = offsetH * output_height / input_height; + phend = + min((offsetH + 1) * output_height / input_height + 1, output_height); + pwstart = offsetW * output_width / input_width; + pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + } else { + phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + phend = min(offsetH / stride_height + 1, output_height); + pwend = min(offsetW / stride_width + 1, output_width); + } T gradient = 0; T input = input_data[index]; int output_idx = @@ -98,14 +122,22 @@ __global__ void KernelPool2DGrad( output_grad += output_idx; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { - int hstart = ph * stride_height - padding_height; - int wstart = pw * stride_width - padding_width; - int hend = min(hstart + ksize_height, input_height); - int wend = min(wstart + ksize_width, input_width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - int pool_size = exclusive ? (hend - hstart) * (wend - wstart) - : ksize_height * ksize_width; + int pool_size; + if (adaptive) { + pool_size = static_cast(ceil(static_cast(input_height) / + ksize_height)) * + static_cast( + ceil(static_cast(input_width) / ksize_width)); + } else { + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + pool_size = exclusive ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; + } int output_sub_idx = ph * output_width + pw; pool_process.compute(input, output_data[output_sub_idx], output_grad[output_sub_idx], @@ -189,7 +221,7 @@ void Pool2dDirectCUDAFunctor::operator()( KernelPool2D<<>>( nthreads, input, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, - padding_height, padding_width, pool_compute, exclusive, output); + padding_height, padding_width, pool_compute, exclusive, false, output); } /* @@ -204,7 +236,7 @@ class Pool2dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - bool exclusive, framework::Tensor* output) { + bool exclusive, bool adaptive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -231,7 +263,7 @@ class Pool2dFunctor { nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, pool_process, exclusive, - output_data); + adaptive, output_data); } }; @@ -250,7 +282,8 @@ class Pool2dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - bool exclusive, framework::Tensor* input_grad) { + bool exclusive, bool adaptive, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -278,7 +311,7 @@ class Pool2dGradFunctor { nthreads, input_data, output_data, output_grad_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, - pool_process, exclusive, input_grad_data); + pool_process, exclusive, adaptive, input_grad_data); } }; @@ -367,7 +400,7 @@ __global__ void KernelPool3D( const int ksize_depth, const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_process, bool exclusive, T* output_data) { + PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -376,15 +409,30 @@ __global__ void KernelPool3D( int c = (index / output_width / output_height / output_depth) % channels; int batch_idx = index / output_width / output_height / output_depth / channels; - int dstart = pd * stride_depth - padding_depth; - int hstart = ph * stride_height - padding_height; - int wstart = pw * stride_width - padding_width; - int dend = min(dstart + ksize_depth, input_depth); - int hend = min(hstart + ksize_height, input_height); - int wend = min(wstart + ksize_width, input_width); - dstart = max(dstart, 0); - hstart = max(hstart, 0); - wstart = max(wstart, 0); + + int dstart, dend; + int hstart, hend; + int wstart, wend; + if (adaptive) { + dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); + + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + dstart = pd * stride_depth - padding_depth; + hstart = ph * stride_height - padding_height; + wstart = pw * stride_width - padding_width; + dend = min(dstart + ksize_depth, input_depth); + hend = min(hstart + ksize_height, input_height); + wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + } T ele = pool_process.initial(); input_data += (batch_idx * channels + c) * input_depth * input_height * input_width; @@ -396,7 +444,7 @@ __global__ void KernelPool3D( } } } - int pool_size = exclusive + int pool_size = (exclusive || adaptive) ? (dend - dstart) * (hend - hstart) * (wend - wstart) : ksize_depth * ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); @@ -413,7 +461,7 @@ __global__ void KernelPool3DGrad( const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, PoolProcess pool_process, - bool exclusive, T* input_grad) { + bool exclusive, bool adaptive, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -423,18 +471,31 @@ __global__ void KernelPool3DGrad( int offsetC = (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; - int pdstart = (offsetD < ksize_depth) - ? 0 - : (offsetD - ksize_depth) / stride_depth + 1; - int phstart = (offsetH < ksize_height) - ? 0 - : (offsetH - ksize_height) / stride_height + 1; - int pwstart = (offsetW < ksize_width) - ? 0 - : (offsetW - ksize_width) / stride_width + 1; - int pdend = min((offsetD) / stride_depth + 1, output_depth); - int phend = min((offsetH) / stride_height + 1, output_height); - int pwend = min((offsetW) / stride_width + 1, output_width); + int pdstart, pdend; + int phstart, phend; + int pwstart, pwend; + if (adaptive) { + pdstart = offsetD * output_depth / input_depth; + pdend = min((offsetD + 1) * output_depth / input_depth + 1, output_depth); + phstart = offsetH * output_height / input_height; + phend = + min((offsetH + 1) * output_height / input_height + 1, output_height); + pwstart = offsetW * output_width / input_width; + pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + } else { + pdstart = (offsetD < ksize_depth) + ? 0 + : (offsetD - ksize_depth) / stride_depth + 1; + phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + pdend = min((offsetD) / stride_depth + 1, output_depth); + phend = min((offsetH) / stride_height + 1, output_height); + pwend = min((offsetW) / stride_width + 1, output_width); + } T gradient = 0; T input = input_data[index]; @@ -447,18 +508,29 @@ __global__ void KernelPool3DGrad( for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int dstart = pd * stride_depth - padding_depth; - int hstart = ph * stride_height - padding_height; - int wstart = pw * stride_width - padding_width; - int dend = min(dstart + ksize_depth, input_depth); - int hend = min(hstart + ksize_height, input_height); - int wend = min(wstart + ksize_width, input_width); - dstart = max(dstart, 0); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - int pool_size = - exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart) - : ksize_depth * ksize_height * ksize_width; + int pool_size; + if (adaptive) { + pool_size = + static_cast( + ceil(static_cast(input_depth) / ksize_depth)) * + static_cast( + ceil(static_cast(input_height) / ksize_height)) * + static_cast( + ceil(static_cast(input_width) / ksize_width)); + } else { + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + pool_size = + exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart) + : ksize_depth * ksize_height * ksize_width; + } int output_sub_idx = (pd * output_height + ph) * output_width + pw; pool_process.compute(input, output_data[output_sub_idx], output_grad[output_sub_idx], @@ -533,7 +605,7 @@ class Pool3dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - bool exclusive, framework::Tensor* output) { + bool exclusive, bool adaptive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -567,7 +639,7 @@ class Pool3dFunctor { input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, padding_depth, padding_height, padding_width, pool_process, exclusive, - output_data); + adaptive, output_data); } }; @@ -586,7 +658,8 @@ class Pool3dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - bool exclusive, framework::Tensor* input_grad) { + bool exclusive, bool adaptive, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -622,7 +695,7 @@ class Pool3dGradFunctor { input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, padding_depth, padding_height, - padding_width, pool_process, exclusive, input_grad_data); + padding_width, pool_process, exclusive, adaptive, input_grad_data); } }; @@ -711,7 +784,7 @@ __global__ void KernelMaxPool2dWithIdx( const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, - const int padding_width, T1* output_data, T2* mask_data) { + const int padding_width, bool adaptive, T1* output_data, T2* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -719,13 +792,23 @@ __global__ void KernelMaxPool2dWithIdx( int c = (index / output_width / output_height) % channels; int batch_idx = index / output_width / output_height / channels; - int hstart = ph * stride_height - padding_height; - int hend = min(hstart + ksize_height, input_height); - hstart = max(hstart, 0); + int hstart, hend; + int wstart, wend; + if (adaptive) { + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); - int wstart = pw * stride_width - padding_width; - int wend = min(wstart + ksize_width, input_width); - wstart = max(wstart, 0); + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + hstart = ph * stride_height - padding_height; + hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + wstart = pw * stride_width - padding_width; + wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + } input_data += (batch_idx * channels + c) * input_height * input_width; T1 ele = -FLT_MAX; @@ -750,36 +833,46 @@ __global__ void KernelMaxPool2DWithIdxGrad( const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, T1* input_grad) { + const int padding_height, const int padding_width, bool adaptive, + T1* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int w_offset = index % input_width; - int h_offset = (index / input_width) % input_height; - int c_offset = (index / input_width / input_height) % channels; + int offsetW = index % input_width; + int offsetH = (index / input_width) % input_height; + int offsetC = (index / input_width / input_height) % channels; int batch_idx = index / input_width / input_height / channels; - int ph_start = - (h_offset + padding_height < ksize_height) - ? 0 - : (h_offset + padding_height - ksize_height) / stride_height + 1; - int pw_start = - (w_offset + padding_width < ksize_width) - ? 0 - : (w_offset + padding_width - ksize_width) / stride_width + 1; - int ph_end = - min((h_offset + padding_height) / stride_height + 1, output_height); - int pw_end = - min((w_offset + padding_width) / stride_width + 1, output_width); + int phstart, phend; + int pwstart, pwend; + if (adaptive) { + phstart = offsetH * output_height / input_height; + phend = + min((offsetH + 1) * output_height / input_height + 1, output_height); + pwstart = offsetW * output_width / input_width; + pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + } else { + phstart = + (offsetH + padding_height < ksize_height) + ? 0 + : (offsetH + padding_height - ksize_height) / stride_height + 1; + pwstart = + (offsetW + padding_width < ksize_width) + ? 0 + : (offsetW + padding_width - ksize_width) / stride_width + 1; + phend = + min((offsetH + padding_height) / stride_height + 1, output_height); + pwend = min((offsetW + padding_width) / stride_width + 1, output_width); + } T1 gradient = 0; - int input_current_featuremap_idx = h_offset * input_width + w_offset; + int input_current_featuremap_idx = offsetH * input_width + offsetW; int output_idx = - (batch_idx * channels + c_offset) * output_height * output_width; + (batch_idx * channels + offsetC) * output_height * output_width; mask_data += output_idx; output_grad += output_idx; - for (int ph = ph_start; ph < ph_end; ++ph) { - for (int pw = pw_start; pw < pw_end; ++pw) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) gradient += output_grad[ph * output_width + pw]; } @@ -799,8 +892,8 @@ class MaxPool2dWithIndexFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output, - framework::Tensor* mask) { + const std::vector& paddings, bool adaptive, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -827,7 +920,8 @@ class MaxPool2dWithIndexFunctor { KernelMaxPool2dWithIdx<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, - stride_width, padding_height, padding_width, output_data, mask_data); + stride_width, padding_height, padding_width, adaptive, output_data, + mask_data); } }; @@ -843,7 +937,7 @@ class MaxPool2dWithIndexGradFunctor { const framework::Tensor& output_grad, const framework::Tensor& mask, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, + const std::vector& paddings, bool adaptive, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; @@ -870,7 +964,7 @@ class MaxPool2dWithIndexGradFunctor { KernelMaxPool2DWithIdxGrad<<>>( nthreads, output_grad_data, mask_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, - stride_height, stride_width, padding_height, padding_width, + stride_height, stride_width, padding_height, padding_width, adaptive, input_grad_data); } }; @@ -892,7 +986,7 @@ __global__ void KernelMaxPool3DWithIdx( const int ksize_depth, const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, - T1* output_data, T2* mask_data) { + bool adaptive, T1* output_data, T2* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -902,15 +996,29 @@ __global__ void KernelMaxPool3DWithIdx( int batch_idx = index / output_width / output_height / output_depth / channels; - int dstart = pd * stride_depth - padding_depth; - int hstart = ph * stride_height - padding_height; - int wstart = pw * stride_width - padding_width; - int dend = min(dstart + ksize_depth, input_depth); - int hend = min(hstart + ksize_height, input_height); - int wend = min(wstart + ksize_width, input_width); - dstart = max(dstart, 0); - hstart = max(hstart, 0); - wstart = max(wstart, 0); + int dstart, dend; + int hstart, hend; + int wstart, wend; + if (adaptive) { + dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); + dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + + hstart = ADAPT_START_INDEX(ph, input_height, output_height); + hend = ADAPT_END_INDEX(ph, input_height, output_height); + + wstart = ADAPT_START_INDEX(pw, input_width, output_width); + wend = ADAPT_END_INDEX(pw, input_width, output_width); + } else { + dstart = pd * stride_depth - padding_depth; + hstart = ph * stride_height - padding_height; + wstart = pw * stride_width - padding_width; + dend = min(dstart + ksize_depth, input_depth); + hend = min(hstart + ksize_height, input_height); + wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + } T1 ele = -FLT_MAX; int max_index = -1; @@ -940,46 +1048,56 @@ __global__ void KernelMaxPool3DWithIdxGrad( const int output_width, const int ksize_depth, const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, - const int padding_width, T1* input_grad) { + const int padding_width, bool adaptive, T1* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int w_offset = index % input_width; - int h_offset = (index / input_width) % input_height; - int d_offset = (index / input_width / input_height) % input_depth; - int c_offset = - (index / input_width / input_height / input_depth) % channels; + int offsetW = index % input_width; + int offsetH = (index / input_width) % input_height; + int offsetD = (index / input_width / input_height) % input_depth; + int offsetC = (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; - int pd_start = - (d_offset + padding_depth < ksize_depth) - ? 0 - : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; - int ph_start = - (h_offset + padding_height < ksize_height) - ? 0 - : (h_offset + padding_height - ksize_height) / stride_height + 1; - int pw_start = - (w_offset + padding_width < ksize_width) - ? 0 - : (w_offset + padding_width - ksize_width) / stride_width + 1; - int pd_end = - min((d_offset + padding_depth) / stride_depth + 1, output_depth); - int ph_end = - min((h_offset + padding_height) / stride_height + 1, output_height); - int pw_end = - min((w_offset + padding_width) / stride_width + 1, output_width); + int pdstart, pdend; + int phstart, phend; + int pwstart, pwend; + if (adaptive) { + pdstart = offsetD * output_depth / input_depth; + pdend = min((offsetD + 1) * output_depth / input_depth + 1, output_depth); + phstart = offsetH * output_height / input_height; + phend = + min((offsetH + 1) * output_height / input_height + 1, output_height); + pwstart = offsetW * output_width / input_width; + pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + } else { + pdstart = + (offsetD + padding_depth < ksize_depth) + ? 0 + : (offsetD + padding_depth - ksize_depth) / stride_depth + 1; + phstart = + (offsetH + padding_height < ksize_height) + ? 0 + : (offsetH + padding_height - ksize_height) / stride_height + 1; + pwstart = + (offsetW + padding_width < ksize_width) + ? 0 + : (offsetW + padding_width - ksize_width) / stride_width + 1; + pdend = min((offsetD + padding_depth) / stride_depth + 1, output_depth); + phend = + min((offsetH + padding_height) / stride_height + 1, output_height); + pwend = min((offsetW + padding_width) / stride_width + 1, output_width); + } T1 gradient = 0; int input_current_feature_map_idx = - (d_offset * input_height + h_offset) * input_width + w_offset; - int output_idx = (batch_idx * channels + c_offset) * output_depth * + (offsetD * input_height + offsetH) * input_width + offsetW; + int output_idx = (batch_idx * channels + offsetC) * output_depth * output_height * output_width; mask += output_idx; output_grad += output_idx; - for (int pd = pd_start; pd < pd_end; ++pd) { - for (int ph = ph_start; ph < ph_end; ++ph) { - for (int pw = pw_start; pw < pw_end; ++pw) { + for (int pd = pdstart; pd < pdend; ++pd) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { if (mask[(pd * output_height + ph) * output_width + pw] == input_current_feature_map_idx) gradient += @@ -1002,8 +1120,8 @@ class MaxPool3dWithIndexFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output, - framework::Tensor* mask) { + const std::vector& paddings, bool adaptive, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -1037,7 +1155,8 @@ class MaxPool3dWithIndexFunctor { nthreads, input_data, input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, - padding_depth, padding_height, padding_width, output_data, mask_data); + padding_depth, padding_height, padding_width, adaptive, output_data, + mask_data); } }; @@ -1053,7 +1172,7 @@ class MaxPool3dWithIndexGradFunctor { const framework::Tensor& output_grad, const framework::Tensor& mask, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, + const std::vector& paddings, bool adaptive, framework::Tensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; @@ -1087,7 +1206,7 @@ class MaxPool3dWithIndexGradFunctor { nthreads, output_grad_data, mask_data, input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, + stride_width, padding_depth, padding_height, padding_width, adaptive, input_grad_data); } }; diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 923babd4c248364b735bb09def7bf12f2762f305..d123af8924b77243e5a0d4d1212d35e16a44c3ae 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -102,7 +102,7 @@ class Pool2dFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, framework::Tensor* output); + bool exclusive, bool adaptive, framework::Tensor* output); }; template @@ -114,7 +114,7 @@ class Pool2dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, framework::Tensor* input_grad); + bool exclusive, bool adaptive, framework::Tensor* input_grad); }; template @@ -136,7 +136,7 @@ class Pool3dFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, framework::Tensor* output); + bool exclusive, bool adaptive, framework::Tensor* output); }; template @@ -148,7 +148,7 @@ class Pool3dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - bool exclusive, framework::Tensor* input_grad); + bool exclusive, bool adaptive, framework::Tensor* input_grad); }; template @@ -176,8 +176,8 @@ class MaxPool2dWithIndexFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output, - framework::Tensor* mask); + const std::vector& paddings, bool adaptive, + framework::Tensor* output, framework::Tensor* mask); }; template @@ -187,7 +187,7 @@ class MaxPool2dWithIndexGradFunctor { const framework::Tensor& output_grad, const framework::Tensor& mask, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, + const std::vector& paddings, bool adaptive, framework::Tensor* input_grad); }; @@ -197,8 +197,8 @@ class MaxPool3dWithIndexFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output, - framework::Tensor* mask); + const std::vector& paddings, bool adaptive, + framework::Tensor* output, framework::Tensor* mask); }; template @@ -208,7 +208,7 @@ class MaxPool3dWithIndexGradFunctor { const framework::Tensor& output_grad, const framework::Tensor& mask, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, + const std::vector& paddings, bool adaptive, framework::Tensor* input_grad); }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 52b607df74446866c535751f3faa11765cb6f247..11b5c493230de8390892a0001ea72af9c2eddd9d 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -52,6 +52,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); bool ceil_mode = ctx->Attrs().Get("ceil_mode"); + bool adaptive = ctx->Attrs().Get("adaptive"); PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, "Pooling intput should be 4-D or 5-D tensor."); @@ -72,9 +73,13 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { "Paddings size and pooling size should be the same."); std::vector output_shape({in_x_dims[0], in_x_dims[1]}); - for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back(PoolOutputSize(in_x_dims[i + 2], ksize[i], - paddings[i], strides[i], ceil_mode)); + if (adaptive) { + output_shape.insert(output_shape.end(), ksize.begin(), ksize.end()); + } else { + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(PoolOutputSize( + in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode)); + } } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->ShareLoD("X", "Out"); @@ -186,6 +191,14 @@ void Pool2dOpMaker::Make() { "averaging calculating, otherwise, include the zero-padding. Note, it " "is only used when pooling_type is avg. The defalut is True.") .SetDefault(true); + AddAttr( + "adaptive", + "(bool, default False) When true, will perform adaptive pooling instead, " + "output shape in H and W dimensions will be same as ksize, input data " + "will be divided into grids specify by ksize averagely and perform " + "pooling in each grid area to get output pooling value.") + .SetDefault(false); + AddAttr( "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") @@ -325,6 +338,13 @@ void Pool3dOpMaker::Make() { "averaging calculating, otherwise, include the zero-padding. Note, it " "is only used when pooling_type is avg. The defalut is True.") .SetDefault(true); + AddAttr( + "adaptive", + "(bool, default False) When true, will perform adaptive pooling instead, " + "output shape in H and W dimensions will be same as ksize, input data " + "will be divided into grids specify by ksize averagely and perform " + "pooling in each grid area to get output pooling value.") + .SetDefault(false); AddAttr( "use_cudnn", diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index c0594b7e3cc5602a44bb01951a22c2135ba5c7ce..6c5900bd0f55bb817834de6d1f3c5e4eb7f282b9 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -70,6 +70,7 @@ class PoolKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); bool exclusive = context.Attr("exclusive"); + bool adaptive = context.Attr("adaptive"); if (context.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; @@ -85,7 +86,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::MaxPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - true, out); + true, false, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dFunctor< @@ -93,7 +94,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::AvgPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - exclusive, out); + exclusive, adaptive, out); } } break; case 3: { @@ -103,14 +104,14 @@ class PoolKernel : public framework::OpKernel { pool3d_forward; paddle::operators::math::MaxPool pool_process; pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - true, out); + true, false, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dFunctor< DeviceContext, paddle::operators::math::AvgPool, T> pool3d_forward; paddle::operators::math::AvgPool pool_process; pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - exclusive, out); + exclusive, adaptive, out); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } @@ -133,6 +134,7 @@ class PoolGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); bool exclusive = context.Attr("exclusive"); + bool adaptive = context.Attr("adaptive"); if (context.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { @@ -159,7 +161,8 @@ class PoolGradKernel : public framework::OpKernel { pool2d_backward; paddle::operators::math::AvgPoolGrad pool_process; pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, - paddings, pool_process, exclusive, in_x_grad); + paddings, pool_process, exclusive, adaptive, + in_x_grad); } } break; case 3: { @@ -174,7 +177,8 @@ class PoolGradKernel : public framework::OpKernel { pool3d_backward; paddle::operators::math::AvgPoolGrad pool_process; pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, - paddings, pool_process, exclusive, in_x_grad); + paddings, pool_process, exclusive, adaptive, + in_x_grad); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 873706593e4c856f0079738654a9e7e59a1c0cd8..f9e25277e5cfa1f83567b2456c36cb60321d61d4 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -40,6 +40,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); + bool adaptive = ctx->Attrs().Get("adaptive"); PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, "Pooling intput should be 4-D or 5-D tensor."); @@ -60,9 +61,13 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { "Paddings size and pooling size should be the same."); std::vector output_shape({in_x_dims[0], in_x_dims[1]}); - for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back(MaxPoolOutputSize(in_x_dims[i + 2], ksize[i], - paddings[i], strides[i])); + if (adaptive) { + output_shape.insert(output_shape.end(), ksize.begin(), ksize.end()); + } else { + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(MaxPoolOutputSize(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); @@ -133,6 +138,14 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default:false) Whether to use the global pooling. " "If global_pooling = true, ksize and paddings will be ignored.") .SetDefault(false); + AddAttr( + "adaptive", + "(bool, default False) When true, will perform adaptive pooling " + "instead, " + "output shape in H and W dimensions will be same as ksize, input data " + "will be divided into grids specify by ksize averagely and perform " + "pooling in each grid area to get output pooling value.") + .SetDefault(false); AddAttr>("strides", "(vector, default {1, 1}), strides(height, " "width) of pooling operator.") @@ -209,6 +222,14 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Whether to use the global pooling. " "If global_pooling = true, ksize and paddings will be ignored.") .SetDefault(false); + AddAttr( + "adaptive", + "(bool, default False) When true, will perform adaptive pooling " + "instead, " + "output shape in H and W dimensions will be same as ksize, input data " + "will be divided into grids specify by ksize averagely and perform " + "pooling in each grid area to get output pooling value.") + .SetDefault(false); AddAttr>("strides", "(vector, default {1,1,1}), strides(depth, " "height, width) of pooling operator.") diff --git a/paddle/fluid/operators/pool_with_index_op.h b/paddle/fluid/operators/pool_with_index_op.h index b55fa76eae34c3179d40f31ed6a57d3ecbbaaccf..a6bec121d4ff002ec80a0f47510e4431176e0ddc 100644 --- a/paddle/fluid/operators/pool_with_index_op.h +++ b/paddle/fluid/operators/pool_with_index_op.h @@ -36,6 +36,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + bool adaptive = context.Attr("adaptive"); auto& dev_ctx = context.template device_context(); if (context.Attr("global_pooling")) { @@ -50,13 +51,15 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { paddle::operators::math::MaxPool2dWithIndexFunctor pool2d_forward; - pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, out, mask); + pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, adaptive, out, + mask); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexFunctor pool3d_forward; - pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, out, mask); + pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, adaptive, out, + mask); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } @@ -75,6 +78,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + bool adaptive = context.Attr("adaptive"); if (context.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; @@ -93,14 +97,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { T1, T2> pool2d_backward; pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides, - paddings, in_x_grad); + paddings, adaptive, in_x_grad); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexGradFunctor pool3d_backward; pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides, - paddings, in_x_grad); + paddings, adaptive, in_x_grad); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } diff --git a/paddle/fluid/operators/spp_op.h b/paddle/fluid/operators/spp_op.h index 35d9737ee01fe1505cbe30e8ed735e6b92cb8df2..3c2d51ec9111e649632dda89290f21a0988db6dd 100644 --- a/paddle/fluid/operators/spp_op.h +++ b/paddle/fluid/operators/spp_op.h @@ -56,13 +56,13 @@ class SppKernel : public framework::OpKernel { math::Pool2dFunctor, T> pool_forward; math::MaxPool max_process; pool_forward(context.template device_context(), *in_x, - kernel_size, strides, paddings, max_process, true, + kernel_size, strides, paddings, max_process, true, false, &out_level); } else if (pooling_type == "avg") { math::Pool2dFunctor, T> pool_forward; math::AvgPool avg_process; pool_forward(context.template device_context(), *in_x, - kernel_size, strides, paddings, avg_process, true, + kernel_size, strides, paddings, avg_process, true, false, &out_level); } // flatten pooling output shape @@ -156,7 +156,7 @@ class SppGradKernel : public framework::OpKernel { math::AvgPoolGrad avg_process; pool_backward(context.template device_context(), *in_x, *&out_level, *&outgrad_level, kernel_size, strides, - paddings, avg_process, true, in_x_grad); + paddings, avg_process, true, false, in_x_grad); } } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e25eaaa9fda6add9d8e81d9e6bdfb711cee3648e..61794f0d49a13a1796ef3f6eecdcd8c299327920 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -52,6 +52,8 @@ __all__ = [ 'softmax', 'pool2d', 'pool3d', + 'adaptive_pool2d', + 'adaptive_pool3d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', @@ -2499,6 +2501,190 @@ def pool3d(input, return pool_out +@templatedoc(op_type="pool2d") +def adaptive_pool2d(input, + pool_size, + pool_type="max", + require_index=False, + use_cudnn=True, + name=None): + """ + ${comment} + + Args: + input (Variable): The input tensor of pooling operator. The format of + input tensor is NCHW, where N is batch size, C is + the number of channels, H is the height of the + feature, and W is the width of the feature. + pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two integers, (pool_size_Height, pool_size_Width). + pool_type: ${pooling_type_comment} + require_index (bool): If true, the index of max pooling point along with outputs. + it cannot be set in average pooling type. + use_cudnn (bool): ${use_cudnn_comment} + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically. + + Returns: + Variable: The pooling result. + + Raises: + ValueError: 'pool_type' is not 'max' nor 'avg'. + ValueError: 'use_cudnn' is not a bool value. + ValueError: invalid setting 'require_index' true when 'pool_type' is 'avg'. + ValueError: 'pool_size' should be a list or tuple with length as 2. + + Examples: + + .. code-block:: python + + data = fluid.layers.data( + name='data', shape=[3, 32, 32], dtype='float32') + conv2d = fluid.layers.pool2d( + input=data, + pool_size=[3, 3], + pool_type='max', + require_index=True) + """ + if pool_type not in ["max", "avg"]: + raise ValueError( + "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", + str(pool_type)) + + if pool_type == "avg" and require_index: + raise ValueError( + "invalid setting 'require_index' true when 'pool_type' is 'avg'.") + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(pool_size) or len(pool_size) != 2: + raise ValueError( + "'pool_size' should be a list or tuple with length as 2.") + + if not isinstance(use_cudnn, bool): + raise ValueError("use_cudnn should be True or False.") + + if pool_type == "max": + l_type = 'max_pool2d_with_index' + else: + l_type = "pool2d" + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + outputs = {"Out": pool_out} + if pool_type == "max": + mask = helper.create_variable_for_type_inference(dtype) + outputs["Mask"] = mask + + helper.append_op( + type=l_type, + inputs={"X": input}, + outputs=outputs, + attrs={ + "pooling_type": pool_type, + "ksize": pool_size, + "use_cudnn": use_cudnn, + "adaptive": True, + }) + + return pool_out + + +@templatedoc(op_type="pool3d") +def adaptive_pool3d(input, + pool_size, + pool_type="max", + require_index=False, + use_cudnn=True, + name=None): + """ + ${comment} + + Args: + input (Variable): The input tensor of pooling operator. The format of + input tensor is NCHW, where N is batch size, C is + the number of channels, H is the height of the + feature, and W is the width of the feature. + pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two integers, (Depth, Height, Width). + pool_type: ${pooling_type_comment} + require_index (bool): If true, the index of max pooling point along with outputs. + it cannot be set in average pooling type. + use_cudnn (bool): ${use_cudnn_comment} + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically. + + Returns: + Variable: The pooling result. + + Raises: + ValueError: 'pool_type' is not 'max' nor 'avg'. + ValueError: 'use_cudnn' is not a bool value. + ValueError: invalid setting 'require_index' true when 'pool_type' is 'avg'. + ValueError: 'pool_size' should be a list or tuple with length as 2. + + Examples: + + .. code-block:: python + + data = fluid.layers.data( + name='data', shape=[3, 32, 32], dtype='float32') + conv2d = fluid.layers.pool2d( + input=data, + pool_size=[3, 3], + pool_type='max', + require_index=True) + """ + if pool_type not in ["max", "avg"]: + raise ValueError( + "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", + str(pool_type)) + + if pool_type == "avg" and require_index: + raise ValueError( + "invalid setting 'require_index' true when 'pool_type' is 'avg'.") + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(pool_size) or len(pool_size) != 3: + raise ValueError( + "'pool_size' should be a list or tuple with length as 3.") + + if not isinstance(use_cudnn, bool): + raise ValueError("use_cudnn should be True or False.") + + if pool_type == "max": + l_type = 'max_pool3d_with_index' + else: + l_type = "pool3d" + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + outputs = {"Out": pool_out} + if pool_type == "max": + mask = helper.create_variable_for_type_inference(dtype) + outputs["Mask"] = mask + + helper.append_op( + type=l_type, + inputs={"X": input}, + outputs=outputs, + attrs={ + "pooling_type": pool_type, + "ksize": pool_size, + "use_cudnn": use_cudnn, + "adaptive": True, + }) + + return pool_out + + def batch_norm(input, act=None, is_test=False, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 10e8bb5a86691d8654c5ae48794e49f30f47500d..9785b5063cdd1d90462fa63a4dc2d567926fc061 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -233,6 +233,28 @@ class TestBook(unittest.TestCase): pool_stride=[1, 2], pool_padding=(2, 1))) + def test_adaptive_pool2d(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[3, 224, 224], dtype='float32') + self.assertIsNotNone( + layers.adaptive_pool2d( + x, [3, 3], require_index=True)) + self.assertIsNotNone( + layers.adaptive_pool2d( + x, [3, 3], pool_type='avg')) + + def test_adaptive_pool3d(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[3, 244, 224, 224], dtype='float32') + self.assertIsNotNone( + layers.adaptive_pool3d( + x, [3, 3, 3], require_index=True)) + self.assertIsNotNone( + layers.adaptive_pool3d( + x, [3, 3, 3], pool_type='avg')) + def test_lstm_unit(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 47b2e71a4e52a327831fde7494bd7a2306b6f2ea..5ccdf082e8a4f8aabcd55b6b470a77690ee6f61f 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function +from __future__ import division import unittest import numpy as np @@ -21,29 +22,47 @@ import paddle.fluid.core as core from op_test import OpTest +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0, ceil_mode=False, - exclusive=True): + exclusive=True, + adaptive=False): N, C, H, W = x.shape if global_pool == 1: ksize = [H, W] - H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else ( - H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else ( - W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) for i in range(H_out): for j in range(W_out): - r_start = np.max((i * strides[0] - paddings[0], 0)) - r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) - c_start = np.max((j * strides[1] - paddings[1], 0)) - c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) x_masked = x[:, :, r_start:r_end, c_start:c_end] out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) @@ -56,27 +75,37 @@ def avg_pool2D_forward_naive(x, paddings, global_pool=0, ceil_mode=False, - exclusive=True): + exclusive=True, + adaptive=False): N, C, H, W = x.shape if global_pool == 1: ksize = [H, W] - H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else ( - H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else ( - W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) for i in range(H_out): for j in range(W_out): - r_start = np.max((i * strides[0] - paddings[0], 0)) - r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) - c_start = np.max((j * strides[1] - paddings[1], 0)) - c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) x_masked = x[:, :, r_start:r_end, c_start:c_end] - field_size = ((r_end - r_start) * (c_end - c_start)) if exclusive \ - else (ksize[0] * ksize[1]) + field_size = ((r_end - r_start) * (c_end - c_start)) \ + if (exclusive or adaptive) else (ksize[0] * ksize[1]) out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size return out @@ -93,12 +122,13 @@ class TestPool2D_Op(OpTest): self.init_pool_type() self.init_ceil_mode() self.init_exclusive() + self.init_adaptive() if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] input = np.random.random(self.shape).astype(self.dtype) output = self.pool2D_forward_naive( input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode, self.exclusive).astype(self.dtype) + self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { @@ -112,7 +142,8 @@ class TestPool2D_Op(OpTest): 'ceil_mode': self.ceil_mode, 'data_format': 'AnyLayout', # TODO(dzhwinter) : should be fix latter - 'exclusive': self.exclusive + 'exclusive': self.exclusive, + 'adaptive': self.adaptive } self.outputs = {'Out': output} @@ -159,6 +190,9 @@ class TestPool2D_Op(OpTest): def init_exclusive(self): self.exclusive = True + def init_adaptive(self): + self.adaptive = False + class TestCase1(TestPool2D_Op): def init_test_case(self): @@ -315,5 +349,10 @@ class TestCUDNNAvgInclude(TestCase2): self.exclusive = False +class TestAvgPoolAdaptive(TestCase1): + def init_adaptive(self): + self.adaptive = True + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index f05f8ccb3985be162d89da099496d5b2baf4afdc..47a5b2d1abe11a37d24624ff52d05ea135befe7c 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function +from __future__ import division import unittest import numpy as np @@ -21,35 +22,59 @@ import paddle.fluid.core as core from op_test import OpTest +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0, ceil_mode=False, - exclusive=True): + exclusive=True, + adaptive=False): N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] - D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else ( - H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else ( - W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 - W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 - ) // strides[2] + 1 if ceil_mode else ( - W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 + if adaptive: + D_out, H_out, W_out = ksize + else: + D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 + ) // strides[2] + 1 if ceil_mode else ( + W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): - d_start = np.max((k * strides[0] - paddings[0], 0)) - d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + if adaptive: + d_start = adaptive_start_index(k, D, ksize[0]) + d_end = adaptive_end_index(k, D, ksize[0]) + else: + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) for i in range(H_out): - h_start = np.max((i * strides[0] - paddings[0], 0)) - h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + if adaptive: + h_start = adaptive_start_index(i, H, ksize[1]) + h_end = adaptive_end_index(i, H, ksize[1]) + else: + h_start = np.max((i * strides[1] - paddings[1], 0)) + h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H)) for j in range(W_out): - w_start = np.max((j * strides[1] - paddings[1], 0)) - w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + if adaptive: + w_start = adaptive_start_index(j, W, ksize[2]) + w_end = adaptive_end_index(j, W, ksize[2]) + else: + w_start = np.max((j * strides[2] - paddings[2], 0)) + w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W)) x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) @@ -62,33 +87,49 @@ def avg_pool3D_forward_naive(x, paddings, global_pool=0, ceil_mode=False, - exclusive=True): + exclusive=True, + adaptive=False): N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] - D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else ( - H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else ( - W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 - W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 - ) // strides[2] + 1 if ceil_mode else ( - W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 + if adaptive: + D_out, H_out, W_out = ksize + else: + D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1 + ) // strides[2] + 1 if ceil_mode else ( + W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): - d_start = np.max((k * strides[0] - paddings[0], 0)) - d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + if adaptive: + d_start = adaptive_start_index(k, D, ksize[0]) + d_end = adaptive_end_index(k, D, ksize[0]) + else: + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) for i in range(H_out): - h_start = np.max((i * strides[0] - paddings[0], 0)) - h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + if adaptive: + h_start = adaptive_start_index(i, H, ksize[1]) + h_end = adaptive_end_index(i, H, ksize[1]) + else: + h_start = np.max((i * strides[1] - paddings[1], 0)) + h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H)) for j in range(W_out): - w_start = np.max((j * strides[1] - paddings[1], 0)) - w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + if adaptive: + w_start = adaptive_start_index(j, W, ksize[2]) + w_end = adaptive_end_index(j, W, ksize[2]) + else: + w_start = np.max((j * strides[2] - paddings[2], 0)) + w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W)) x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] field_size = (d_end - d_start) * (h_end - h_start) * (w_end - w_start) \ - if exclusive else ksize[0] * ksize[1] * ksize[2] + if (exclusive or adaptive) else ksize[0] * ksize[1] * ksize[2] out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / field_size return out @@ -105,13 +146,14 @@ class TestPool3d_Op(OpTest): self.init_pool_type() self.init_ceil_mode() self.init_exclusive() + self.init_adaptive() if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] input = np.random.random(self.shape).astype(self.dtype) output = self.pool3D_forward_naive( input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode, self.exclusive).astype(self.dtype) + self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { @@ -124,7 +166,8 @@ class TestPool3d_Op(OpTest): 'ceil_mode': self.ceil_mode, 'data_format': 'AnyLayout', # TODO(dzhwinter) : should be fix latter - 'exclusive': self.exclusive + 'exclusive': self.exclusive, + 'adaptive': self.adaptive } self.outputs = {'Out': output} @@ -171,6 +214,9 @@ class TestPool3d_Op(OpTest): def init_exclusive(self): self.exclusive = True + def init_adaptive(self): + self.adaptive = False + class TestCase1(TestPool3d_Op): def init_test_case(self): @@ -353,5 +399,10 @@ class TestCUDNNAvgInclude(TestCUDNNCase3): self.exclusive = False +class TestAvgPoolAdaptive(TestCase1): + def init_adaptive(self): + self.adaptive = True + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool_max_op.py b/python/paddle/fluid/tests/unittests/test_pool_max_op.py index 488ff431d4f2ef76ce0c9486d8c307b4e01b5544..6575c408eeaa43d4f7caf257b2ebd77a942aecda 100644 --- a/python/paddle/fluid/tests/unittests/test_pool_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool_max_op.py @@ -13,33 +13,62 @@ # limitations under the License. from __future__ import print_function +from __future__ import division import unittest import numpy as np from op_test import OpTest -def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False): +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + +def max_pool3D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=False, + adaptive=False): N, C, D, H, W = x.shape if global_pool: ksize = [D, H, W] paddings = [0, 0, 0] - D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1 - W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 + if adaptive: + D_out, H_out, W_out = ksize + else: + D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 out = np.zeros((N, C, D_out, H_out, W_out)) mask = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): - d_start = np.max((k * strides[0] - paddings[0], 0)) - d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + if adaptive: + d_start = adaptive_start_index(k, D, ksize[0]) + d_end = adaptive_end_index(k, D, ksize[0]) + else: + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) for i in range(H_out): - h_start = np.max((i * strides[0] - paddings[0], 0)) - h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + if adaptive: + h_start = adaptive_start_index(i, H, ksize[1]) + h_end = adaptive_end_index(i, H, ksize[1]) + else: + h_start = np.max((i * strides[1] - paddings[1], 0)) + h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H)) for j in range(W_out): - w_start = np.max((j * strides[1] - paddings[1], 0)) - w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + if adaptive: + w_start = adaptive_start_index(j, W, ksize[2]) + w_end = adaptive_end_index(j, W, ksize[2]) + else: + w_start = np.max((j * strides[2] - paddings[2], 0)) + w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W)) x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) @@ -58,23 +87,37 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False): return out, mask -def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=False): +def max_pool2D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=False, + adaptive=False): N, C, H, W = x.shape if global_pool: ksize = [H, W] paddings = [0, 0] - H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) mask = np.zeros((N, C, H_out, W_out)) for i in range(H_out): for j in range(W_out): - r_start = np.max((i * strides[0] - paddings[0], 0)) - r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) - c_start = np.max((j * strides[1] - paddings[1], 0)) - c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) x_masked = x[:, :, r_start:r_end, c_start:c_end] out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) @@ -95,10 +138,12 @@ class TestMaxPoolWithIndex_Op(OpTest): def setUp(self): self.init_test_case() self.init_global() + self.init_adaptive() input = np.random.random(self.shape).astype("float32") output, mask = self.pool_forward_naive(input, self.ksize, self.strides, - self.paddings, self.global_pool) + self.paddings, self.global_pool, + self.adaptive) output = output.astype("float32") mask = mask.astype("int32") @@ -107,6 +152,7 @@ class TestMaxPoolWithIndex_Op(OpTest): 'paddings': self.paddings, 'ksize': self.ksize, 'global_pooling': self.global_pool, + 'adaptive': self.adaptive, } self.inputs = {'X': input} @@ -129,6 +175,9 @@ class TestMaxPoolWithIndex_Op(OpTest): def init_global(self): self.global_pool = False + def init_adaptive(self): + self.adaptive = False + class TestCase1(TestMaxPoolWithIndex_Op): def init_global(self): @@ -190,5 +239,15 @@ class TestCase7(TestCase6): self.global_pool = False +class TestCastAdaptive2d(TestCase6): + def init_adaptive(self): + self.adaptive = True + + +class TestCastAdaptive3d(TestMaxPoolWithIndex_Op): + def init_adaptive(self): + self.adaptive = True + + if __name__ == '__main__': unittest.main()