diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7e5d4fd640f4399d1a217d1a0be76b3da457c0cc..937441b318095eadb9022c1d7578ad8aca2dadc8 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -191,6 +191,7 @@ set(DEPS_OPS sum_op pool_op maxout_op + unpool_op pool_with_index_op conv_op conv_transpose_op @@ -235,6 +236,7 @@ op_library(adagrad_op DEPS selected_rows_functor) op_library(conv_op DEPS vol2col) op_library(pool_op DEPS pooling) op_library(maxout_op DEPS maxouting) +op_library(unpool_op DEPS unpooling) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 3017f133afc5d4dcd484c78b44591a876ab4d667..bf47879f772a3013bd7ce78c6f8a6aefe65298f9 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -13,8 +13,9 @@ if(WITH_GPU) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) - nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) + nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) + nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -26,8 +27,9 @@ else() cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) - cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(maxouting SRCS maxouting.cc DEPS device_context) + cc_library(unpooling SRCS unpooling.cc DEPS device_context) + cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..b57d3dc1414cff492db8d7d503a7fce370a3f151 --- /dev/null +++ b/paddle/operators/math/unpooling.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/unpooling.h" +namespace paddle { +namespace operators { +namespace math { +template +class Unpool2dMaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); + output_data[index] = input_data[i]; + } + input_data += input_feasize; + indices_data += input_feasize; + output_data += output_feasize; + } + } + } +}; +template +class Unpool2dMaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const int* indices_data = indices.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); + input_grad_data[i] = output_grad_data[index]; + } + input_grad_data += input_feasize; + indices_data += input_feasize; + output_grad_data += output_feasize; + } + } + } +}; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu new file mode 100644 index 0000000000000000000000000000000000000000..37c3c8b689f9a69b68ddffd23813fa9ad8ced0e7 --- /dev/null +++ b/paddle/operators/math/unpooling.cu @@ -0,0 +1,134 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/unpooling.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { +template +__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, + const int* indices_data, + const int input_height, const int input_width, + const int channels, T* output_data, + const int output_height, + const int output_width) { + int in_n_stride = input_height * input_width * channels; + int in_c_stride = input_height * input_width; + int out_n_stride = output_height * output_width * channels; + int out_c_stride = output_height * output_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int bidx = i / in_n_stride; + int boffset = i % in_n_stride; + int cidx = boffset / in_c_stride; + int out_offset = bidx * out_n_stride + cidx * out_c_stride; + int out_index = indices_data[i]; + PADDLE_ASSERT(out_index < out_c_stride); + output_data[out_offset + out_index] = input_data[i]; + } +} +template +__global__ void KernelUnpool2dMaxGrad( + const int nthreads, const T* input_data, const int* indices_data, + const int input_height, const int input_width, const int channels, + const T* output_data, const T* output_grad, const int output_height, + const int output_width, T* input_grad) { + int in_n_stride = input_height * input_width * channels; + int in_c_stride = input_height * input_width; + int out_n_stride = output_height * output_width * channels; + int out_c_stride = output_height * output_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int bidx = i / in_n_stride; + int boffset = i % in_n_stride; + int cidx = boffset / in_c_stride; + int out_offset = bidx * out_n_stride + cidx * out_c_stride; + int out_index = indices_data[i]; + PADDLE_ASSERT(out_index < out_c_stride); + input_grad[i] = output_grad[out_offset + out_index]; + } +} +/* + * All tensors are in NCHW format. + */ +template +class Unpool2dMaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + int threads = 1024; + int grid = (input.numel() + threads - 1) / threads; + KernelUnpool2dMax< + T><<(context) + .stream()>>>(input.numel(), input_data, indices_data, + input_height, input_width, output_channels, + output_data, output_height, output_width); + } +}; +/* + * All tensors are in NCHW format. + */ +template +class Unpool2dMaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int threads = 1024; + int grid = (input.numel() + threads - 1) / threads; + KernelUnpool2dMaxGrad< + T><<(context) + .stream()>>>(input.numel(), input_data, indices_data, + input_height, input_width, output_channels, + output_data, output_grad_data, output_height, + output_width, input_grad_data); + } +}; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h new file mode 100644 index 0000000000000000000000000000000000000000..7077d7c2274fd9e02b69ef343f310f4ffbbcff1a --- /dev/null +++ b/paddle/operators/math/unpooling.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { +template +class Unpool2dMaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output); +}; +template +class Unpool2dMaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..89c48e071cf351f7d7b9cf26a5d4989af291da57 --- /dev/null +++ b/paddle/operators/unpool_op.cc @@ -0,0 +1,143 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/unpool_op.h" +namespace paddle { +namespace operators { + +class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Unpool2dOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of unpool operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddInput( + "Indices", + "(Tensor) The input tensor of the indices given out by MaxPool2d. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of unpool operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + AddAttr>( + "ksize", + "(vector), the unpooling window size(height, width) " + "of unpooling operator."); + AddAttr>("strides", + "(vector, default:{1, 1}), " + "strides (height, width) of unpooling operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "(vector defalut:{0,0}), " + "paddings (height, width) of unpooling operator.") + .SetDefault({0, 0}); + AddAttr( + "unpooling_type", + "(string), unpooling type, can be \"max\" for max-unpooling ") + .InEnum({"max"}); + AddComment(R"DOC( + "Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where + $$ + H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\ + W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1] + $$ + Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 + /07/iccv2011.pdf + )DOC"); + } +}; + +int OutputSize(int input_size, int ksize, int padding, int stride) { + int output_size = (input_size - 1) * stride - 2 * padding + ksize; + return output_size; +} + +class UnpoolOp : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnpoolOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input(Indices) of UnpoolOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnpoolOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + auto in_y_dims = ctx->GetInputDim("Indices"); + std::string unpooling_type = + ctx->Attrs().Get("unpooling_type"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + PADDLE_ENFORCE(in_x_dims.size() == 4, + "Unpooling intput must be of 4-dimensional."); + PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back( + OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class UnpoolOpGrad : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, + ops::UnpoolOpGrad); +REGISTER_OP_CPU_KERNEL(unpool, + ops::UnpoolKernel, + ops::UnpoolKernel); +REGISTER_OP_CPU_KERNEL( + unpool_grad, ops::UnpoolGradKernel, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..18aafb7dc74ed474ed3ec5e8a388ecdb71b9a8f5 --- /dev/null +++ b/paddle/operators/unpool_op.cu.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/unpool_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(unpool, + ops::UnpoolKernel, + ops::UnpoolKernel); +REGISTER_OP_GPU_KERNEL( + unpool_grad, ops::UnpoolGradKernel, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..243eb7e532c5149db4fb1b381fd8664ae4bdd81a --- /dev/null +++ b/paddle/operators/unpool_op.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/unpooling.h" + +namespace paddle { +namespace operators { +template +class UnpoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Indices"); + auto* out = context.Output("Out"); + std::string unpooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + T* output_data = out->mutable_data(context.GetPlace()); + if (output_data) { + math::SetConstant set_zero; + set_zero(context.device_context(), out, static_cast(0)); + } + math::Unpool2dMaxFunctor unpool2d_max_forward; + unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); + } +}; +template +class UnpoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Indices"); + const framework::Tensor* out = context.Input("Out"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); + std::string unpooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + auto& device_ctx = context.device_context(); + math::SetConstant zero; + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0)); + } + math::Unpool2dMaxGradFunctor unpool2d_max_backward; + unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out, + *out_grad, in_x_grad); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e87f283042c081ed9f232d140ff8c303cd3d1858 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_unpool_op.py @@ -0,0 +1,83 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): + s0, s1, s2, s3 = input.shape + out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0] + out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1] + out = np.zeros((s0, s1, out_hsize, out_wsize)) + for nidx in xrange(s0): + for cidx in xrange(s1): + for h in xrange(s2): + for w in xrange(s3): + index = indices[nidx, cidx, h, w] + hidx = (index - index % out_wsize) / out_wsize + widx = index % out_wsize + out[nidx, cidx, int(hidx), int(widx)] = \ + input[nidx, cidx, h, w] + + return out + + +class TestUnpoolOp(OpTest): + def setUp(self): + self.op_type = "unpool" + self.init_test_case() + pre_input = np.random.random(self.shape).astype("float32") + nsize, csize, hsize, wsize = pre_input.shape + hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \ + self.strides[0] + 1 + wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \ + self.strides[1] + 1 + input = np.zeros((nsize, csize, hsize_out, wsize_out)) + indices = np.zeros((nsize, csize, hsize_out, wsize_out)) + for i in xrange(hsize_out): + for j in xrange(wsize_out): + r_start = np.max((i * self.strides[0] - self.paddings[0], 0)) + r_end = np.min((i * self.strides[0] + self.ksize[0] - \ + self.paddings[0], hsize)) + c_start = np.max((j * self.strides[1] - self.paddings[1], 0)) + c_end = np.min((j * self.strides[1] + self.ksize[1] - \ + self.paddings[1], wsize)) + for nidx in xrange(nsize): + for cidx in xrange(csize): + x_masked = pre_input[nidx, cidx, r_start:r_end, \ + c_start:c_end] + input[nidx, cidx, i, j] = x_masked.max() + arg = x_masked.argmax() + indices[nidx, cidx, i, j] = \ + (r_start + arg / self.ksize[1]) * wsize + \ + c_start + arg % self.ksize[1] + output = self.unpool2d_forward_naive(input, indices, self.ksize, \ + self.strides, self.paddings).astype("float32") + self.inputs = { + 'X': input.astype('float32'), + 'Indices': indices.astype('int32') + } + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'unpooling_type': self.unpooling_type, + } + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpooling_type = "max" + self.shape = [6, 4, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main()