From c93e044ae0d34f4456b0400529ebe925bda2fc7f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 26 Oct 2018 16:16:46 +0800 Subject: [PATCH] add inclusive/exclusive mode in PoolOp avg pool type --- paddle/fluid/operators/math/pooling.cc | 30 +++++----- paddle/fluid/operators/math/pooling.cu | 55 ++++++++++--------- paddle/fluid/operators/math/pooling.h | 8 +-- paddle/fluid/operators/pool_cudnn_op.cu.cc | 6 +- paddle/fluid/operators/pool_op.cc | 12 ++++ paddle/fluid/operators/pool_op.h | 14 +++-- paddle/fluid/operators/spp_op.h | 8 ++- paddle/fluid/platform/cudnn_helper.h | 11 +++- python/paddle/fluid/layers/nn.py | 18 ++++-- .../fluid/tests/unittests/test_pool2d_op.py | 28 ++++++++-- .../fluid/tests/unittests/test_pool3d_op.py | 28 ++++++++-- 11 files changed, 145 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/operators/math/pooling.cc b/paddle/fluid/operators/math/pooling.cc index b87185179..dba687be9 100644 --- a/paddle/fluid/operators/math/pooling.cc +++ b/paddle/fluid/operators/math/pooling.cc @@ -29,8 +29,8 @@ class Pool2dFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, + const std::vector& strides, const std::vector& paddings, + PoolProcess pool_process, bool exclusive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; @@ -68,7 +68,8 @@ class Pool2dFunctor { pool_process.compute(input_data[h * input_width + w], &ele); } } - int pool_size = (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[ph * output_width + pw] = ele; } @@ -93,7 +94,7 @@ class Pool2dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_grad_process, - framework::Tensor* input_grad) { + bool exclusive, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -124,7 +125,8 @@ class Pool2dGradFunctor { int wstart = pw * stride_width - padding_width; int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); - int pool_size = (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; float scale = 1.0 / pool_size; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -247,9 +249,9 @@ class Pool3dFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, PoolProcess pool_process, - framework::Tensor* output) { + const std::vector& strides, const std::vector& paddings, + PoolProcess pool_process, + bool exclusive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -299,8 +301,9 @@ class Pool3dFunctor { } } } - int pool_size = - (dend - dstart) * (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? + (dend - dstart) * (hend - hstart) * (wend - wstart) + : ksize_depth * ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[output_idx] = ele; } @@ -326,7 +329,7 @@ class Pool3dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_grad_process, - framework::Tensor* input_grad) { + bool exclusive, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -368,8 +371,9 @@ class Pool3dGradFunctor { int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); - int pool_size = - (dend - dstart) * (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? + (dend - dstart) * (hend - hstart) * (wend - wstart) + : ksize_depth * ksize_height * ksize_width; float scale = 1.0 / pool_size; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index b1c76350d..437d7039a 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, PoolProcess pool_process, - T* output_data) { + bool exclusive, T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, pool_process.compute(input_data[h * input_width + w], &ele); } } - int pool_size = (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[index] = ele; } @@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad( const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, - PoolProcess pool_process, T* input_grad) { + PoolProcess pool_process, bool exclusive, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad( int wend = min(wstart + ksize_width, input_width); hstart = max(hstart, 0); wstart = max(wstart, 0); - int pool_size = (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; int output_sub_idx = ph * output_width + pw; pool_process.compute(input, output_data[output_sub_idx], output_grad[output_sub_idx], @@ -163,7 +165,7 @@ class Pool2dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - framework::Tensor* output) { + bool exclusive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -189,7 +191,8 @@ class Pool2dFunctor { KernelPool2D<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, - stride_width, padding_height, padding_width, pool_process, output_data); + stride_width, padding_height, padding_width, pool_process, exclusive, + output_data); } }; @@ -208,7 +211,7 @@ class Pool2dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - framework::Tensor* input_grad) { + bool exclusive, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -236,7 +239,7 @@ class Pool2dGradFunctor { nthreads, input_data, output_data, output_grad_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, - pool_process, input_grad_data); + pool_process, exclusive, input_grad_data); } }; @@ -313,16 +316,14 @@ template class Pool2dGradFunctor; template -__global__ void KernelPool3D(const int nthreads, const T* input_data, - const int channels, const int input_depth, - const int input_height, const int input_width, - const int output_depth, const int output_height, - const int output_width, const int ksize_depth, - const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, - const int stride_width, const int padding_depth, - const int padding_height, const int padding_width, - PoolProcess pool_process, T* output_data) { +__global__ void KernelPool3D( + const int nthreads, const T* input_data, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + PoolProcess pool_process, bool exclusive, T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data, } } } - int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? + (dend - dstart) * (hend - hstart) * (wend - wstart) + : ksize_depth * ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[index] = ele; } @@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad( const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, PoolProcess pool_process, - T* input_grad) { + bool exclusive, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad( dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); - int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + int pool_size = exclusive ? + (dend - dstart) * (hend - hstart) * (wend - wstart) + : ksize_depth * ksize_height * ksize_width; int output_sub_idx = (pd * output_height + ph) * output_width + pw; pool_process.compute(input, output_data[output_sub_idx], output_grad[output_sub_idx], @@ -484,7 +489,7 @@ class Pool3dFunctor { const framework::Tensor& input, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - framework::Tensor* output) { + bool exclusive, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -518,7 +523,7 @@ class Pool3dFunctor { input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, padding_depth, padding_height, padding_width, pool_process, - output_data); + exclusive, output_data); } }; @@ -537,7 +542,7 @@ class Pool3dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_process, - framework::Tensor* input_grad) { + bool exclusive, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -573,7 +578,7 @@ class Pool3dGradFunctor { input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, padding_depth, padding_height, - padding_width, pool_process, input_grad_data); + padding_width, pool_process, exclusive, input_grad_data); } }; diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 120f59198..0f64e321b 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -89,7 +89,7 @@ class Pool2dFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - framework::Tensor* output); + bool exclusive, framework::Tensor* output); }; template @@ -101,7 +101,7 @@ class Pool2dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - framework::Tensor* input_grad); + bool exclusive, framework::Tensor* input_grad); }; template @@ -123,7 +123,7 @@ class Pool3dFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - framework::Tensor* output); + bool exclusive, framework::Tensor* output); }; template @@ -135,7 +135,7 @@ class Pool3dGradFunctor { const std::vector& ksize, const std::vector& strides, const std::vector& paddings, PoolProcess pool_compute, - framework::Tensor* input_grad); + bool exclusive, framework::Tensor* input_grad); }; template diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 31f083565..4365805b9 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel { T *output_data = output->mutable_data(ctx.GetPlace()); std::string pooling_type = ctx.Attr("pooling_type"); + bool exclusive = ctx.Attr("exclusive"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); @@ -72,7 +73,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel { if (pooling_type == "max") { pooling_mode = PoolingMode::kMaximum; } else { - pooling_mode = PoolingMode::kAverage; + pooling_mode = exclusive ? PoolingMode::kAverageExclusive : PoolingMode::kAverageInclusive; } cudnnPoolingDescriptor_t cudnn_pool_desc = @@ -101,6 +102,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { Tensor *input_grad = ctx.Output(framework::GradVarName("X")); std::string pooling_type = ctx.Attr("pooling_type"); + bool exclusive = ctx.Attr("exclusive"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); @@ -141,7 +143,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { pooling_mode = PoolingMode::kMaximum; } } else { - pooling_mode = PoolingMode::kAverage; + pooling_mode = exclusive ? PoolingMode::kAverageExclusive : PoolingMode::kAverageInclusive; } cudnnPoolingDescriptor_t cudnn_pool_desc = diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 24a5346b0..27c7e2ae8 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -180,6 +180,12 @@ void Pool2dOpMaker::Make() { "operator." "If global_pooling = true, paddings and ksize will be ignored.") .SetDefault({0, 0}); + AddAttr( + "exclusive", + "(bool, default True) When true, will exclude the zero-padding in the " + "averaging calculating, otherwise, include the zero-padding. Note, it " + "is only used when pooling_type is avg. The defalut is True.") + .SetDefault(true); AddAttr( "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") @@ -283,6 +289,12 @@ void Pool3dOpMaker::Make() { "If global_pooling = true, ksize and paddings will be ignored.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) + AddAttr( + "exclusive", + "(bool, default True) When true, will exclude the zero-padding in the " + "averaging calculating, otherwise, include the zero-padding. Note, it " + "is only used when pooling_type is avg. The defalut is True.") + .SetDefault(true); AddAttr( "use_cudnn", diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index a63963ca9..c0594b7e3 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -69,6 +69,7 @@ class PoolKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + bool exclusive = context.Attr("exclusive"); if (context.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; @@ -84,7 +85,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::MaxPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - out); + true, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dFunctor< @@ -92,7 +93,7 @@ class PoolKernel : public framework::OpKernel { pool2d_forward; paddle::operators::math::AvgPool pool_process; pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - out); + exclusive, out); } } break; case 3: { @@ -102,14 +103,14 @@ class PoolKernel : public framework::OpKernel { pool3d_forward; paddle::operators::math::MaxPool pool_process; pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - out); + true, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dFunctor< DeviceContext, paddle::operators::math::AvgPool, T> pool3d_forward; paddle::operators::math::AvgPool pool_process; pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, - out); + exclusive, out); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } @@ -131,6 +132,7 @@ class PoolGradKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + bool exclusive = context.Attr("exclusive"); if (context.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { @@ -157,7 +159,7 @@ class PoolGradKernel : public framework::OpKernel { pool2d_backward; paddle::operators::math::AvgPoolGrad pool_process; pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, - paddings, pool_process, in_x_grad); + paddings, pool_process, exclusive, in_x_grad); } } break; case 3: { @@ -172,7 +174,7 @@ class PoolGradKernel : public framework::OpKernel { pool3d_backward; paddle::operators::math::AvgPoolGrad pool_process; pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, - paddings, pool_process, in_x_grad); + paddings, pool_process, exclusive, in_x_grad); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } diff --git a/paddle/fluid/operators/spp_op.h b/paddle/fluid/operators/spp_op.h index 08cb7849d..35d9737ee 100644 --- a/paddle/fluid/operators/spp_op.h +++ b/paddle/fluid/operators/spp_op.h @@ -56,12 +56,14 @@ class SppKernel : public framework::OpKernel { math::Pool2dFunctor, T> pool_forward; math::MaxPool max_process; pool_forward(context.template device_context(), *in_x, - kernel_size, strides, paddings, max_process, &out_level); + kernel_size, strides, paddings, max_process, true, + &out_level); } else if (pooling_type == "avg") { math::Pool2dFunctor, T> pool_forward; math::AvgPool avg_process; pool_forward(context.template device_context(), *in_x, - kernel_size, strides, paddings, avg_process, &out_level); + kernel_size, strides, paddings, avg_process, true, + &out_level); } // flatten pooling output shape int output_flatten_w = in_x->dims()[1] * bins * bins; @@ -154,7 +156,7 @@ class SppGradKernel : public framework::OpKernel { math::AvgPoolGrad avg_process; pool_backward(context.template device_context(), *in_x, *&out_level, *&outgrad_level, kernel_size, strides, - paddings, avg_process, in_x_grad); + paddings, avg_process, true, in_x_grad); } } } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index bb8b14bb9..1d1ec08b2 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -76,8 +76,9 @@ enum class DataLayout { // Not use enum class PoolingMode { kMaximum, - kAverage, kMaximumDeterministic, + kAverageExclusive, + kAverageInclusive, }; #if CUDNN_VERSION < 6000 @@ -91,8 +92,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { switch (mode) { case PoolingMode::kMaximumDeterministic: return CUDNN_POOLING_MAX; - case PoolingMode::kAverage: + case PoolingMode::kAverageExclusive: return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + case PoolingMode::kAverageInclusive: + return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; case PoolingMode::kMaximum: return CUDNN_POOLING_MAX; default: @@ -105,8 +108,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { switch (mode) { case PoolingMode::kMaximumDeterministic: return CUDNN_POOLING_MAX_DETERMINISTIC; - case PoolingMode::kAverage: + case PoolingMode::kAverageExclusive: return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + case PoolingMode::kAverageInclusive: + return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; case PoolingMode::kMaximum: return CUDNN_POOLING_MAX; default: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4bfa89d9f..692084813 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2067,6 +2067,7 @@ def pool2d(input, global_pooling=False, use_cudnn=True, ceil_mode=False, + exclusive=True, name=None): """ ${comment} @@ -2081,9 +2082,11 @@ def pool2d(input, pool_type: ${pooling_type_comment} pool_stride (int): stride of the pooling layer. pool_padding (int): padding size. - global_pooling: ${global_pooling_comment} - use_cudnn: ${use_cudnn_comment} - ceil_mode: ${ceil_mode_comment} + global_pooling (bool): ${global_pooling_comment} + use_cudnn (bool): ${use_cudnn_comment} + ceil_mode (bool): ${ceil_mode_comment} + exclusive (bool): Whether to exclude padding points in average pooling + mode, default is true name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2143,7 +2146,8 @@ def pool2d(input, "paddings": pool_padding, "use_cudnn": use_cudnn, "ceil_mode": ceil_mode, - "use_mkldnn": False + "use_mkldnn": False, + "exclusive": exclusive, }) return pool_out @@ -2157,6 +2161,7 @@ def pool3d(input, global_pooling=False, use_cudnn=True, ceil_mode=False, + exclusive=True, name=None): """ This function adds the operator for pooling in 3-dimensions, using the @@ -2171,6 +2176,8 @@ def pool3d(input, global_pooling (bool): ${global_pooling_comment} use_cudnn (bool): ${use_cudnn_comment} ceil_mode (bool): ${ceil_mode_comment} + exclusive (bool): Whether to exclude padding points in average pooling + mode, default is true name (str): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2211,7 +2218,8 @@ def pool3d(input, "paddings": pool_padding, "use_cudnn": use_cudnn, "ceil_mode": ceil_mode, - "use_mkldnn": False + "use_mkldnn": False, + "exclusive": exclusive, }) return pool_out diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 26969bd52..c627336f4 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -26,7 +26,8 @@ def max_pool2D_forward_naive(x, strides, paddings, global_pool=0, - ceil_mode=False): + ceil_mode=False, + exclusive=True): N, C, H, W = x.shape if global_pool == 1: ksize = [H, W] @@ -54,7 +55,8 @@ def avg_pool2D_forward_naive(x, strides, paddings, global_pool=0, - ceil_mode=False): + ceil_mode=False, + exclusive=True): N, C, H, W = x.shape if global_pool == 1: ksize = [H, W] @@ -73,8 +75,9 @@ def avg_pool2D_forward_naive(x, c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) x_masked = x[:, :, r_start:r_end, c_start:c_end] - out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / ( - (r_end - r_start) * (c_end - c_start)) + field_size = ((r_end - r_start) * (c_end - c_start)) if exclusive \ + else (ksize[0] * ksize[1]) + out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size return out @@ -89,12 +92,13 @@ class TestPool2d_Op(OpTest): self.init_kernel_type() self.init_pool_type() self.init_ceil_mode() + self.init_exclusive() if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] input = np.random.random(self.shape).astype(self.dtype) output = self.pool2D_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode).astype(self.dtype) + self.ceil_mode, self.exclusive).astype(self.dtype) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { @@ -106,7 +110,8 @@ class TestPool2d_Op(OpTest): 'use_cudnn': self.use_cudnn, 'use_mkldnn': self.use_mkldnn, 'ceil_mode': self.ceil_mode, - 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter + 'data_format': 'AnyLayout', # TODO(dzhwinter) : should be fix latter + 'exclusive': self.exclusive } self.outputs = {'Out': output} @@ -150,6 +155,9 @@ class TestPool2d_Op(OpTest): def init_ceil_mode(self): self.ceil_mode = False + def init_exclusive(self): + self.exclusive = True + class TestCase1(TestPool2d_Op): def init_test_case(self): @@ -321,6 +329,14 @@ class TestCeilModeCase4(TestCase2): def init_ceil_mode(self): self.ceil_mode = True +class TestAvgInclude(TestCase2): + def init_exclusive(self): + self.exclusive = False + +class TestCUDNNAvgInclude(TestCUDNNCase3): + def init_exclusive(self): + self.exclusive = False + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index 77045c130..20dc2eefa 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -26,7 +26,8 @@ def max_pool3D_forward_naive(x, strides, paddings, global_pool=0, - ceil_mode=False): + ceil_mode=False, + exclusive=True): N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] @@ -60,7 +61,8 @@ def avg_pool3D_forward_naive(x, strides, paddings, global_pool=0, - ceil_mode=False): + ceil_mode=False, + exclusive=True): N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] @@ -85,8 +87,9 @@ def avg_pool3D_forward_naive(x, w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] - out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / ( - (d_end - d_start) * (h_end - h_start) * (w_end - w_start)) + field_size = (d_end - d_start) * (h_end - h_start) * (w_end - w_start) \ + if exclusive else ksize[0] * ksize[1] * ksize[2] + out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / field_size return out @@ -100,13 +103,14 @@ class TestPool3d_Op(OpTest): self.init_kernel_type() self.init_pool_type() self.init_ceil_mode() + self.init_exclusive() if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] input = np.random.random(self.shape).astype(self.dtype) output = self.pool3D_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode).astype(self.dtype) + self.ceil_mode, self.exclusive).astype(self.dtype) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { @@ -117,7 +121,8 @@ class TestPool3d_Op(OpTest): 'global_pooling': self.global_pool, 'use_cudnn': self.use_cudnn, 'ceil_mode': self.ceil_mode, - 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter + 'data_format': 'AnyLayout', # TODO(dzhwinter) : should be fix latter + 'exclusive': self.exclusive } self.outputs = {'Out': output} @@ -161,6 +166,9 @@ class TestPool3d_Op(OpTest): def init_ceil_mode(self): self.ceil_mode = False + def init_exclusive(self): + self.exclusive = True + class TestCase1(TestPool3d_Op): def init_test_case(self): @@ -332,6 +340,14 @@ class TestCeilModeCase4(TestCase2): def init_ceil_mode(self): self.ceil_mode = True +class TestAvgInclude(TestCase2): + def init_exclusive(self): + self.exclusive = False + +class TestCUDNNAvgInclude(TestCUDNNCase3): + def init_exclusive(self): + self.exclusive = False + if __name__ == '__main__': unittest.main() -- GitLab