From 5238c9fb0beaff08399b91e996e090852f4b87bf Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 17 Nov 2017 15:54:08 +0800 Subject: [PATCH] input type should be different --- paddle/operators/math/pooling.cc | 60 ++++---- paddle/operators/math/pooling.cu | 130 +++++++++--------- paddle/operators/math/pooling.h | 8 +- paddle/operators/pool_with_index_op.cc | 34 +++-- paddle/operators/pool_with_index_op.cu.cc | 8 +- paddle/operators/pool_with_index_op.h | 18 +-- .../paddle/v2/fluid/tests/test_pool_max_op.py | 104 +++++--------- 7 files changed, 172 insertions(+), 190 deletions(-) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index ead89e146f3..135984586a6 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -498,8 +498,8 @@ template class Pool3dGradFunctor< * Ksize, strides, paddings are two elements. These two elements represent * height and width, respectively. */ -template -class MaxPool2dWithIndexFunctor { +template +class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, std::vector& ksize, @@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor { const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; - const T* input_data = input.data(); - T* output_data = output->mutable_data(context.GetPlace()); - T* mask_data = mask->mutable_data(context.GetPlace()); + const T1* input_data = input.data(); + T1* output_data = output->mutable_data(context.GetPlace()); + T2* mask_data = mask->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor { int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); - T ele = static_cast(-FLT_MAX); + T1 ele = static_cast(-FLT_MAX); int index = -1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor { * Ksize, strides, paddings are two elements. These two elements represent * height and width, respectively. */ -template -class MaxPool2dWithIndexGradFunctor { +template +class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& output_grad, @@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor { const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; - const T* mask_data = mask.data(); - const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + const T2* mask_data = mask.data(); + const T1* output_grad_data = output_grad.data(); + T1* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int n = 0; n < batch_size; ++n) { for (int c = 0; c < output_channels; ++c) { @@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor { } }; -template class MaxPool2dWithIndexFunctor; -template class MaxPool2dWithIndexGradFunctor; -template class MaxPool2dWithIndexFunctor; -template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; /* * All tensors are in NCDHW format. * Ksize, strides, paddings are three elements. These three elements represent * depth, height and width, respectively. */ -template -class MaxPool3dWithIndexFunctor { +template +class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, std::vector& ksize, @@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor { const int input_stride = input_depth * input_height * input_width; const int output_stride = output_depth * output_height * output_width; - const T* input_data = input.data(); - T* output_data = output->mutable_data(context.GetPlace()); - T* mask_data = mask->mutable_data(context.GetPlace()); + const T1* input_data = input.data(); + T1* output_data = output->mutable_data(context.GetPlace()); + T2* mask_data = mask->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor { wstart = std::max(wstart, 0); int output_idx = (pd * output_height + ph) * output_width + pw; - T ele = static_cast(-FLT_MAX); + T1 ele = static_cast(-FLT_MAX); int index = -1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor { * Ksize, strides, paddings are three elements. These three elements represent * depth, height and width, respectively. */ -template -class MaxPool3dWithIndexGradFunctor { +template +class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& output_grad, @@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor { const int input_stride = input_depth * input_height * input_width; const int output_stride = output_depth * output_height * output_width; - const T* mask_data = mask.data(); - const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + const T2* mask_data = mask.data(); + const T1* output_grad_data = output_grad.data(); + T1* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int n = 0; n < batch_size; ++n) { for (int c = 0; c < output_channels; ++c) { @@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor { } }; -template class MaxPool3dWithIndexFunctor; -template class MaxPool3dWithIndexGradFunctor; -template class MaxPool3dWithIndexFunctor; -template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 6d1138ad50c..ca3560f264b 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -658,13 +658,13 @@ template class Pool3dGradFunctor< template class Pool3dGradFunctor< platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; -template +template __global__ void KernelMaxPool2dWithIdx( - const int nthreads, const T* input_data, const int channels, + const int nthreads, const T1* input_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, - const int padding_width, T* output_data, T* mask_data) { + const int padding_width, T1* output_data, T2* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -681,7 +681,7 @@ __global__ void KernelMaxPool2dWithIdx( wstart = max(wstart, 0); input_data += (batch_idx * channels + c) * input_height * input_width; - T ele = -FLT_MAX; + T1 ele = -FLT_MAX; int max_index = -1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -697,13 +697,13 @@ __global__ void KernelMaxPool2dWithIdx( } } -template +template __global__ void KernelMaxPool2DWithIdxGrad( - const int nthreads, const T* output_grad, const T* mask_data, + const int nthreads, const T1* output_grad, const T2* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, T* input_grad) { + const int padding_height, const int padding_width, T1* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; @@ -724,7 +724,7 @@ __global__ void KernelMaxPool2DWithIdxGrad( int pw_end = min((w_offset + padding_width) / stride_width + 1, output_width); - T gradient = 0; + T1 gradient = 0; int input_current_featuremap_idx = h_offset * input_width + w_offset; int output_idx = (batch_idx * channels + c_offset) * output_height * output_width; @@ -746,8 +746,8 @@ __global__ void KernelMaxPool2DWithIdxGrad( * Ksize, strides, paddings are two elements. These two elements represent * height and width, respectively. */ -template -class MaxPool2dWithIndexFunctor { +template +class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, std::vector& ksize, @@ -767,9 +767,9 @@ class MaxPool2dWithIndexFunctor { const int padding_height = paddings[0]; const int padding_width = paddings[1]; - const T* input_data = input.data(); - T* output_data = output->mutable_data(context.GetPlace()); - T* mask_data = mask->mutable_data(context.GetPlace()); + const T1* input_data = input.data(); + T1* output_data = output->mutable_data(context.GetPlace()); + T2* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -777,9 +777,9 @@ class MaxPool2dWithIndexFunctor { dim3 grid(blocks, 1); KernelMaxPool2dWithIdx< - T><<(context) - .stream()>>>( + T1, T2><<(context) + .stream()>>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, stride_width, padding_height, padding_width, output_data, mask_data); @@ -791,8 +791,8 @@ class MaxPool2dWithIndexFunctor { * Ksize, strides, paddings are two elements. These two elements represent * height and width, respectively. */ -template -class MaxPool2dWithIndexGradFunctor { +template +class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& output_grad, @@ -812,9 +812,9 @@ class MaxPool2dWithIndexGradFunctor { const int padding_height = paddings[0]; const int padding_width = paddings[1]; - const T* mask_data = mask.data(); - const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + const T2* mask_data = mask.data(); + const T1* output_grad_data = output_grad.data(); + T1* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -822,30 +822,30 @@ class MaxPool2dWithIndexGradFunctor { dim3 grid(blocks, 1); KernelMaxPool2DWithIdxGrad< - T><<(context) - .stream()>>>(nthreads, output_grad_data, mask_data, - input_channels, input_height, input_width, - output_height, output_width, ksize_height, - ksize_width, stride_height, stride_width, - padding_height, padding_width, input_grad_data); + T1, T2><<(context) + .stream()>>>( + nthreads, output_grad_data, mask_data, input_channels, input_height, + input_width, output_height, output_width, ksize_height, ksize_width, + stride_height, stride_width, padding_height, padding_width, + input_grad_data); } }; -template class MaxPool2dWithIndexFunctor; -template class MaxPool2dWithIndexGradFunctor; -template class MaxPool2dWithIndexFunctor; -template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; -template +template __global__ void KernelMaxPool3DWithIdx( - const int nthreads, const T* input_data, const int channels, + const int nthreads, const T1* input_data, const int channels, const int input_depth, const int input_height, const int input_width, const int output_depth, const int output_height, const int output_width, const int ksize_depth, const int ksize_height, const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width, - T* output_data, T* mask_data) { + T1* output_data, T2* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -865,7 +865,7 @@ __global__ void KernelMaxPool3DWithIdx( hstart = max(hstart, 0); wstart = max(wstart, 0); - T ele = -FLT_MAX; + T1 ele = -FLT_MAX; int max_index = -1; input_data += (batch_idx * channels + c) * input_depth * input_height * input_width; @@ -885,15 +885,15 @@ __global__ void KernelMaxPool3DWithIdx( } } -template +template __global__ void KernelMaxPool3DWithIdxGrad( - const int nthreads, const T* output_grad, const T* mask, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, const int padding_width, - T* input_grad) { + const int nthreads, const T1* output_grad, const T2* mask, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width, T1* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; @@ -922,7 +922,7 @@ __global__ void KernelMaxPool3DWithIdxGrad( int pw_end = min((w_offset + padding_width) / stride_width + 1, output_width); - T gradient = 0; + T1 gradient = 0; int input_current_feature_map_idx = (d_offset * input_height + h_offset) * input_width + w_offset; int output_idx = (batch_idx * channels + c_offset) * output_depth * @@ -949,8 +949,8 @@ __global__ void KernelMaxPool3DWithIdxGrad( * Ksize, strides, paddings are three elements. These three elements represent * depth, height and width, respectively. */ -template -class MaxPool3dWithIndexFunctor { +template +class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, std::vector& ksize, @@ -975,9 +975,9 @@ class MaxPool3dWithIndexFunctor { const int padding_height = paddings[1]; const int padding_width = paddings[2]; - const T* input_data = input.data(); - T* output_data = output->mutable_data(context.GetPlace()); - T* mask_data = mask->mutable_data(context.GetPlace()); + const T1* input_data = input.data(); + T1* output_data = output->mutable_data(context.GetPlace()); + T2* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -986,9 +986,9 @@ class MaxPool3dWithIndexFunctor { dim3 grid(blocks, 1); KernelMaxPool3DWithIdx< - T><<(context) - .stream()>>>( + T1, T2><<(context) + .stream()>>>( nthreads, input_data, input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, stride_width, @@ -1001,8 +1001,8 @@ class MaxPool3dWithIndexFunctor { * Ksize, strides, paddings are three elements. These three elements represent * depth, height and width, respectively. */ -template -class MaxPool3dWithIndexGradFunctor { +template +class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& output_grad, @@ -1027,9 +1027,9 @@ class MaxPool3dWithIndexGradFunctor { const int padding_height = paddings[1]; const int padding_width = paddings[2]; - const T* output_grad_data = output_grad.data(); - const T* mask_data = mask.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + const T1* output_grad_data = output_grad.data(); + const T2* mask_data = mask.data(); + T1* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_depth * input_height * input_width; @@ -1038,9 +1038,9 @@ class MaxPool3dWithIndexGradFunctor { dim3 grid(blocks, 1); KernelMaxPool3DWithIdxGrad< - T><<(context) - .stream()>>>( + T1, T2><<(context) + .stream()>>>( nthreads, output_grad_data, mask_data, input_channels, input_depth, input_height, input_width, output_depth, output_height, output_width, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, @@ -1049,10 +1049,10 @@ class MaxPool3dWithIndexGradFunctor { } }; -template class MaxPool3dWithIndexFunctor; -template class MaxPool3dWithIndexGradFunctor; -template class MaxPool3dWithIndexFunctor; -template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index f6719e1e628..19fbd8b4bb2 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -153,7 +153,7 @@ class MaxPool3dGradFunctor { * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in * NCDHW format. */ -template +template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, @@ -162,7 +162,7 @@ class MaxPool2dWithIndexFunctor { framework::Tensor* output, framework::Tensor* mask); }; -template +template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, @@ -172,7 +172,7 @@ class MaxPool2dWithIndexGradFunctor { framework::Tensor* input_grad); }; -template +template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, @@ -181,7 +181,7 @@ class MaxPool3dWithIndexFunctor { framework::Tensor* output, framework::Tensor* mask); }; -template +template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index 1df36e965ab..4470e2b2798 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -29,11 +29,11 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "X(Input) of Pooling should not be null."); + "Input(X) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Out(Output) of Pooling should not be null."); + "Output(Out) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Mask"), - "Mask(Output) of Pooling should not be null."); + "Output(Mask) of Pooling should not be null."); auto in_x_dims = ctx->GetInputDim("X"); @@ -67,6 +67,14 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { @@ -80,6 +88,14 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { "Input(X@GRAD) should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { @@ -116,7 +132,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { // TypedAttrChecker don't support vector type.) AddAttr( "global_pooling", - "(bool, default false) Whether to use the global pooling. " + "(bool, default:false) Whether to use the global pooling. " "If global_pooling = true, ksize and paddings will be ignored.") .SetDefault(false); AddAttr>("strides", @@ -126,7 +142,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "(vector, defalut {0, 0}), paddings(height, width) of pooling " + "(vector, defalut:{0, 0}), paddings(height, width) of pooling " "operator. " "If global_pooling = true, paddings and will be ignored.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, @@ -250,10 +266,10 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, REGISTER_OP_CPU_KERNEL( max_pool2d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel); REGISTER_OP_CPU_KERNEL( max_pool2d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel) REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, @@ -261,7 +277,7 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, REGISTER_OP_CPU_KERNEL( max_pool3d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel); REGISTER_OP_CPU_KERNEL( max_pool3d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu.cc b/paddle/operators/pool_with_index_op.cu.cc index 287657d4b1c..7d4c294c5fb 100644 --- a/paddle/operators/pool_with_index_op.cu.cc +++ b/paddle/operators/pool_with_index_op.cu.cc @@ -18,14 +18,14 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( max_pool2d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel); REGISTER_OP_GPU_KERNEL( max_pool2d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel) REGISTER_OP_GPU_KERNEL( max_pool3d_with_index, - ops::MaxPoolWithIndexKernel); + ops::MaxPoolWithIndexKernel); REGISTER_OP_GPU_KERNEL( max_pool3d_with_index_grad, - ops::MaxPoolWithIndexGradKernel) + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index a081607edce..40766c7e821 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -24,8 +24,8 @@ namespace operators { using Tensor = framework::Tensor; -template -class MaxPoolWithIndexKernel : public framework::OpKernel { +template +class MaxPoolWithIndexKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* in_x = context.Input("X"); @@ -44,13 +44,13 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { switch (ksize.size()) { case 2: { - paddle::operators::math::MaxPool2dWithIndexFunctor + paddle::operators::math::MaxPool2dWithIndexFunctor pool2d_forward; pool2d_forward(context.device_context(), *in_x, ksize, strides, paddings, out, mask); } break; case 3: { - paddle::operators::math::MaxPool3dWithIndexFunctor + paddle::operators::math::MaxPool3dWithIndexFunctor pool3d_forward; pool3d_forward(context.device_context(), *in_x, ksize, strides, paddings, out, mask); @@ -60,8 +60,8 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { } }; -template -class MaxPoolWithIndexGradKernel : public framework::OpKernel { +template +class MaxPoolWithIndexGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* mask = context.Input("Mask"); @@ -80,19 +80,19 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { } if (in_x_grad) { - in_x_grad->mutable_data(context.GetPlace()); + in_x_grad->mutable_data(context.GetPlace()); auto& device_ctx = context.device_context(); math::set_constant(device_ctx, in_x_grad, 0); switch (ksize.size()) { case 2: { - paddle::operators::math::MaxPool2dWithIndexGradFunctor + paddle::operators::math::MaxPool2dWithIndexGradFunctor pool2d_backward; pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides, paddings, in_x_grad); } break; case 3: { - paddle::operators::math::MaxPool3dWithIndexGradFunctor + paddle::operators::math::MaxPool3dWithIndexGradFunctor pool3d_backward; pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides, paddings, in_x_grad); diff --git a/python/paddle/v2/fluid/tests/test_pool_max_op.py b/python/paddle/v2/fluid/tests/test_pool_max_op.py index 04843a28ac1..2c862ec4d55 100644 --- a/python/paddle/v2/fluid/tests/test_pool_max_op.py +++ b/python/paddle/v2/fluid/tests/test_pool_max_op.py @@ -3,11 +3,13 @@ import numpy as np from op_test import OpTest -def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): +def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False): N, C, D, H, W = x.shape - if global_pool == 1: + if global_pool: ksize = [D, H, W] + paddings = [0, 0, 0] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 @@ -40,11 +42,13 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): return out, mask -def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): +def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=False): N, C, H, W = x.shape - if global_pool == 1: + if global_pool: ksize = [H, W] + paddings = [0, 0] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 out = np.zeros((N, C, H_out, W_out)) @@ -74,13 +78,13 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): class TestMaxPoolWithIndex_Op(OpTest): def setUp(self): self.init_test_case() - if self.global_pool: - self.paddings = [0 for _ in range(len(self.paddings))] + self.init_global() + input = np.random.random(self.shape).astype("float32") output, mask = self.pool_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool) output = output.astype("float32") - mask = mask.astype("float32") + mask = mask.astype("int32") self.attrs = { 'strides': self.strides, @@ -99,41 +103,24 @@ class TestMaxPoolWithIndex_Op(OpTest): # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) def init_test_case(self): - self.global_pool = True - self.index = "max_pool3d_with_index" - self.op_type = "%s" % self.index + self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] self.strides = [1, 1, 1] self.paddings = [1, 1, 1] + def init_global(self): + self.global_pool = False + class TestCase1(TestMaxPoolWithIndex_Op): - def init_test_case(self): + def init_global(self): self.global_pool = True - self.op_type = "max_pool3d_with_index" - self.pool_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 5, 5, 5] - self.ksize = [3, 3, 3] - self.strides = [1, 1, 1] - self.paddings = [1, 1, 1] class TestCase2(TestMaxPoolWithIndex_Op): def init_test_case(self): - self.global_pool = False - self.op_type = "max_pool3d_with_index" - self.pool_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 7, 7, 7] - self.ksize = [3, 3, 3] - self.strides = [1, 1, 1] - self.paddings = [1, 1, 1] - - -class TestCase3(TestMaxPoolWithIndex_Op): - def init_test_case(self): - self.global_pool = False self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 7, 7, 7] @@ -141,32 +128,18 @@ class TestCase3(TestMaxPoolWithIndex_Op): self.strides = [2, 2, 2] self.paddings = [0, 0, 0] - -class TestCase4(TestMaxPoolWithIndex_Op): - def init_test_case(self): + def init_global(self): self.global_pool = True - self.op_type = "max_pool3d_with_index" - self.pool_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 5, 5, 5] - self.ksize = [3, 3, 3] - self.strides = [1, 1, 1] - self.paddings = [1, 1, 1] -class TestCase5(TestMaxPoolWithIndex_Op): - def init_test_case(self): - self.global_pool = True - self.op_type = "max_pool3d_with_index" - self.pool_forward_naive = max_pool3D_forward_naive - self.shape = [2, 3, 5, 5, 5] - self.ksize = [3, 3, 3] - self.strides = [2, 2, 2] - self.paddings = [0, 0, 0] +class TestCase3(TestCase2): + def init_global(self): + self.global_pool = False -class TestCase6(TestMaxPoolWithIndex_Op): +#----------------max_pool2d_with_index---------------- +class TestCase4(TestMaxPoolWithIndex_Op): def init_test_case(self): - self.global_pool = False self.op_type = "max_pool2d_with_index" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] @@ -174,10 +147,17 @@ class TestCase6(TestMaxPoolWithIndex_Op): self.strides = [1, 1] self.paddings = [1, 1] + def init_global(self): + self.global_pool = True -class TestCase7(TestMaxPoolWithIndex_Op): - def init_test_case(self): + +class TestCase5(TestMaxPoolWithIndex_Op): + def init_global(self): self.global_pool = False + + +class TestCase6(TestMaxPoolWithIndex_Op): + def init_test_case(self): self.op_type = "max_pool2d_with_index" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] @@ -185,27 +165,13 @@ class TestCase7(TestMaxPoolWithIndex_Op): self.strides = [2, 2] self.paddings = [0, 0] - -class TestCase8(TestMaxPoolWithIndex_Op): - def init_test_case(self): + def init_global(self): self.global_pool = True - self.op_type = "max_pool2d_with_index" - self.pool_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 5, 5] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [1, 1] -class TestCase9(TestMaxPoolWithIndex_Op): - def init_test_case(self): - self.global_pool = True - self.op_type = "max_pool2d_with_index" - self.pool_forward_naive = max_pool2D_forward_naive - self.shape = [2, 3, 5, 5] - self.ksize = [3, 3] - self.strides = [2, 2] - self.paddings = [0, 0] +class TestCase7(TestCase6): + def init_global(self): + self.global_pool = False if __name__ == '__main__': -- GitLab