diff --git a/paddle/fluid/operators/math/pooling.cc b/paddle/fluid/operators/math/pooling.cc index b4ee82add3133021f5e4bd24dbfeb0e58f9b20ff..30873e9f87f22fa5b39cbf519760a9ec3979f98b 100644 --- a/paddle/fluid/operators/math/pooling.cc +++ b/paddle/fluid/operators/math/pooling.cc @@ -19,16 +19,6 @@ namespace paddle { namespace operators { namespace math { -static inline int ADAPT_START_INDEX(int ph, int input_size, int output_size) { - return static_cast( - floor(static_cast(ph * input_size) / output_size)); -} - -static inline int ADAPT_END_INDEX(int ph, int input_size, int output_size) { - return static_cast( - ceil(static_cast((ph + 1) * input_size) / output_size)); -} - /* * All tensors are in NCHW format. * Ksize, strides, paddings are two elements. These two elements represent @@ -67,8 +57,8 @@ class Pool2dFunctor { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -76,8 +66,8 @@ class Pool2dFunctor { } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -144,8 +134,8 @@ class Pool2dGradFunctor { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -153,8 +143,8 @@ class Pool2dGradFunctor { } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -319,8 +309,8 @@ class Pool3dFunctor { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { if (adaptive) { - dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = AdaptStartIndex(pd, input_depth, output_depth); + dend = AdaptEndIndex(pd, input_depth, output_depth); } else { dstart = pd * stride_depth - padding_depth; dend = std::min(dstart + ksize_depth, input_depth); @@ -328,8 +318,8 @@ class Pool3dFunctor { } for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -337,8 +327,8 @@ class Pool3dFunctor { } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -417,8 +407,8 @@ class Pool3dGradFunctor { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { if (adaptive) { - dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = AdaptStartIndex(pd, input_depth, output_depth); + dend = AdaptEndIndex(pd, input_depth, output_depth); } else { dstart = pd * stride_depth - padding_depth; dend = std::min(dstart + ksize_depth, input_depth); @@ -426,8 +416,8 @@ class Pool3dGradFunctor { } for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -435,8 +425,8 @@ class Pool3dGradFunctor { } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -615,8 +605,8 @@ class MaxPool2dWithIndexFunctor { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -624,8 +614,8 @@ class MaxPool2dWithIndexFunctor { } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -753,8 +743,8 @@ class MaxPool3dWithIndexFunctor { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { if (adaptive) { - dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = AdaptStartIndex(pd, input_depth, output_depth); + dend = AdaptEndIndex(pd, input_depth, output_depth); } else { dstart = pd * stride_depth - padding_depth; dend = std::min(dstart + ksize_depth, input_depth); @@ -762,8 +752,8 @@ class MaxPool3dWithIndexFunctor { } for (int ph = 0; ph < output_height; ++ph) { if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -771,8 +761,8 @@ class MaxPool3dWithIndexFunctor { } for (int pw = 0; pw < output_width; ++pw) { if (adaptive) { - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 5f3b82ed553d6b96a312fdc025ecb3da8b14fdcd..efce3f899a449c72ae3298f7ce0defb166ee8329 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -21,18 +21,6 @@ namespace paddle { namespace operators { namespace math { -__device__ __forceinline__ int ADAPT_START_INDEX(int ph, int input_size, - int output_size) { - return static_cast( - floor(static_cast(ph * input_size) / output_size)); -} - -__device__ __forceinline__ int ADAPT_END_INDEX(int ph, int input_size, - int output_size) { - return static_cast( - ceil(static_cast((ph + 1) * input_size) / output_size)); -} - template __global__ void KernelPool2D(const int nthreads, const T* input_data, const int channels, const int input_height, @@ -52,11 +40,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, int hstart, hend; int wstart, wend; if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { hstart = ph * stride_height - padding_height; hend = min(hstart + ksize_height, input_height); @@ -91,28 +79,29 @@ __global__ void KernelPool2DGrad( PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int offsetW = index % input_width + padding_width; - int offsetH = (index / input_width) % input_height + padding_height; + int w_offset = index % input_width + padding_width; + int h_offset = (index / input_width) % input_height + padding_height; int offsetC = (index / input_width / input_height) % channels; int batch_idx = index / input_width / input_height / channels; int phstart, phend; int pwstart, pwend; if (adaptive) { - phstart = offsetH * output_height / input_height; + phstart = h_offset * output_height / input_height; phend = - min((offsetH + 1) * output_height / input_height + 1, output_height); - pwstart = offsetW * output_width / input_width; - pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + min((h_offset + 1) * output_height / input_height + 1, output_height); + pwstart = w_offset * output_width / input_width; + pwend = + min((w_offset + 1) * output_width / input_width + 1, output_width); } else { - phstart = (offsetH < ksize_height) + phstart = (h_offset < ksize_height) ? 0 - : (offsetH - ksize_height) / stride_height + 1; - pwstart = (offsetW < ksize_width) + : (h_offset - ksize_height) / stride_height + 1; + pwstart = (w_offset < ksize_width) ? 0 - : (offsetW - ksize_width) / stride_width + 1; - phend = min(offsetH / stride_height + 1, output_height); - pwend = min(offsetW / stride_width + 1, output_width); + : (w_offset - ksize_width) / stride_width + 1; + phend = min(h_offset / stride_height + 1, output_height); + pwend = min(w_offset / stride_width + 1, output_width); } T gradient = 0; T input = input_data[index]; @@ -414,14 +403,14 @@ __global__ void KernelPool3D( int hstart, hend; int wstart, wend; if (adaptive) { - dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = AdaptStartIndex(pd, input_depth, output_depth); + dend = AdaptEndIndex(pd, input_depth, output_depth); - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { dstart = pd * stride_depth - padding_depth; hstart = ph * stride_height - padding_height; @@ -464,9 +453,9 @@ __global__ void KernelPool3DGrad( bool exclusive, bool adaptive, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int offsetW = index % input_width + padding_width; - int offsetH = (index / input_width) % input_height + padding_height; - int offsetD = + int w_offset = index % input_width + padding_width; + int h_offset = (index / input_width) % input_height + padding_height; + int d_offset = (index / input_width / input_height) % input_depth + padding_depth; int offsetC = (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; @@ -475,26 +464,28 @@ __global__ void KernelPool3DGrad( int phstart, phend; int pwstart, pwend; if (adaptive) { - pdstart = offsetD * output_depth / input_depth; - pdend = min((offsetD + 1) * output_depth / input_depth + 1, output_depth); - phstart = offsetH * output_height / input_height; + pdstart = d_offset * output_depth / input_depth; + pdend = + min((d_offset + 1) * output_depth / input_depth + 1, output_depth); + phstart = h_offset * output_height / input_height; phend = - min((offsetH + 1) * output_height / input_height + 1, output_height); - pwstart = offsetW * output_width / input_width; - pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + min((h_offset + 1) * output_height / input_height + 1, output_height); + pwstart = w_offset * output_width / input_width; + pwend = + min((w_offset + 1) * output_width / input_width + 1, output_width); } else { - pdstart = (offsetD < ksize_depth) + pdstart = (d_offset < ksize_depth) ? 0 - : (offsetD - ksize_depth) / stride_depth + 1; - phstart = (offsetH < ksize_height) + : (d_offset - ksize_depth) / stride_depth + 1; + phstart = (h_offset < ksize_height) ? 0 - : (offsetH - ksize_height) / stride_height + 1; - pwstart = (offsetW < ksize_width) + : (h_offset - ksize_height) / stride_height + 1; + pwstart = (w_offset < ksize_width) ? 0 - : (offsetW - ksize_width) / stride_width + 1; - pdend = min((offsetD) / stride_depth + 1, output_depth); - phend = min((offsetH) / stride_height + 1, output_height); - pwend = min((offsetW) / stride_width + 1, output_width); + : (w_offset - ksize_width) / stride_width + 1; + pdend = min((d_offset) / stride_depth + 1, output_depth); + phend = min((h_offset) / stride_height + 1, output_height); + pwend = min((w_offset) / stride_width + 1, output_width); } T gradient = 0; @@ -795,11 +786,11 @@ __global__ void KernelMaxPool2dWithIdx( int hstart, hend; int wstart, wend; if (adaptive) { - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { hstart = ph * stride_height - padding_height; hend = min(hstart + ksize_height, input_height); @@ -837,35 +828,36 @@ __global__ void KernelMaxPool2DWithIdxGrad( T1* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int offsetW = index % input_width; - int offsetH = (index / input_width) % input_height; + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; int offsetC = (index / input_width / input_height) % channels; int batch_idx = index / input_width / input_height / channels; int phstart, phend; int pwstart, pwend; if (adaptive) { - phstart = offsetH * output_height / input_height; + phstart = h_offset * output_height / input_height; phend = - min((offsetH + 1) * output_height / input_height + 1, output_height); - pwstart = offsetW * output_width / input_width; - pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + min((h_offset + 1) * output_height / input_height + 1, output_height); + pwstart = w_offset * output_width / input_width; + pwend = + min((w_offset + 1) * output_width / input_width + 1, output_width); } else { phstart = - (offsetH + padding_height < ksize_height) + (h_offset + padding_height < ksize_height) ? 0 - : (offsetH + padding_height - ksize_height) / stride_height + 1; + : (h_offset + padding_height - ksize_height) / stride_height + 1; pwstart = - (offsetW + padding_width < ksize_width) + (w_offset + padding_width < ksize_width) ? 0 - : (offsetW + padding_width - ksize_width) / stride_width + 1; + : (w_offset + padding_width - ksize_width) / stride_width + 1; phend = - min((offsetH + padding_height) / stride_height + 1, output_height); - pwend = min((offsetW + padding_width) / stride_width + 1, output_width); + min((h_offset + padding_height) / stride_height + 1, output_height); + pwend = min((w_offset + padding_width) / stride_width + 1, output_width); } T1 gradient = 0; - int input_current_featuremap_idx = offsetH * input_width + offsetW; + int input_current_featuremap_idx = h_offset * input_width + w_offset; int output_idx = (batch_idx * channels + offsetC) * output_height * output_width; @@ -1000,14 +992,14 @@ __global__ void KernelMaxPool3DWithIdx( int hstart, hend; int wstart, wend; if (adaptive) { - dstart = ADAPT_START_INDEX(pd, input_depth, output_depth); - dend = ADAPT_END_INDEX(pd, input_depth, output_depth); + dstart = AdaptStartIndex(pd, input_depth, output_depth); + dend = AdaptEndIndex(pd, input_depth, output_depth); - hstart = ADAPT_START_INDEX(ph, input_height, output_height); - hend = ADAPT_END_INDEX(ph, input_height, output_height); + hstart = AdaptStartIndex(ph, input_height, output_height); + hend = AdaptEndIndex(ph, input_height, output_height); - wstart = ADAPT_START_INDEX(pw, input_width, output_width); - wend = ADAPT_END_INDEX(pw, input_width, output_width); + wstart = AdaptStartIndex(pw, input_width, output_width); + wend = AdaptEndIndex(pw, input_width, output_width); } else { dstart = pd * stride_depth - padding_depth; hstart = ph * stride_height - padding_height; @@ -1051,9 +1043,9 @@ __global__ void KernelMaxPool3DWithIdxGrad( const int padding_width, bool adaptive, T1* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { - int offsetW = index % input_width; - int offsetH = (index / input_width) % input_height; - int offsetD = (index / input_width / input_height) % input_depth; + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int d_offset = (index / input_width / input_height) % input_depth; int offsetC = (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; @@ -1061,35 +1053,37 @@ __global__ void KernelMaxPool3DWithIdxGrad( int phstart, phend; int pwstart, pwend; if (adaptive) { - pdstart = offsetD * output_depth / input_depth; - pdend = min((offsetD + 1) * output_depth / input_depth + 1, output_depth); - phstart = offsetH * output_height / input_height; + pdstart = d_offset * output_depth / input_depth; + pdend = + min((d_offset + 1) * output_depth / input_depth + 1, output_depth); + phstart = h_offset * output_height / input_height; phend = - min((offsetH + 1) * output_height / input_height + 1, output_height); - pwstart = offsetW * output_width / input_width; - pwend = min((offsetW + 1) * output_width / input_width + 1, output_width); + min((h_offset + 1) * output_height / input_height + 1, output_height); + pwstart = w_offset * output_width / input_width; + pwend = + min((w_offset + 1) * output_width / input_width + 1, output_width); } else { pdstart = - (offsetD + padding_depth < ksize_depth) + (d_offset + padding_depth < ksize_depth) ? 0 - : (offsetD + padding_depth - ksize_depth) / stride_depth + 1; + : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; phstart = - (offsetH + padding_height < ksize_height) + (h_offset + padding_height < ksize_height) ? 0 - : (offsetH + padding_height - ksize_height) / stride_height + 1; + : (h_offset + padding_height - ksize_height) / stride_height + 1; pwstart = - (offsetW + padding_width < ksize_width) + (w_offset + padding_width < ksize_width) ? 0 - : (offsetW + padding_width - ksize_width) / stride_width + 1; - pdend = min((offsetD + padding_depth) / stride_depth + 1, output_depth); + : (w_offset + padding_width - ksize_width) / stride_width + 1; + pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth); phend = - min((offsetH + padding_height) / stride_height + 1, output_height); - pwend = min((offsetW + padding_width) / stride_width + 1, output_width); + min((h_offset + padding_height) / stride_height + 1, output_height); + pwend = min((w_offset + padding_width) / stride_width + 1, output_width); } T1 gradient = 0; int input_current_feature_map_idx = - (offsetD * input_height + offsetH) * input_width + offsetW; + (d_offset * input_height + h_offset) * input_width + w_offset; int output_idx = (batch_idx * channels + offsetC) * output_depth * output_height * output_width; mask += output_idx; diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index d123af8924b77243e5a0d4d1212d35e16a44c3ae..e1f8e6df1d19b519e48bff326bc1aa9548c96905 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -68,6 +68,18 @@ class AvgPoolGrad { } }; +/* used for adaptive pool to calculate start and end index of each divided grid + */ +HOSTDEVICE inline int AdaptStartIndex(int ph, int input_size, int output_size) { + return static_cast( + floor(static_cast(ph * input_size) / output_size)); +} + +HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { + return static_cast( + ceil(static_cast((ph + 1) * input_size) / output_size)); +} + /* * \brief Getting pooling results, and calculating gradient. * diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 61794f0d49a13a1796ef3f6eecdcd8c299327920..07fc4ccc6bc2668bd86787e3814fc6ede9c641ea 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2506,7 +2506,7 @@ def adaptive_pool2d(input, pool_size, pool_type="max", require_index=False, - use_cudnn=True, + use_cudnn=False, name=None): """ ${comment} @@ -2521,7 +2521,7 @@ def adaptive_pool2d(input, pool_type: ${pooling_type_comment} require_index (bool): If true, the index of max pooling point along with outputs. it cannot be set in average pooling type. - use_cudnn (bool): ${use_cudnn_comment} + use_cudnn (bool, default False): adaptive pool currently not supported in cudnn. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2531,6 +2531,7 @@ def adaptive_pool2d(input, Raises: ValueError: 'pool_type' is not 'max' nor 'avg'. ValueError: 'use_cudnn' is not a bool value. + ValueError: adaptive pool currently not supported in cudnn. ValueError: invalid setting 'require_index' true when 'pool_type' is 'avg'. ValueError: 'pool_size' should be a list or tuple with length as 2. @@ -2540,11 +2541,11 @@ def adaptive_pool2d(input, data = fluid.layers.data( name='data', shape=[3, 32, 32], dtype='float32') - conv2d = fluid.layers.pool2d( + pool_out = fluid.layers.adaptive_pool2d( input=data, pool_size=[3, 3], pool_type='max', - require_index=True) + require_index=False) """ if pool_type not in ["max", "avg"]: raise ValueError( @@ -2565,6 +2566,9 @@ def adaptive_pool2d(input, if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False.") + if use_cudnn: + raise ValueError("adaptive pool currently not supported in cudnn.") + if pool_type == "max": l_type = 'max_pool2d_with_index' else: @@ -2590,7 +2594,7 @@ def adaptive_pool2d(input, "adaptive": True, }) - return pool_out + return (pool_out, mask) if require_index else pool_out @templatedoc(op_type="pool3d") @@ -2598,7 +2602,7 @@ def adaptive_pool3d(input, pool_size, pool_type="max", require_index=False, - use_cudnn=True, + use_cudnn=False, name=None): """ ${comment} @@ -2613,7 +2617,7 @@ def adaptive_pool3d(input, pool_type: ${pooling_type_comment} require_index (bool): If true, the index of max pooling point along with outputs. it cannot be set in average pooling type. - use_cudnn (bool): ${use_cudnn_comment} + use_cudnn (bool, default False): adaptive pool currently not supported in cudnn. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2623,6 +2627,7 @@ def adaptive_pool3d(input, Raises: ValueError: 'pool_type' is not 'max' nor 'avg'. ValueError: 'use_cudnn' is not a bool value. + ValueError: adaptive pool currently not supported in cudnn. ValueError: invalid setting 'require_index' true when 'pool_type' is 'avg'. ValueError: 'pool_size' should be a list or tuple with length as 2. @@ -2632,7 +2637,7 @@ def adaptive_pool3d(input, data = fluid.layers.data( name='data', shape=[3, 32, 32], dtype='float32') - conv2d = fluid.layers.pool2d( + pool_out, mask = fluid.layers.adaptive_pool3d( input=data, pool_size=[3, 3], pool_type='max', @@ -2657,6 +2662,9 @@ def adaptive_pool3d(input, if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False.") + if use_cudnn: + raise ValueError("adaptive pool currently not supported in cudnn.") + if pool_type == "max": l_type = 'max_pool3d_with_index' else: @@ -2682,7 +2690,7 @@ def adaptive_pool3d(input, "adaptive": True, }) - return pool_out + return (pool_out, mask) if require_index else pool_out def batch_norm(input, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 9785b5063cdd1d90462fa63a4dc2d567926fc061..030bf012fa5b17030a62419f30a4b680a72c3fee 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -237,23 +237,24 @@ class TestBook(unittest.TestCase): program = Program() with program_guard(program): x = layers.data(name='x', shape=[3, 224, 224], dtype='float32') - self.assertIsNotNone( - layers.adaptive_pool2d( - x, [3, 3], require_index=True)) self.assertIsNotNone( layers.adaptive_pool2d( x, [3, 3], pool_type='avg')) + pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True) + self.assertIsNotNone(pool) + self.assertIsNotNone(mask) def test_adaptive_pool3d(self): program = Program() with program_guard(program): x = layers.data(name='x', shape=[3, 244, 224, 224], dtype='float32') - self.assertIsNotNone( - layers.adaptive_pool3d( - x, [3, 3, 3], require_index=True)) self.assertIsNotNone( layers.adaptive_pool3d( x, [3, 3, 3], pool_type='avg')) + pool, mask = layers.adaptive_pool3d( + x, [3, 3, 3], require_index=True) + self.assertIsNotNone(pool) + self.assertIsNotNone(mask) def test_lstm_unit(self): program = Program()