From 45a8c9ddaf5d16fdeeb6a424988d23c121d207b4 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Tue, 21 Nov 2017 16:28:51 +0800 Subject: [PATCH] add unpool2d make ok --- paddle/operators/CMakeLists.txt | 7 +++++++ paddle/operators/math/unpooling.cc | 26 ++++++++++---------------- paddle/operators/math/unpooling.cu | 21 ++++++++++++--------- paddle/operators/math/unpooling.h | 5 +++-- paddle/operators/unpool_op.cc | 25 ++++++++++++++++--------- paddle/operators/unpool_op.cu.cc | 7 +++++-- paddle/operators/unpool_op.h | 13 ++++++------- 7 files changed, 59 insertions(+), 45 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index ee25abd6cb5..d53bca277da 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -80,6 +80,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() + # unpool_op contains several operators + if ("${TARGET}" STREQUAL "unpool_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(unpool2d);\n") + endif() + # pool_cudnn_op contains several operators if ("${TARGET}" STREQUAL "pool_cudnn_op") set(pybind_flag 1) diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc index 36506b903ed..8cfdb4bb605 100644 --- a/paddle/operators/math/unpooling.cc +++ b/paddle/operators/math/unpooling.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/maxouting.h" +#include "paddle/operators/math/unpooling.h" namespace paddle { namespace operators { @@ -20,7 +20,7 @@ namespace math { // All tensors are in NCHW format template -class Unpool2d_Max_Functor { +class Unpool2d_MaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -36,16 +36,14 @@ class Unpool2d_Max_Functor { int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; const T* input_data = input.data(); - const T* indices_data = indices.data(); + const int * indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); for (int b = 0; b < batch_size; ++b) { for (int c = 0; c < output_channels; ++c) { for (int i = 0; i < input_feasize; ++i) { int index = indices_data[i]; - if(index > output_feasize) { - //抛一个异常! - } + // PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); output_data[index] = input_data[i]; } input_data += input_feasize; @@ -70,26 +68,22 @@ public: const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output->dims()[1]; - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; - const T* input_data = input.data(); - const T* indices_data = indices.data(); - const T* output_data = output.data(); + const int* indices_data = indices.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int b = 0; b < batch_size; ++b) { for (int c = 0; c < output_channels; ++c) { - for (int f = 0; f < input_feasize; ++f) { + for (int i = 0; i < input_feasize; ++i) { int index = indices_data[i]; - if(index > output_feasize) { - //抛一个异常! - } + // PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); input_grad_data[i] = output_grad_data[index]; } input_grad_data += input_feasize; diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu index 53e88a57c14..c8e7b252349 100644 --- a/paddle/operators/math/unpooling.cu +++ b/paddle/operators/math/unpooling.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/maxouting.h" +#include "paddle/operators/math/unpooling.h" #include "paddle/platform/cuda_helper.h" namespace paddle { @@ -22,7 +22,7 @@ namespace math { template __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, - const T* indices_data, + const int* indices_data, const int input_height, const int input_width, T* output_data, @@ -30,16 +30,19 @@ __global__ void KernelUnpool2dMax(const int nthreads, const int output_width) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; + // int output_feasize = output_height * output_width; for (int i = index; i < nthreads; i += offset) { int out_offset = i / (input_height * input_width) \ * output_height * output_width; int out_index = indices_data[i]; + // PADDLE_ENFORCE(out_index < output_feasize, "err index in unpooling!"); output_data[out_offset + out_index] = input_data[i]; } } template __global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data, + const int* indices_data, const int input_height, const int input_width, const T* output_data, @@ -49,10 +52,13 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, T* input_grad) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; + // int output_feasize = output_height * output_width; for (int i = index; i < nthreads; i += offset) { int out_offset = i / (input_height * input_width) \ * output_height * output_width; int out_index = indices_data[i]; + // PADDLE_ENFORCE(out_index < output_feasize, + // "err index in unpooling!"); input_grad[i] = output_grad[out_offset + out_index]; } } @@ -72,10 +78,8 @@ class Unpool2d_MaxFunctor { const int output_channels = output->dims()[1]; const int output_height = output->dims()[2]; const int output_width = output->dims()[3]; - int input_feasize = input_height * input_width; - int output_feasize = output_height * output_width; const T* input_data = input.data(); - const T* indices_data = indices.data(); + const int* indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); int nthreads = output->numel(); @@ -99,19 +103,18 @@ class Unpool2d_MaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor * input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, - int groups) { + const framework::Tensor& output_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output.dims()[1]; const int output_height = output.dims()[2]; const int output_width = output.dims()[3]; - const T* input_data = input.data(); - const T* indices_data = indices.data(); + const int* indices_data = indices.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h index bb0e0d08f02..ba4be89746f 100644 --- a/paddle/operators/math/unpooling.h +++ b/paddle/operators/math/unpooling.h @@ -26,7 +26,7 @@ namespace math { template -class Unpool2d_Max_Functor { +class Unpool2d_MaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -35,10 +35,11 @@ class Unpool2d_Max_Functor { }; template -class Unpool2d_Max_GradFunctor { +class Unpool2d_MaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad); diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index d81428e8023..9d6e69dffb8 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -20,7 +20,8 @@ using framework::Tensor; class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { public: - UnpoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + Unpool2dOpMaker(framework::OpProto* proto, \ + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor) The input tensor of unpool operator. " @@ -39,10 +40,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("ksize", "(vector ), the unpooling window size(height, width) " "of unpooling operator."); - AddAttr>("strides", "(vector, default:{1, 1}), " + AddAttr>("strides", + "(vector, default:{1, 1}), " "strides(height, width) of unpooling operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "(vector defalut:{0,0}), " + AddAttr>("paddings", + "(vector defalut:{0,0}), " "paddings(height, width) of unpooling operator.") .SetDefault({0, 0}); AddAttr("unpoolingType", @@ -73,7 +76,8 @@ class UnpoolOp : public framework::OperatorWithKernel { auto in_x_dims = ctx->GetInputDim("X"); auto in_y_dims = ctx->GetInputDim("Y"); - std::string unpooling_type = ctx->Attrs().Get("unpooling_type"); + std::string unpooling_type = \ + ctx->Attrs().Get("unpooling_type"); std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); @@ -95,7 +99,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), @@ -109,8 +113,11 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad, ops::UnpoolOpGrad); -REGISTER_OP_CPU_KERNEL(unpool2d, ops::UnpoolKernel); +REGISTER_OP_CPU_KERNEL(unpool2d, + ops::UnpoolKernel, + ops::UnpoolKernel); REGISTER_OP_CPU_KERNEL(unpool2d_grad, - ops::UnpoolGradKernel); + ops::UnpoolGradKernel, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc index 8aeef8b3cff..96fb9e40c3f 100644 --- a/paddle/operators/unpool_op.cu.cc +++ b/paddle/operators/unpool_op.cu.cc @@ -16,7 +16,10 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(unpool2d, - ops::UnpoolKernel); + ops::UnpoolKernel, + ops::UnpoolKernel); REGISTER_OP_GPU_KERNEL(unpool2d_grad, ops::UnpoolGradKernel); + float>, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h index 38903dee17b..47dd8da6f7c 100644 --- a/paddle/operators/unpool_op.h +++ b/paddle/operators/unpool_op.h @@ -37,9 +37,8 @@ class UnpoolKernel : public framework::OpKernel { switch (ksize.size()) { case 2: { if (pooling_type == "max") { - math::Unpool2d_Max_Functor unpool2d_max_forward; - unpool2d_max_forward(context.device_context(), *in_x, *in_y, - ksize, strides, paddings, out); + math::Unpool2d_MaxFunctor unpool2d_max_forward; + unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); } } break; default: { PADDLE_THROW("Pool op only supports 2D input."); } @@ -71,12 +70,12 @@ class UnpoolGradKernel : public framework::OpKernel { switch (ksize.size()) { case 2: { if (pooling_type == "max") { - math::UnpoolGradFunctor maxout_backward; - maxout_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out, - *out_grad, ksize, strides, paddings); + math::Unpool2d_MaxGradFunctor unpool2d_max_backward; + unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, + *out, *out_grad); } } break; - default: { PADDLE_THROW("Pool op only supports 2D input."); } + default: { PADDLE_THROW("Unpool op only supports 2D input."); } } } }; -- GitLab