diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc index 8cfdb4bb605039e36e3b781a4b74493568a80d52..a1747e76e73689cbc36ce3e6b035c33f9f7f9afb 100644 --- a/paddle/operators/math/unpooling.cc +++ b/paddle/operators/math/unpooling.cc @@ -20,7 +20,7 @@ namespace math { // All tensors are in NCHW format template -class Unpool2d_MaxFunctor { +class Unpool2dMaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -43,7 +43,7 @@ class Unpool2d_MaxFunctor { for (int c = 0; c < output_channels; ++c) { for (int i = 0; i < input_feasize; ++i) { int index = indices_data[i]; - // PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); + PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); output_data[index] = input_data[i]; } input_data += input_feasize; @@ -57,7 +57,7 @@ class Unpool2d_MaxFunctor { template -class Unpool2d_MaxGradFunctor { +class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -83,7 +83,7 @@ public: for (int c = 0; c < output_channels; ++c) { for (int i = 0; i < input_feasize; ++i) { int index = indices_data[i]; - // PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); + PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); input_grad_data[i] = output_grad_data[index]; } input_grad_data += input_feasize; @@ -94,10 +94,10 @@ public: } }; -template class Unpool2d_MaxGradFunctor; -template class Unpool2d_MaxGradFunctor; -template class Unpool2d_MaxFunctor; -template class Unpool2d_MaxFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu index c8e7b252349bd2c63877ee8e6eb09c1135fe6a39..f14dd0626fc1caed8389302f54fc9a17b30a14b8 100644 --- a/paddle/operators/math/unpooling.cu +++ b/paddle/operators/math/unpooling.cu @@ -30,12 +30,11 @@ __global__ void KernelUnpool2dMax(const int nthreads, const int output_width) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; - // int output_feasize = output_height * output_width; for (int i = index; i < nthreads; i += offset) { int out_offset = i / (input_height * input_width) \ * output_height * output_width; int out_index = indices_data[i]; - // PADDLE_ENFORCE(out_index < output_feasize, "err index in unpooling!"); + PADDLE_ASSERT(out_index < (output_height * output_width)); output_data[out_offset + out_index] = input_data[i]; } } @@ -52,13 +51,11 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, T* input_grad) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; - // int output_feasize = output_height * output_width; for (int i = index; i < nthreads; i += offset) { int out_offset = i / (input_height * input_width) \ * output_height * output_width; int out_index = indices_data[i]; - // PADDLE_ENFORCE(out_index < output_feasize, - // "err index in unpooling!"); + PADDLE_ASSERT(out_index < (output_height * output_width)); input_grad[i] = output_grad[out_offset + out_index]; } } @@ -66,7 +63,7 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, * All tensors are in NCHW format. */ template -class Unpool2d_MaxFunctor { +class Unpool2dMaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -99,7 +96,7 @@ class Unpool2d_MaxFunctor { * All tensors are in NCHW format. */ template -class Unpool2d_MaxGradFunctor { +class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -135,11 +132,11 @@ class Unpool2d_MaxGradFunctor { } }; -template class Unpool2d_MaxGradFunctor; -template class Unpool2d_MaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; -template class Unpool2d_MaxFunctor; -template class Unpool2d_MaxFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h index ba4be89746fb946c3ec42394e1e732d43881ea5c..93a77bf53e1389f6a96fefab14d71e8af6453958 100644 --- a/paddle/operators/math/unpooling.h +++ b/paddle/operators/math/unpooling.h @@ -26,7 +26,7 @@ namespace math { template -class Unpool2d_MaxFunctor { +class Unpool2dMaxFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -35,7 +35,7 @@ class Unpool2d_MaxFunctor { }; template -class Unpool2d_MaxGradFunctor { +class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index 9d6e69dffb8ee54796db7715e65daa79643f1b24..d450d9f62ae4cd885b6ee4901b9b3b630895038f 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -49,11 +49,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "paddings(height, width) of unpooling operator.") .SetDefault({0, 0}); AddAttr("unpoolingType", - "(string), unpooling type, can be \"max\" for max-unpooling " - "and \"avg\" for average-unpooling.") - .InEnum({"max", "avg"}); + "(string), unpooling type, can be \"max\" for max-unpooling ") + .InEnum({"max"}); AddComment(R"DOC( - + "input: the input Tensor to invert" + "indices: the indices given out by MaxPool2d" + "ksize – Size of the max pooling window." + "stride – Stride of the max pooling window." + "It is set to kernel_size by default." + "padding – Padding that was added to the input" )DOC"); } }; @@ -82,8 +86,13 @@ class UnpoolOp : public framework::OperatorWithKernel { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); - PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, - "Unpooling intput should be 4-D or 5-D tensor."); + PADDLE_ENFORCE(in_x_dims.size() == 4, + "Unpooling intput should be 4-D."); + for (int i = 0; i < 4; ++i) { + PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i], + "X size must be eq Y size!"); + } + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h index 47dd8da6f7c28050ce37a8fb74fd5329bf230d27..44115b0726d803a5de7b717db65be8c95fbf8f8b 100644 --- a/paddle/operators/unpool_op.h +++ b/paddle/operators/unpool_op.h @@ -37,7 +37,7 @@ class UnpoolKernel : public framework::OpKernel { switch (ksize.size()) { case 2: { if (pooling_type == "max") { - math::Unpool2d_MaxFunctor unpool2d_max_forward; + math::Unpool2dMaxFunctor unpool2d_max_forward; unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); } } break; @@ -70,7 +70,7 @@ class UnpoolGradKernel : public framework::OpKernel { switch (ksize.size()) { case 2: { if (pooling_type == "max") { - math::Unpool2d_MaxGradFunctor unpool2d_max_backward; + math::Unpool2dMaxGradFunctor unpool2d_max_backward; unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out, *out_grad); } diff --git a/python/paddle/v2/fluid/tests/test_unpool2d_op.py b/python/paddle/v2/fluid/tests/test_unpool2d_op.py new file mode 100644 index 0000000000000000000000000000000000000000..08f734a264f2f97a679b49af29b0115a27d8dad4 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_unpool2d_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def maxout_forward_naive(input, groups): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + +class TestUnpool2dOp(OpTest): + def setUp(self): + self.op_type = "unpool2d" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups).astype("float32") + + self.inputs = {'X': input} + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'unpooling_type': self.pool_type, + } + + self.outputs = {'Out': output.astype('float32')} + + def init_pool_type(self): + self.pool_type = "max" + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 6, 2, 2] + self.groups=2 + + + + +if __name__ == '__main__': + unittest.main()