diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d39f7bf452e5fa8f5add8480a81fa7563f39414a..c720cce1828c8c94544d6a52611e67b338fa2ab1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -139,6 +139,7 @@ set(DEPS_OPS sum_op pool_op maxout_op + unpool_op pool_with_index_op nccl_op sequence_conv_op @@ -151,6 +152,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) op_library(maxout_op DEPS maxouting) +op_library(unpool_op DEPS unpooling) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) if(WITH_GPU) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index b330f30d216b8003482970c5a6f5c1f62613b320..cd7e33cd7cbf4cb0923d9a7d6f720d52a10c71a7 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -14,6 +14,7 @@ if(WITH_GPU) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) + nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -26,6 +27,7 @@ else() cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(maxouting SRCS maxouting.cc DEPS device_context) + cc_library(unpooling SRCS unpooling.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..36506b903ed8a69df46aaa1c7c69502bd24f9019 --- /dev/null +++ b/paddle/operators/math/unpooling.cc @@ -0,0 +1,110 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { +namespace math { + +// All tensors are in NCHW format +template +class Unpool2d_Max_Functor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + framework::Tensor * output) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const T* input_data = input.data(); + const T* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + if(index > output_feasize) { + //抛一个异常! + } + output_data[index] = input_data[i]; + } + input_data += input_feasize; + indices_data += input_feasize; + output_data += output_feasize; + } + } + } +}; + + + +template +class Unpool2d_MaxGradFunctor { +public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + framework::Tensor * input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const T* input_data = input.data(); + const T* indices_data = indices.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int f = 0; f < input_feasize; ++f) { + int index = indices_data[i]; + if(index > output_feasize) { + //抛一个异常! + } + input_grad_data[i] = output_grad_data[index]; + } + input_grad_data += input_feasize; + indices_data += input_feasize; + output_grad_data += output_feasize; + } + } + } +}; + +template class Unpool2d_MaxGradFunctor; +template class Unpool2d_MaxGradFunctor; +template class Unpool2d_MaxFunctor; +template class Unpool2d_MaxFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu new file mode 100644 index 0000000000000000000000000000000000000000..53e88a57c14097a0da34d036f034e14f9dfc95d5 --- /dev/null +++ b/paddle/operators/math/unpooling.cu @@ -0,0 +1,143 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelUnpool2dMax(const int nthreads, + const T* input_data, + const T* indices_data, + const int input_height, + const int input_width, + T* output_data, + const int output_height, + const int output_width) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int out_offset = i / (input_height * input_width) \ + * output_height * output_width; + int out_index = indices_data[i]; + output_data[out_offset + out_index] = input_data[i]; + } +} +template +__global__ void KernelUnpool2dMaxGrad(const int nthreads, + const T* input_data, + const int input_height, + const int input_width, + const T* output_data, + const T* output_grad, + const int output_height, + const int output_width, + T* input_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int out_offset = i / (input_height * input_width) \ + * output_height * output_width; + int out_index = indices_data[i]; + input_grad[i] = output_grad[out_offset + out_index]; + } +} +/* + * All tensors are in NCHW format. + */ +template +class Unpool2d_MaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + framework::Tensor * output) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const T* input_data = input.data(); + const T* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + + int nthreads = output->numel(); + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelUnpool2dMax< + T><<(context) + .stream()>>>(nthreads, input_data, indices_data, + input_height, input_width, + output_data, output_height, output_width); + } +}; +/* + * All tensors are in NCHW format. + */ +template +class Unpool2d_MaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor * input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + + const T* input_data = input.data(); + const T* indices_data = indices.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int nthreads = output.numel(); + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelUnpool2dMaxGrad< + T><<(context) + .stream()>>>( + nthreads, input_data, indices_data, + input_height, input_width, + output_data, output_grad_data, + output_height, output_width, + input_grad_data); + } +}; + +template class Unpool2d_MaxGradFunctor; +template class Unpool2d_MaxGradFunctor; + +template class Unpool2d_MaxFunctor; +template class Unpool2d_MaxFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h new file mode 100644 index 0000000000000000000000000000000000000000..bb0e0d08f0202bfa424fbf4270a4746c14deeab8 --- /dev/null +++ b/paddle/operators/math/unpooling.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +#define FLT_MAX \ + __FLT_MAX__ + +template + +class Unpool2d_Max_Functor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + framework::Tensor * output); +}; + +template +class Unpool2d_Max_GradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor * input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d81428e8023afb9a29554f7c0b9d2bd63cb16434 --- /dev/null +++ b/paddle/operators/unpool_op.cc @@ -0,0 +1,116 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/operators/unpool_op.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + UnpoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input tensor of unpool operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddInput("Y", + "(Tensor) The input tensor of the indices given out by MaxPool2d. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of unpool operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + AddAttr>("ksize", + "(vector ), the unpooling window size(height, width) " + "of unpooling operator."); + AddAttr>("strides", "(vector, default:{1, 1}), " + "strides(height, width) of unpooling operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "(vector defalut:{0,0}), " + "paddings(height, width) of unpooling operator.") + .SetDefault({0, 0}); + AddAttr("unpoolingType", + "(string), unpooling type, can be \"max\" for max-unpooling " + "and \"avg\" for average-unpooling.") + .InEnum({"max", "avg"}); + AddComment(R"DOC( + + )DOC"); + } +}; + +int OutputSize(int input_size, int ksize, int padding, int stride) { + int output_size = (input_size -1) * stride - 2 * padding + ksize; + return output_size; +} + +class UnpoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of UnpoolOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnpoolOp should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + auto in_y_dims = ctx->GetInputDim("Y"); + std::string unpooling_type = ctx->Attrs().Get("unpooling_type"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Unpooling intput should be 4-D or 5-D tensor."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back( + OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class UnpoolOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad, + ops::UnpoolOpGrad); +REGISTER_OP_CPU_KERNEL(unpool2d, ops::UnpoolKernel); +REGISTER_OP_CPU_KERNEL(unpool2d_grad, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..8aeef8b3cff5b8d763de2866cc2ade3d6ae58182 --- /dev/null +++ b/paddle/operators/unpool_op.cu.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/unpool_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(unpool2d, + ops::UnpoolKernel); +REGISTER_OP_GPU_KERNEL(unpool2d_grad, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..38903dee17ba1b5af3473a95359c08226d607735 --- /dev/null +++ b/paddle/operators/unpool_op.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/unpooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class UnpoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + const Tensor* in_y = context.Input("Y"); + Tensor* out = context.Output("Out"); + std::string pooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + switch (ksize.size()) { + case 2: { + if (pooling_type == "max") { + math::Unpool2d_Max_Functor unpool2d_max_forward; + unpool2d_max_forward(context.device_context(), *in_x, *in_y, + ksize, strides, paddings, out); + } + } break; + default: { PADDLE_THROW("Pool op only supports 2D input."); } + } + } +}; + +template +class UnpoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + const Tensor* in_y = context.Input("Y"); + const Tensor* out = context.Input("Out"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + std::string pooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + auto& device_ctx = context.device_context(); + math::SetConstant zero; + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0.0)); + } + switch (ksize.size()) { + case 2: { + if (pooling_type == "max") { + math::UnpoolGradFunctor maxout_backward; + maxout_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out, + *out_grad, ksize, strides, paddings); + } + } break; + default: { PADDLE_THROW("Pool op only supports 2D input."); } + } + } +}; + +} // namespace operators +} // namespace paddle