diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e3e934bcccd1a5f34d88a2f33f3708a46ddabe05..429e5526dbcf82591dff2e9e4101f12125fd8724 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,6 +55,13 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # activation_op contains several operators + if ("${TARGET}" STREQUAL "pool_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d);\n") + endif() + # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index f8333f34f7b4c7b0f9a0f14a7a33f9d98e1d331c..185708cdaab4af29824961260ca04f71048a0978 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,9 +1,8 @@ - if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc + im2col.cu pooling.cc pooling.cu DEPS cblas device_context) else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 6146e710fcd3e75040bd8f17fbd48266aaad5357..a78c5f929ccc8fa22134709e652a6b2361a1183f 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -111,6 +111,7 @@ class Pool2dBackwardFunctor { int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); int pool_size = (hend - hstart) * (wend - wstart); + float scale = 1.0 / pool_size; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { pool_process.gradProcess( @@ -118,7 +119,7 @@ class Pool2dBackwardFunctor { output_data[ph * output_width + pw], output_grad_data[ph * output_width + pw], input_grad_data[h * input_width + w], - static_cast(pool_size)); + static_cast(scale)); } } } @@ -244,7 +245,6 @@ class Pool3dBackwardFunctor { const int padding_depth = paddings[0]; const int padding_height = paddings[1]; const int padding_width = paddings[2]; - const int input_stride = input_depth * input_height * input_width; const int output_stride = output_depth * output_height * output_width; @@ -271,6 +271,7 @@ class Pool3dBackwardFunctor { int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + float scale = 1.0 / pool_size; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -280,17 +281,17 @@ class Pool3dBackwardFunctor { pool_process.gradProcess( input_data[input_idx], output_data[output_idx], output_grad_data[output_idx], - input_grad_data[input_idx], static_cast(pool_size)); + input_grad_data[input_idx], static_cast(scale)); } } } } } - input_data += input_stride; - output_data += output_stride; - input_grad_data += input_stride; - output_grad_data += output_stride; } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; } } } diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 87bfb43c472d4246184276e3a954efb0c26787c0..0a399f7ca0cfa6eb591299de28eeb660f4df60cd 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -95,7 +95,7 @@ __global__ void KernelPool2dBackward( int output_sub_idx = ph * output_width + pw; pool_process.gradProcess(input, output_data[output_sub_idx], output_grad[output_sub_idx], gradient, - static_cast(pool_size)); + static_cast(1.0 / pool_size)); } } input_grad[index] = gradient; @@ -264,7 +264,7 @@ __global__ void KernelPool3DBackward( int pdstart = (offsetD < ksize_depth) ? 0 - : (offsetD + ksize_depth) / stride_depth + 1; + : (offsetD - ksize_depth) / stride_depth + 1; int phstart = (offsetH < ksize_height) ? 0 : (offsetH - ksize_height) / stride_height + 1; @@ -296,10 +296,10 @@ __global__ void KernelPool3DBackward( hstart = max(hstart, 0); wstart = max(wstart, 0); int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); - int output_sub_idx = ph * output_width + pw; + int output_sub_idx = (pd * output_height + ph) * output_width + pw; pool_process.gradProcess(input, output_data[output_sub_idx], output_grad[output_sub_idx], gradient, - static_cast(pool_size)); + static_cast(1.0 / pool_size)); } } } @@ -385,7 +385,8 @@ class Pool3dBackwardFunctor { const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad.mutable_data(context->GetPlace()); - int nthreads = batch_size * input_channels * input_height * input_width; + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..984cbefabf8de30d7736b0cc35c04e9f77be60a8 --- /dev/null +++ b/paddle/operators/pool_op.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_op.h" + +namespace paddle { +namespace operators { + +int outputSize(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class PoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), + "Input(Input) of Pooling should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), + "Output(Output) of Pooling should not be null."); + // PADDLE_ENFORCE_NOT_NULL(Attr("pooling_type"), + // "pooling_type should not be null."); + // PADDLE_ENFORCE_NOT_NULL(Attr>("ksize"), "ksize should + // not be null."); + auto input = ctx.Input("Input"); + auto output = ctx.Output("Output"); + int global_pooling = Attr("global_pooling"); + std::string pooling_type = Attr("pooling_type"); + std::vector ksize = Attr>("ksize"); + std::vector strides = Attr>("strides"); + std::vector paddings = Attr>("paddings"); + + PADDLE_ENFORCE(pooling_type == "max" || pooling_type == "ave", + "pooling_type should be 'max' or 'ave'"); + PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, + "Pooling ksize should be 2-D or 3-D"); + + if (global_pooling == 1) { + for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = input->dims()[i + 2]; + } + if (ksize.size() == 2) { + PADDLE_ENFORCE_EQ(input->dims().size(), 4, + "Pool2DOp intput should be 4-D."); + PADDLE_ENFORCE_EQ(strides.size(), 2, "Pool2DOp strides should be 2-D."); + PADDLE_ENFORCE_EQ(paddings.size(), 2, "Pool2DOp paddings should be 2-D."); + } else { + PADDLE_ENFORCE_EQ(input->dims().size(), 5, + "Pool3DOp intput should be 5-D."); + PADDLE_ENFORCE_EQ(strides.size(), 3, "Pool3DOp strides should be 3-D."); + PADDLE_ENFORCE_EQ(paddings.size(), 3, "Pool3DOp paddings should be 3-D."); + } + std::vector output_shape({input->dims()[0], input->dims()[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back( + outputSize(input->dims()[i + 2], ksize[i], paddings[i], strides[i])); + } + output->Resize(framework::make_ddim(output_shape)); + } +}; + +class PoolOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto in = ctx.Input("Input"); + auto d_in = + ctx.Output(framework::GradVarName("Input")); + if (d_in) d_in->Resize(in->dims()); + } +}; + +class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is the " + "number of channels, D, H and W is the depth, height and width of " + "image."); + AddOutput("Output", + "The output tensor of pooling operator." + "The format of output tensor is also NCDHW."); + + AddAttr("pooling_type", + "pooling_type of pooling operator.['max' or 'ave']"); + AddAttr>("ksize", "strides of pooling operator."); + AddAttr("global_pooling", "whether to use the global_pooling.") + .SetDefault(0); + AddAttr>("strides", "strides of pooling operator.") + .SetDefault({1, 1, 1}); + AddAttr>("paddings", "paddings of pooling operator.") + .SetDefault({0, 0, 0}); + AddComment(R"DOC( +The pooling3d operation calculates the output based on +the input, pooling_type and ksize, strides, paddings parameters. +)DOC"); + } +}; + +class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddOutput("Output", + "The output tensor of pooling operator." + "The format of output tensor is also NCHW."); + + AddAttr("pooling_type", + "pooling_type of pooling operator.['max' or 'ave']"); + AddAttr>("ksize", "strides of pooling operator."); + AddAttr("global_pooling", "whether to use the global_pooling.") + .SetDefault(0); + AddAttr>("strides", "strides of pooling operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of pooling operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( +The pooling2d operation calculates the output based on +the input, pooling_type and ksize, strides, paddings parameters. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool2d, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool2d_grad, + ops::PoolGradKernel) + +REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool3d, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool3d_grad, + ops::PoolGradKernel); diff --git a/paddle/operators/pool_op.cu b/paddle/operators/pool_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b011d20da51fe9bc24e47fc1d933a868bf59a605 --- /dev/null +++ b/paddle/operators/pool_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/pool_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(pool2d, + ops::PoolKernel); +REGISTER_OP_GPU_KERNEL(pool2d_grad, + ops::PoolGradKernel); + +REGISTER_OP_GPU_KERNEL(pool3d, + ops::PoolKernel); +REGISTER_OP_GPU_KERNEL(pool3d_grad, + ops::PoolGradKernel); diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..aca1c5a137db03510b16acac6ef1c256fd86204b --- /dev/null +++ b/paddle/operators/pool_op.h @@ -0,0 +1,157 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class PoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + Tensor* output = context.Output("Output"); + + int global_pooling = context.Attr("global_pooling"); + std::string pooling_type = context.Attr("pooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (global_pooling == 1) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = input->dims()[i + 2]; + } + } + auto* device_context = + const_cast(context.device_context_); + + switch (ksize.size()) { + case 2: { + if (pooling_type == "max") { + paddle::operators::math::Pool2dForwardFunctor< + Place, paddle::operators::math::pool::maxPool, T> + pool2d_forward; + paddle::operators::math::pool::maxPool pool_process; + pool2d_forward(*input, *output, ksize, strides, paddings, + pool_process, device_context); + + } else if (pooling_type == "ave") { + paddle::operators::math::Pool2dForwardFunctor< + Place, paddle::operators::math::pool::avePool, T> + pool2d_forward; + paddle::operators::math::pool::avePool pool_process; + pool2d_forward(*input, *output, ksize, strides, paddings, + pool_process, device_context); + } + } break; + case 3: { + if (pooling_type == "max") { + paddle::operators::math::Pool3dForwardFunctor< + Place, paddle::operators::math::pool::maxPool, T> + pool3d_forward; + paddle::operators::math::pool::maxPool pool_process; + pool3d_forward(*input, *output, ksize, strides, paddings, + pool_process, device_context); + } else if (pooling_type == "ave") { + paddle::operators::math::Pool3dForwardFunctor< + Place, paddle::operators::math::pool::avePool, T> + pool3d_forward; + paddle::operators::math::pool::avePool pool_process; + pool3d_forward(*input, *output, ksize, strides, paddings, + pool_process, device_context); + } + } break; + } + } +}; + +template +class PoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output = context.Input("Output"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + + int global_pooling = context.Attr("global_pooling"); + std::string pooling_type = context.Attr("pooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + if (global_pooling == 1) { + for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = input->dims()[i + 2]; + } + auto* device_context = + const_cast(context.device_context_); + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*input_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + if (pooling_type == "max") { + paddle::operators::math::Pool2dBackwardFunctor< + Place, paddle::operators::math::pool::maxPool, T> + pool2d_backward; + paddle::operators::math::pool::maxPool pool_process; + pool2d_backward(*input, *input_grad, *output, *output_grad, ksize, + strides, paddings, pool_process, device_context); + } else if (pooling_type == "ave") { + paddle::operators::math::Pool2dBackwardFunctor< + Place, paddle::operators::math::pool::avePool, T> + pool2d_backward; + paddle::operators::math::pool::avePool pool_process; + pool2d_backward(*input, *input_grad, *output, *output_grad, ksize, + strides, paddings, pool_process, device_context); + } + } break; + case 3: { + if (pooling_type == "max") { + paddle::operators::math::Pool3dBackwardFunctor< + Place, paddle::operators::math::pool::maxPool, T> + pool3d_backward; + paddle::operators::math::pool::maxPool pool_process; + pool3d_backward(*input, *input_grad, *output, *output_grad, ksize, + strides, paddings, pool_process, device_context); + } else if (pooling_type == "ave") { + paddle::operators::math::Pool3dBackwardFunctor< + Place, paddle::operators::math::pool::avePool, T> + pool3d_backward; + paddle::operators::math::pool::avePool pool_process; + pool3d_backward(*input, *input_grad, *output, *output_grad, ksize, + strides, paddings, pool_process, device_context); + } + } break; + } + } + } +}; + +} // namespace operators +} // namespace paddle