diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 1918c8b16916f1be603399dbbc34b991febface5..3b706529d8f1ed0d673904b81047a5614bd4cf23 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -24,7 +24,7 @@ class Pool2dFunctor { void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -54,14 +54,15 @@ class Pool2dFunctor { int wstart = pw * stride_width - padding_width; int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); - T ele = pool_compute.initial(); + + T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_compute.compute(ele, input_data[h * input_width + w]); + pool_process.compute(ele, input_data[h * input_width + w]); } } int pool_size = (hend - hstart) * (wend - wstart); - pool_compute.finalize(ele, (static_cast(pool_size))); + pool_process.finalize(ele, (static_cast(pool_size))); output_data[ph * output_width + pw] = ele; } } @@ -80,7 +81,7 @@ class Pool2dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute) { + PoolProcess pool_grad_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -115,11 +116,12 @@ class Pool2dGradFunctor { float scale = 1.0 / pool_size; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_compute.compute(input_data[h * input_width + w], - output_data[ph * output_width + pw], - output_grad_data[ph * output_width + pw], - input_grad_data[h * input_width + w], - static_cast(scale)); + pool_grad_process.compute( + input_data[h * input_width + w], + output_data[ph * output_width + pw], + output_grad_data[ph * output_width + pw], + input_grad_data[h * input_width + w], + static_cast(scale)); } } } @@ -198,21 +200,21 @@ template class MaxPool2dGradFunctor; // template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; + paddle::operators::math::MaxPool, float>; template class Pool2dFunctor, float>; + paddle::operators::math::AvgPool, float>; template class Pool2dGradFunctor< - platform::CPUPlace, paddle::operators::math::maxPoolGrad, float>; + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, float>; template class Pool2dGradFunctor< - platform::CPUPlace, paddle::operators::math::avgPoolGrad, float>; + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, float>; template class Pool2dFunctor, double>; + paddle::operators::math::MaxPool, double>; template class Pool2dFunctor, double>; + paddle::operators::math::AvgPool, double>; template class Pool2dGradFunctor< - platform::CPUPlace, paddle::operators::math::maxPoolGrad, double>; + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool2dGradFunctor< - platform::CPUPlace, paddle::operators::math::avgPoolGrad, double>; + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; template class Pool3dFunctor { @@ -220,7 +222,7 @@ class Pool3dFunctor { void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -260,11 +262,11 @@ class Pool3dFunctor { int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); int output_idx = (pd * output_height + ph) * output_width + pw; - T ele = pool_compute.initial(); + T ele = pool_process.initial(); for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_compute.compute( + pool_process.compute( ele, input_data[(d * input_height + h) * input_width + w]); } @@ -272,7 +274,7 @@ class Pool3dFunctor { } int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); - pool_compute.finalize(ele, static_cast(pool_size)); + pool_process.finalize(ele, static_cast(pool_size)); output_data[output_idx] = ele; } } @@ -292,7 +294,7 @@ class Pool3dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute) { + PoolProcess pool_grad_process) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -343,7 +345,7 @@ class Pool3dGradFunctor { int input_idx = (d * input_height + h) * input_width + w; int output_idx = (pd * output_height + ph) * output_width + pw; - pool_compute.compute( + pool_grad_process.compute( input_data[input_idx], output_data[output_idx], output_grad_data[output_idx], input_grad_data[input_idx], static_cast(scale)); @@ -441,21 +443,21 @@ template class MaxPool3dGradFunctor; // template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; + paddle::operators::math::MaxPool, float>; template class Pool3dFunctor, float>; + paddle::operators::math::AvgPool, float>; template class Pool3dGradFunctor< - platform::CPUPlace, paddle::operators::math::maxPoolGrad, float>; + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, float>; template class Pool3dGradFunctor< - platform::CPUPlace, paddle::operators::math::avgPoolGrad, float>; + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, float>; template class Pool3dFunctor, double>; + paddle::operators::math::MaxPool, double>; template class Pool3dFunctor, double>; + paddle::operators::math::AvgPool, double>; template class Pool3dGradFunctor< - platform::CPUPlace, paddle::operators::math::maxPoolGrad, double>; + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool3dGradFunctor< - platform::CPUPlace, paddle::operators::math::avgPoolGrad, double>; + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 4164920035ff9da8c33eb9e502a8cdffe605ebb7..8aeccd1f8e8855c51ad85016f0cb239b4c9c8fb0 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -20,14 +20,16 @@ namespace operators { namespace math { template -__global__ void KernelPool2dForward( - const int nthreads, const T* input_data, T* output_data, const int channels, - const int input_height, const int input_width, const int output_height, - const int output_width, const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, const int padding_height, - const int padding_width, PoolProcess pool_compute) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { +__global__ void KernelPool2D(const int nthreads, const T* input_data, + T* output_data, const int channels, + const int input_height, const int input_width, + const int output_height, const int output_width, + const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, + const int padding_height, const int padding_width, + PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; int c = (index / output_width / output_height) % channels; @@ -42,28 +44,28 @@ __global__ void KernelPool2dForward( wstart = max(wstart, 0); input_data += (batch_idx * channels + c) * input_height * input_width; - T ele = pool_compute.initial(); + T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_compute.compute(ele, input_data[h * input_width + w]); + pool_process.compute(ele, input_data[h * input_width + w]); } } int pool_size = (hend - hstart) * (wend - wstart); - pool_compute.finalize(ele, (static_cast(pool_size))); + pool_process.finalize(ele, (static_cast(pool_size))); output_data[index] = ele; } } template -__global__ void KernelPool2dBackward( +__global__ void KernelPool2DGrad( const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, - const int padding_width, PoolProcess pool_compute) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { + const int padding_width, PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; int offsetH = (index / input_width) % input_height + padding_height; int offsetC = (index / input_width / input_height) % channels; @@ -93,7 +95,7 @@ __global__ void KernelPool2dBackward( wstart = max(wstart, 0); int pool_size = (hend - hstart) * (wend - wstart); int output_sub_idx = ph * output_width + pw; - pool_compute.compute(input, output_data[output_sub_idx], + pool_process.compute(input, output_data[output_sub_idx], output_grad[output_sub_idx], gradient, static_cast(1.0 / pool_size)); } @@ -103,15 +105,15 @@ __global__ void KernelPool2dBackward( } template -__global__ void KernelMaxPool2dBackward( +__global__ void KernelMaxPool2DGrad( const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; int c = (index / output_width / output_height) % channels; @@ -153,7 +155,7 @@ class Pool2dFunctor { void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -176,7 +178,7 @@ class Pool2dFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelPool2dForward< + KernelPool2D< PoolProcess, T><<(context) @@ -184,7 +186,7 @@ class Pool2dFunctor { input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, - padding_width, pool_compute); + padding_width, pool_process); } }; @@ -196,7 +198,7 @@ class Pool2dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute) { + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -220,7 +222,7 @@ class Pool2dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelPool2dBackward< + KernelPool2DGrad< PoolProcess, T><<(context) @@ -228,7 +230,7 @@ class Pool2dGradFunctor { nthreads, input_data, output_data, output_grad_data, input_grad_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, - padding_width, pool_compute); + padding_width, pool_process); } }; @@ -264,7 +266,7 @@ class MaxPool2dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool2dBackward< + KernelMaxPool2DGrad< T><<(context) .stream()>>>( @@ -276,35 +278,37 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; +// template class MaxPool2dGradFunctor; // The +// 64-bit floating-point version of atomicAdd() is only supported by devices of +// compute capability 6.x and higher. template class Pool2dFunctor, float>; + paddle::operators::math::MaxPool, float>; template class Pool2dFunctor, float>; + paddle::operators::math::AvgPool, float>; template class Pool2dGradFunctor< - platform::GPUPlace, paddle::operators::math::maxPoolGrad, float>; + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, float>; template class Pool2dGradFunctor< - platform::GPUPlace, paddle::operators::math::avgPoolGrad, float>; + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, float>; template class Pool2dFunctor, double>; + paddle::operators::math::MaxPool, double>; template class Pool2dFunctor, double>; + paddle::operators::math::AvgPool, double>; template class Pool2dGradFunctor< - platform::GPUPlace, paddle::operators::math::maxPoolGrad, double>; + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool2dGradFunctor< - platform::GPUPlace, paddle::operators::math::avgPoolGrad, double>; + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; template -__global__ void KernelPool3DForward( +__global__ void KernelPool3D( const int nthreads, const T* input_data, T* output_data, const int channels, const int input_depth, const int input_height, const int input_width, const int output_depth, const int output_height, const int output_width, const int ksize_depth, const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_compute) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; @@ -321,25 +325,25 @@ __global__ void KernelPool3DForward( dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); - T ele = pool_compute.initial(); + T ele = pool_process.initial(); input_data += (batch_idx * channels + c) * input_depth * input_height * input_width; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - pool_compute.compute( + pool_process.compute( ele, input_data[(d * input_height + h) * input_width + w]); } } } int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); - pool_compute.finalize(ele, static_cast(pool_size)); + pool_process.finalize(ele, static_cast(pool_size)); output_data[index] = ele; } } template -__global__ void KernelPool3DBackward( +__global__ void KernelPool3DGrad( const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_depth, const int input_height, const int input_width, @@ -347,8 +351,8 @@ __global__ void KernelPool3DBackward( const int ksize_depth, const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_compute) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; int offsetH = (index / input_width) % input_height + padding_height; @@ -392,7 +396,7 @@ __global__ void KernelPool3DBackward( wstart = max(wstart, 0); int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); int output_sub_idx = (pd * output_height + ph) * output_width + pw; - pool_compute.compute(input, output_data[output_sub_idx], + pool_process.compute(input, output_data[output_sub_idx], output_grad[output_sub_idx], gradient, static_cast(1.0 / pool_size)); } @@ -403,7 +407,7 @@ __global__ void KernelPool3DBackward( } template -__global__ void KernelMaxPool3DBackward( +__global__ void KernelMaxPool3DGrad( const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_depth, const int input_height, const int input_width, @@ -412,7 +416,7 @@ __global__ void KernelMaxPool3DBackward( const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; @@ -460,7 +464,7 @@ class Pool3dFunctor { void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -489,7 +493,7 @@ class Pool3dFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelPool3DForward< + KernelPool3D< PoolProcess, T><<(context) @@ -498,7 +502,7 @@ class Pool3dFunctor { input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, padding_depth, padding_height, padding_width, - pool_compute); + pool_process); } }; @@ -510,7 +514,7 @@ class Pool3dGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute) { + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -541,7 +545,7 @@ class Pool3dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelPool3DBackward< + KernelPool3DGrad< PoolProcess, T><<(context) @@ -550,7 +554,7 @@ class Pool3dGradFunctor { input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, pool_compute); + padding_height, padding_width, pool_process); } }; @@ -592,7 +596,7 @@ class MaxPool3dGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool3DBackward< + KernelMaxPool3DGrad< T><<(context) .stream()>>>( @@ -605,24 +609,26 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; +// template class MaxPool3dGradFunctor; // The +// 64-bit floating-point version of atomicAdd() is only supported by devices of +// compute capability 6.x and higher. template class Pool3dFunctor, float>; + paddle::operators::math::MaxPool, float>; template class Pool3dFunctor, float>; + paddle::operators::math::AvgPool, float>; template class Pool3dGradFunctor< - platform::GPUPlace, paddle::operators::math::maxPoolGrad, float>; + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, float>; template class Pool3dGradFunctor< - platform::GPUPlace, paddle::operators::math::avgPoolGrad, float>; + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, float>; template class Pool3dFunctor, double>; + paddle::operators::math::MaxPool, double>; template class Pool3dFunctor, double>; + paddle::operators::math::AvgPool, double>; template class Pool3dGradFunctor< - platform::GPUPlace, paddle::operators::math::maxPoolGrad, double>; + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool3dGradFunctor< - platform::GPUPlace, paddle::operators::math::avgPoolGrad, double>; + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index cf0ab7ecae8b7463185a2066ea75293cb168e1aa..d214c689235ad4233d3e4e1c2aa0fdc993bf20c6 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -22,11 +22,10 @@ namespace paddle { namespace operators { namespace math { ////////////////////// -#define FLT_MAX __FLT_MAX__ -///////////////////// +#define FLT_MAX __FLT_MAX__ // template -class maxPool { +class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } @@ -34,14 +33,14 @@ class maxPool { }; template -class avgPool { +class AvgPool { public: DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(T& y, const T& x) { y += x; } DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } }; template -class maxPoolGrad { +class MaxPoolGrad { public: DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, T scale) { @@ -50,7 +49,7 @@ class maxPoolGrad { }; template -class avgPoolGrad { +class AvgPoolGrad { public: DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, T scale) { diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 6219c6f97cb749f1e27481c8fc3abf9c449031de..c29f51f05613832c838400eb114465c81290ea58 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -51,7 +51,7 @@ class PoolOp : public framework::OperatorWithKernel { ksize[i] = static_cast(in_x_dims[i + 2]); } - PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2, + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, "Input size and Pooling size should be consistent."); PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, "Pooling size should be 2 elements. or 3 elements."); @@ -79,7 +79,6 @@ class PoolOpGrad : public framework::OperatorWithKernel { "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Input@Grad of Pooling should not be null."); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; @@ -98,66 +97,36 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { "The format of output tensor is also NCHW."); AddAttr("poolingType", - "poolingType of pooling operator." - "str constant equal to 'max' or 'avg'"); + "PoolingType of pooling operator." + "Str constant equal to 'max' or 'avg'.") + .InEnum({"max", "avg"}); AddAttr>( "ksize", "Pooling size(depth, height, width) of pooling operator." - "If globalPooling = true, ksize is ignored and need not be specified."); + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) AddAttr( "globalPooling", - "whether to use the globalPooling." - "int constant equal to false or true" - "default false" + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." "If globalPooling = true, ksize is ignored and need not be specified.") .SetDefault(false); AddAttr>("strides", - "strides(height, width) of pooling operator." - "default {1,1}") - .SetDefault({1, 1}) - .AddCustomChecker(GreaterThanChecker_pool({0, 0})); + "Strides(height, width) of pooling operator." + "Default {1,1}") + .SetDefault({1, 1}); // TODO(Add checker) AddAttr>("paddings", - "paddings(height, width) of pooling operator." - "default {0,0}") - .SetDefault({0, 0}) - .AddCustomChecker(EqualGreaterThanChecker_pool({0, 0})); + "Paddings(height, width) of pooling operator." + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Add checker) AddComment(R"DOC( The pooling2d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. )DOC"); } - - private: - struct GreaterThanChecker_pool { - public: - explicit GreaterThanChecker_pool(std::vector lower_bound) - : lower_bound_(lower_bound) {} - void operator()(std::vector &value) const { - PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); - for (size_t i = 0; i < value.size(); ++i) { - PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails."); - } - } - - private: - std::vector lower_bound_; - }; - - struct EqualGreaterThanChecker_pool { - public: - explicit EqualGreaterThanChecker_pool(std::vector lower_bound) - : lower_bound_(lower_bound) {} - void operator()(std::vector &value) const { - PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); - for (size_t i = 0; i < value.size(); ++i) { - PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails."); - } - } - - private: - std::vector lower_bound_; - }; }; + class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { public: Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -173,67 +142,36 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { "The format of output tensor is also NCDHW."); AddAttr("poolingType", - "poolingType of pooling operator." - "str constant equal to 'max' or 'avg'"); + "PoolingType of pooling operator." + "str constant equal to 'max' or 'avg'.") + .InEnum({"max", "avg"}); AddAttr>( "ksize", - "pooling size(depth, height, width) of pooling operator." - "If globalPooling = true, ksize is ignored and need not be specified."); + "Pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) AddAttr( "globalPooling", - "whether to use the globalPooling." - "int constant equal to false or true" - "default false" + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." "If globalPooling = true, ksize is ignored and need not be specified.") .SetDefault(false); AddAttr>( "strides", - "strides(depth, height, width) of pooling operator." - "default {1,1,1}") - .SetDefault({1, 1, 1}) - .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0})); + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Add checker) AddAttr>( "paddings", - "paddings(depth, height, width) of pooling operator." - "default {0,0,0}") - .SetDefault({0, 0, 0}) - .AddCustomChecker(EqualGreaterThanChecker_pool({0, 0, 0})); + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Add checker) AddComment(R"DOC( The pooling3d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. )DOC"); } - - private: - struct GreaterThanChecker_pool { - public: - explicit GreaterThanChecker_pool(std::vector lower_bound) - : lower_bound_(lower_bound) {} - void operator()(std::vector &value) const { - PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); - for (size_t i = 0; i < value.size(); ++i) { - PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails."); - } - } - - private: - std::vector lower_bound_; - }; - - struct EqualGreaterThanChecker_pool { - public: - explicit EqualGreaterThanChecker_pool(std::vector lower_bound) - : lower_bound_(lower_bound) {} - void operator()(std::vector &value) const { - PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); - for (size_t i = 0; i < value.size(); ++i) { - PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails."); - } - } - - private: - std::vector lower_bound_; - }; }; } // namespace operators } // namespace paddle diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index 73c9721624992a73c5aafd8a3755fba74ef3318e..0c246b38efa8352a42249e5113c985c5193463d0 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -31,12 +31,11 @@ class PoolKernel : public framework::OpKernel { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); - bool global_pooling = context.Attr("globalPooling"); std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (global_pooling) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -46,17 +45,17 @@ class PoolKernel : public framework::OpKernel { case 2: { if (pooling_type == "max") { paddle::operators::math::Pool2dFunctor< - Place, paddle::operators::math::maxPool, T> + Place, paddle::operators::math::MaxPool, T> pool2d_forward; - paddle::operators::math::maxPool pool_process; + paddle::operators::math::MaxPool pool_process; pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, paddings, pool_process); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dFunctor< - Place, paddle::operators::math::avgPool, T> + Place, paddle::operators::math::AvgPool, T> pool2d_forward; - paddle::operators::math::avgPool pool_process; + paddle::operators::math::AvgPool pool_process; pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, paddings, pool_process); } @@ -64,16 +63,16 @@ class PoolKernel : public framework::OpKernel { case 3: { if (pooling_type == "max") { paddle::operators::math::Pool3dFunctor< - Place, paddle::operators::math::maxPool, T> + Place, paddle::operators::math::MaxPool, T> pool3d_forward; - paddle::operators::math::maxPool pool_process; + paddle::operators::math::MaxPool pool_process; pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, paddings, pool_process); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dFunctor< - Place, paddle::operators::math::avgPool, T> + Place, paddle::operators::math::AvgPool, T> pool3d_forward; - paddle::operators::math::avgPool pool_process; + paddle::operators::math::AvgPool pool_process; pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, paddings, pool_process); } @@ -92,13 +91,12 @@ class PoolGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - bool global_pooling = context.Attr("globalPooling"); std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (global_pooling) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -118,9 +116,9 @@ class PoolGradKernel : public framework::OpKernel { *out_grad, ksize, strides, paddings); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dGradFunctor< - Place, paddle::operators::math::avgPoolGrad, T> + Place, paddle::operators::math::AvgPoolGrad, T> pool2d_backward; - paddle::operators::math::avgPoolGrad pool_process; + paddle::operators::math::AvgPoolGrad pool_process; pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, *out_grad, ksize, strides, paddings, pool_process); } @@ -133,9 +131,9 @@ class PoolGradKernel : public framework::OpKernel { *out_grad, ksize, strides, paddings); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dGradFunctor< - Place, paddle::operators::math::avgPoolGrad, T> + Place, paddle::operators::math::AvgPoolGrad, T> pool3d_backward; - paddle::operators::math::avgPoolGrad pool_process; + paddle::operators::math::AvgPoolGrad pool_process; pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, *out_grad, ksize, strides, paddings, pool_process); }