diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index a3357867530c110df16a5f3ec8c799735206cc71..239ae5e1233c7f5c506930df374b5d0cc8de7c8d 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -195,6 +195,14 @@ std::vector vectorize(const DDim& ddim) { return result; } +// NOTE: framework::vectorize converts to type int64_t +// which does not fit cudnn inputs. +std::vector vectorize2int(const DDim& ddim) { + std::vector temp = vectorize(ddim); + std::vector result(temp.begin(), temp.end()); + return result; +} + struct ProductVisitor : public boost::static_visitor { template int64_t operator()(const Dim& dim) { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 4a871bb0a91ed4050847509cc3f24218bcd57142..2a5e2d2b6948b045642dbac5e83992a048ecb63d 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx); void set(DDim& dim, int idx, int val); std::vector vectorize(const DDim& ddim); +std::vector vectorize2int(const DDim& ddim); int64_t product(const DDim& ddim); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 132db540241b226deb94cfb65dd8ec8fe47e7e9b..c72261710173a0f3af199646d6800bf9d7c27b67 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,6 +69,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() + # pool_cudnn_op contains several operators + if ("${TARGET}" STREQUAL "pool_cudnn_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") + endif() + # save_restore_op contains several operators if ("${TARGET}" STREQUAL "save_restore_op") set(pybind_flag 1) diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu index 366d0323b840c338dd6ba5b28bdb29fd135fe91a..e2eb157f40c0039f87c41d28f8732cd4901a046d 100644 --- a/paddle/operators/conv_cudnn_op.cu +++ b/paddle/operators/conv_cudnn_op.cu @@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; -// NOTE: framework::vectorize converts to type int64_t -// which does not fit cudnn inputs. -std::vector Dims2Vector(const framework::DDim& dims) { - std::vector ret; - for (int i = 0; i < dims.size(); i++) { - ret.push_back(dims[i]); - } - return ret; -} - template class CudnnConvOpKernel : public framework::OpKernel { public: @@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel { ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; - cudnnTensorDescriptor_t cudnn_input_desc = - input_desc.descriptor(layout, Dims2Vector(input->dims()), groups); - cudnnTensorDescriptor_t cudnn_output_desc = - output_desc.descriptor(layout, Dims2Vector(output->dims()), groups); - cudnnFilterDescriptor_t cudnn_filter_desc = - filter_desc.descriptor(layout, Dims2Vector(filter->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims()), groups); cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); @@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel { ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; - cudnnTensorDescriptor_t cudnn_input_desc = - input_desc.descriptor(layout, Dims2Vector(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims()), groups); cudnnTensorDescriptor_t cudnn_output_grad_desc = - output_grad_desc.descriptor(layout, Dims2Vector(output_grad->dims()), - groups); - cudnnFilterDescriptor_t cudnn_filter_desc = - filter_desc.descriptor(layout, Dims2Vector(filter->dims()), groups); + output_grad_desc.descriptor( + layout, framework::vectorize2int(output_grad->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims()), groups); cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; @@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel { auto handle = ctx.cuda_device_context().cudnn_handle(); if (input_grad) { cudnn_input_grad_desc = input_grad_desc.descriptor( - layout, Dims2Vector(input_grad->dims()), groups); + layout, framework::vectorize2int(input_grad->dims()), groups); PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( handle, cudnn_filter_desc, @@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel { if (filter_grad) { cudnn_filter_grad_desc = filter_grad_desc.descriptor( - layout, Dims2Vector(filter_grad->dims()), groups); + layout, framework::vectorize2int(filter_grad->dims()), groups); PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, diff --git a/paddle/operators/pool_cudnn_op.cc b/paddle/operators/pool_cudnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f962d9e3e6abde14ce21eb0102f10d139fdb160e --- /dev/null +++ b/paddle/operators/pool_cudnn_op.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_cudnn_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool2d_cudnn, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad, + ops::PoolGradKernel) diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc29be18e76fde19c10c32e0299c395a150d8c40 --- /dev/null +++ b/paddle/operators/pool_cudnn_op.cu @@ -0,0 +1,152 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_cudnn_op.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor; +using DataLayout = platform::DataLayout; +using PoolingMode = platform::PoolingMode; + +template +class PoolCudnnOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + const Tensor *input = ctx.Input("X"); + Tensor *output = ctx.Output("Out"); + + const T *input_data = input->data(); + T *output_data = output->mutable_data(ctx.GetPlace()); + + std::string pooling_type = ctx.Attr("poolingType"); + std::vector ksize = ctx.Attr>("ksize"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + if (ctx.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(input->dims()[i + 2]); + } + } + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedPoolingDescriptor pool_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + + PoolingMode pooling_mode; + if (pooling_type == "max") { + pooling_mode = PoolingMode::kMaximum; + } else { + pooling_mode = PoolingMode::kAverage; + } + + cudnnPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, ksize, paddings, strides); + + // ------------------- cudnn pool algorithm --------------------- + auto handle = ctx.cuda_device_context().cudnn_handle(); + T alpha = 1.0f, beta = 0.0f; + + PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward( + handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, + cudnn_output_desc, output_data)); + } +}; + +template +class PoolCudnnGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + const Tensor *input = ctx.Input("X"); + const Tensor *output = ctx.Input("Out"); + const Tensor *output_grad = + ctx.Input(framework::GradVarName("Out")); + Tensor *input_grad = ctx.Output(framework::GradVarName("X")); + + std::string pooling_type = ctx.Attr("poolingType"); + std::vector ksize = ctx.Attr>("ksize"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + + if (ctx.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(input->dims()[i + 2]); + } + + const T *input_data = input->data(); + const T *output_data = output->data(); + const T *output_grad_data = output_grad->data(); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedPoolingDescriptor pool_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + + PoolingMode pooling_mode; + if (pooling_type == "max") { + pooling_mode = PoolingMode::kMaximum; + } else { + pooling_mode = PoolingMode::kAverage; + } + + cudnnPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, ksize, paddings, strides); + + // ------------------- cudnn pool algorithm --------------------- + auto handle = ctx.cuda_device_context().cudnn_handle(); + T alpha = 1.0f, beta = 0.0f; + + if (input_grad) { + T *input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), input_grad, static_cast(0)); + + PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( + handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, + cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, + &beta, cudnn_input_desc, input_grad_data)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel); +REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel); diff --git a/paddle/operators/pool_cudnn_op.h b/paddle/operators/pool_cudnn_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5adf27f5bccae8542719612320bc6dbe21007634 --- /dev/null +++ b/paddle/operators/pool_cudnn_op.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/pool_op.h" + +namespace paddle { +namespace operators {} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index a326839c0f9ad14b8fd2aac596f21c7dd2539cd7..c4ab29e4d5f7c02d97a2185a58fdcd48de822d2d 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -29,7 +29,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { auto in_x_dims = ctx->GetInputDim("X"); - std::string pooling_type = ctx->Attrs().Get("pooling_type"); + std::string pooling_type = ctx->Attrs().Get("poolingType"); std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); @@ -37,7 +37,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, "Pooling intput should be 4-D or 5-D tensor."); - if (ctx->Attrs().Get("global_pooling")) { + if (ctx->Attrs().Get("globalPooling")) { ksize.resize(static_cast(in_x_dims.size()) - 2); for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x_dims[i + 2]); @@ -80,34 +80,30 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, "the number of channels, H and W is the height and " "width of feature."); - AddAttr("pooling_type", - "Pooling_type of pooling operator." - "Str constant equal to 'max' or 'avg'.") + AddAttr("poolingType", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") .InEnum({"max", "avg"}); - AddAttr>( "ksize", - "The pooling window size(height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + // TypedAttrChecker don't support vector type.) + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); - AddAttr>("strides", - "The strides(height, width) of pooling window." - "Default {1,1}.") + AddAttr>( + "strides", + "(vector, default:{1, 1}), strides(height, width) of pooling operator.") .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr>("paddings", - "The zero padding(height, width) size on both sides" - "Default {0,0}.") + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "(vector defalut:{0,0}), paddings(height, width) of pooling operator.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( The pooling2d operation calculates the output based on @@ -123,7 +119,6 @@ Example: X shape: (N, C, H_in, W_in) Output: Out shape: (N, C, H_out, W_out) - Mask shape: (N, C, H_out, W_out) where H_out = (H_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; W_out = (W_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; @@ -146,33 +141,29 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, "the number of channels, D, H and W is the depth, height and " "width of feature."); - AddAttr("pooling_type", - "PoolingType of pooling operator." - "Str constant equal to 'max' or 'avg'.") + AddAttr("poolingType", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") .InEnum({"max", "avg"}); - AddAttr>( "ksize", - "The pooling window size(depth, height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(depth, height, width) of pooling " + "operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); AddAttr>("strides", - "Strides(depth, height, width) of pooling operator." - "Default {1,1,1}.") + "(vector, default:{1,1,1}), strides(depth, height, " + "width) of pooling operator.") .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) - AddAttr>( - "paddings", - "Paddings(depth, height, width) of pooling operator." - "Default {0,0,0}.") + AddAttr>("paddings", + "(vector defalut:{0,0,0}), paddings(depth, height, " + "width) of pooling operator.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) @@ -190,7 +181,6 @@ Example: X shape: (N, C, D_in, H_in, W_in) Output: Out shape: (N, C, D_out, H_out, W_out) - Mask shape: (N, C, D_out, H_out, W_out) where D_out = (D_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; H_out = (H_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index ada956501918cc92a2d30ebb8d0c42453acd2839..ba8edc9cf60bcf90204ed11fa4fe1d408ad82d40 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -57,11 +57,11 @@ class PoolKernel : public framework::OpKernel { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); - std::string pooling_type = context.Attr("pooling_type"); + std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -117,12 +117,12 @@ class PoolGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - std::string pooling_type = context.Attr("pooling_type"); + std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x->dims()[i + 2]); } diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index 29d0322a27b71fe8d335703e228969c084f5139f..ea21845751bee523fbbb85f7fdbeea7bcc586997 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -44,7 +44,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, "Pooling intput should be 4-D or 5-D tensor."); - if (ctx->Attrs().Get("global_pooling")) { + if (ctx->Attrs().Get("globalPooling")) { ksize.resize(static_cast(in_x_dims.size()) - 2); for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x_dims[i + 2]); @@ -105,28 +105,24 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "ksize", - "The pooling window size(height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + // TypedAttrChecker don't support vector type.) + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); - AddAttr>("strides", - "The strides(height, width) of pooling window." - "Default {1,1}.") + AddAttr>( + "strides", + "(vector, default:{1, 1}), strides(height, width) of pooling operator.") .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "The zero padding(height, width) size on both sides" - "Default {0,0}.") + "(vector defalut:{0,0}), paddings(height, width) of pooling operator.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( The maxPooling2d with index operation calculates the output and the mask @@ -176,29 +172,25 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "ksize", - "The pooling window size(depth, height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(depth, height, width) of pooling " + "operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + // TypedAttrChecker don't support vector type.) + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); - AddAttr>( - "strides", - "Strides(depth, height, width) of pooling operator." - "Default {1,1,1}.") + AddAttr>("strides", + "(vector, default:{1,1,1}), strides(depth, " + "height, width) of pooling operator.") .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr>( - "paddings", - "Paddings(depth, height, width) of pooling operator." - "Default {0,0,0}.") + // TypedAttrChecker don't support vector type.) + AddAttr>("paddings", + "(vector defalut:{0,0,0}), paddings(depth, " + "height, width) of pooling operator.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( The maxpooling3d with index operation calculates the output and the mask diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index 455c453efcd15bf0150bbd3de83d50729f338b4b..01b961ca8295f723bea7335e43ec5ab100dfc65c 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -35,7 +35,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -70,7 +70,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x_grad->dims()[i + 2]); } diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 471bd80096f76aa4172929b4d653cad1c6380025..4bb763e6d9be39f8f1cc3521767c4f46537db7d4 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -284,9 +284,9 @@ def pool2d(input, inputs={"X": input}, outputs={"Out": pool_out}, attrs={ - "pooling_type": pool_type, + "poolingType": pool_type, "ksize": pool_size, - "global_pooling": global_pooling, + "globalPooling": global_pooling, "strides": pool_stride, "paddings": pool_padding }) diff --git a/python/paddle/v2/framework/tests/test_pool2d_op.py b/python/paddle/v2/framework/tests/test_pool2d_op.py index 059b65e201efd30ba220a5951fac708a06b23663..f04de8133ad3b747d03500a1498b1516c21479b8 100644 --- a/python/paddle/v2/framework/tests/test_pool2d_op.py +++ b/python/paddle/v2/framework/tests/test_pool2d_op.py @@ -46,7 +46,9 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): class TestPool2d_Op(OpTest): def setUp(self): - self.initTestCase() + self.init_test_case() + self.init_op_type() + self.init_pool_type() input = np.random.random(self.shape).astype("float32") output = self.pool2D_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool) @@ -56,8 +58,8 @@ class TestPool2d_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'pooling_type': self.pool_type, - 'global_pooling': self.global_pool, + 'poolingType': self.pool_type, + 'globalPooling': self.global_pool, } self.outputs = {'Out': output.astype('float32')} @@ -69,76 +71,197 @@ class TestPool2d_Op(OpTest): if self.pool_type != "max": self.check_grad(set(['X']), 'Out', max_relative_error=0.07) - def initTestCase(self): + def init_test_case(self): self.global_pool = True - self.op_type = "pool2d" - self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "avg" + class TestCase1(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False - self.op_type = "pool2d" - self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "avg" + class TestCase2(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False - self.op_type = "pool2d" - self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [1, 1] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "avg" + class TestCase3(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = True - self.op_type = "pool2d" - self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "max" + class TestCase4(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False - self.op_type = "pool2d" - self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "max" + class TestCase5(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_op_type(self): self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "max" + + +#--------------------test pool2d_cudnn-------------------- +class TestCaseCudnn1(TestPool2d_Op): + def init_test_case(self): + self.global_pool = True + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "avg" + + +class TestCaseCudnn2(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "avg" + + +class TestCaseCudnn3(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "avg" + + +class TestCaseCudnn4(TestPool2d_Op): + def init_test_case(self): + self.global_pool = True + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "max" + + +class TestCaseCudnn5(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): self.pool_type = "max" + + +class TestCaseCudnn6(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False self.pool2D_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [1, 1] + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "max" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool3d_op.py b/python/paddle/v2/framework/tests/test_pool3d_op.py index abb4d4e68f532c3bf4224ca30bdd35660361f833..d62fbee9746c5524cb8c428df584d2b76cf67bc9 100644 --- a/python/paddle/v2/framework/tests/test_pool3d_op.py +++ b/python/paddle/v2/framework/tests/test_pool3d_op.py @@ -64,8 +64,8 @@ class TestPool3d_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'pooling_type': self.pool_type, - 'global_pooling': self.global_pool, + 'poolingType': self.pool_type, + 'globalPooling': self.global_pool, } self.outputs = {'Out': output.astype('float32')} diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index b78f9bba05c5af38806f6cabb0e53379f8aa0526..f0f8aa6089c74d31702a6a5d37362099205d96b2 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'global_pooling': self.global_pool, + 'globalPooling': self.global_pool, } self.inputs = {'X': input}