From 90f664d0b0eb4cb0f13a5ac5c434ed9cb6544687 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Wed, 22 Nov 2017 12:52:43 +0800 Subject: [PATCH] test unpool ok cpu --- paddle/operators/CMakeLists.txt | 7 -- paddle/operators/math/unpooling.cc | 9 +-- paddle/operators/math/unpooling.cu | 4 +- paddle/operators/unpool_op.cc | 25 +++---- paddle/operators/unpool_op.cu.cc | 4 +- paddle/operators/unpool_op.h | 8 +- .../paddle/v2/fluid/tests/test_unpool2d_op.py | 47 ------------ .../paddle/v2/fluid/tests/test_unpool_op.py | 74 +++++++++++++++++++ 8 files changed, 98 insertions(+), 80 deletions(-) delete mode 100644 python/paddle/v2/fluid/tests/test_unpool2d_op.py create mode 100644 python/paddle/v2/fluid/tests/test_unpool_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d53bca277da..ee25abd6cb5 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -80,13 +80,6 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() - # unpool_op contains several operators - if ("${TARGET}" STREQUAL "unpool_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(unpool2d);\n") - endif() - # pool_cudnn_op contains several operators if ("${TARGET}" STREQUAL "pool_cudnn_op") set(pybind_flag 1) diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc index a1747e76e73..0becab721ec 100644 --- a/paddle/operators/math/unpooling.cc +++ b/paddle/operators/math/unpooling.cc @@ -32,13 +32,13 @@ class Unpool2dMaxFunctor { const int output_channels = output->dims()[1]; const int output_height = output->dims()[2]; const int output_width = output->dims()[3]; - int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; const T* input_data = input.data(); - const int * indices_data = indices.data(); + const T * indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); - + memset(output_data, 0, \ + sizeof(T) * output_feasize * output_channels * batch_size); for (int b = 0; b < batch_size; ++b) { for (int c = 0; c < output_channels; ++c) { for (int i = 0; i < input_feasize; ++i) { @@ -74,9 +74,8 @@ public: int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; - const int* indices_data = indices.data(); + const T* indices_data = indices.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int b = 0; b < batch_size; ++b) { diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu index f14dd0626fc..cd313770ab2 100644 --- a/paddle/operators/math/unpooling.cu +++ b/paddle/operators/math/unpooling.cu @@ -76,7 +76,7 @@ class Unpool2dMaxFunctor { const int output_height = output->dims()[2]; const int output_width = output->dims()[3]; const T* input_data = input.data(); - const int* indices_data = indices.data(); + const T* indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); int nthreads = output->numel(); @@ -111,7 +111,7 @@ class Unpool2dMaxGradFunctor { const int output_height = output.dims()[2]; const int output_width = output.dims()[3]; const T* input_data = input.data(); - const int* indices_data = indices.data(); + const T* indices_data = indices.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index d450d9f62ae..9036005a4d6 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -48,7 +48,7 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "(vector defalut:{0,0}), " "paddings(height, width) of unpooling operator.") .SetDefault({0, 0}); - AddAttr("unpoolingType", + AddAttr("unpoolingtype", "(string), unpooling type, can be \"max\" for max-unpooling ") .InEnum({"max"}); AddComment(R"DOC( @@ -80,8 +80,8 @@ class UnpoolOp : public framework::OperatorWithKernel { auto in_x_dims = ctx->GetInputDim("X"); auto in_y_dims = ctx->GetInputDim("Y"); - std::string unpooling_type = \ - ctx->Attrs().Get("unpooling_type"); + std::string unpoolingtype = \ + ctx->Attrs().Get("unpoolingtype"); std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); @@ -108,9 +108,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); + // PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); + // PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + // "Input(Out@GRAD) should not be null"); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Input(X@GRAD) should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); @@ -120,13 +120,12 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool2d_grad, +REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, ops::UnpoolOpGrad); -REGISTER_OP_CPU_KERNEL(unpool2d, +REGISTER_OP_CPU_KERNEL(unpool, ops::UnpoolKernel, ops::UnpoolKernel); -REGISTER_OP_CPU_KERNEL(unpool2d_grad, - ops::UnpoolGradKernel, - ops::UnpoolGradKernel); +REGISTER_OP_CPU_KERNEL(unpool_grad, + ops::UnpoolGradKernel, + ops::UnpoolGradKernel); + diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc index 96fb9e40c3f..4949fc467e0 100644 --- a/paddle/operators/unpool_op.cu.cc +++ b/paddle/operators/unpool_op.cu.cc @@ -15,10 +15,10 @@ #include "paddle/operators/unpool_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(unpool2d, +REGISTER_OP_GPU_KERNEL(unpool, ops::UnpoolKernel, ops::UnpoolKernel); -REGISTER_OP_GPU_KERNEL(unpool2d_grad, +REGISTER_OP_GPU_KERNEL(unpool_grad, ops::UnpoolGradKernel, ops::UnpoolGradKernel { const Tensor* in_x = context.Input("X"); const Tensor* in_y = context.Input("Y"); Tensor* out = context.Output("Out"); - std::string pooling_type = context.Attr("unpooling_type"); + std::string unpoolingtype = context.Attr("unpoolingtype"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); switch (ksize.size()) { case 2: { - if (pooling_type == "max") { + if (unpoolingtype == "max") { math::Unpool2dMaxFunctor unpool2d_max_forward; unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); } @@ -56,7 +56,7 @@ class UnpoolGradKernel : public framework::OpKernel { const Tensor* out_grad = context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - std::string pooling_type = context.Attr("unpooling_type"); + std::string unpoolingtype = context.Attr("unpoolingtype"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); @@ -69,7 +69,7 @@ class UnpoolGradKernel : public framework::OpKernel { } switch (ksize.size()) { case 2: { - if (pooling_type == "max") { + if (unpoolingtype == "max") { math::Unpool2dMaxGradFunctor unpool2d_max_backward; unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, *out, *out_grad); diff --git a/python/paddle/v2/fluid/tests/test_unpool2d_op.py b/python/paddle/v2/fluid/tests/test_unpool2d_op.py deleted file mode 100644 index 08f734a264f..00000000000 --- a/python/paddle/v2/fluid/tests/test_unpool2d_op.py +++ /dev/null @@ -1,47 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -def maxout_forward_naive(input, groups): - s0, s1, s2, s3 = input.shape - return np.ndarray([s0, s1 / groups, groups, s2, s3], \ - buffer = input, dtype=input.dtype).max(axis=(2)) - - -class TestUnpool2dOp(OpTest): - def setUp(self): - self.op_type = "unpool2d" - self.init_test_case() - input = np.random.random(self.shape).astype("float32") - output = self.MaxOut_forward_naive(input, self.groups).astype("float32") - - self.inputs = {'X': input} - self.attrs = { - 'strides': self.strides, - 'paddings': self.paddings, - 'ksize': self.ksize, - 'unpooling_type': self.pool_type, - } - - self.outputs = {'Out': output.astype('float32')} - - def init_pool_type(self): - self.pool_type = "max" - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - def init_test_case(self): - self.MaxOut_forward_naive = maxout_forward_naive - self.shape = [100, 6, 2, 2] - self.groups=2 - - - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py new file mode 100644 index 00000000000..566da6e26ee --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_unpool_op.py @@ -0,0 +1,74 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): + s0, s1, s2, s3 = input.shape + out_H=(s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0] + out_W=(s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1] + out = np.zeros((s0, s1, out_H, out_W)) + for nidx in xrange(s0): + for cidx in xrange(s1): + for h in xrange(s2): + for w in xrange(s3): + index = indices[nidx, cidx, h, w] + hidx = (index - index % out_W) / out_W + widx = index % out_W + out[nidx, cidx, int(hidx), int(widx)] = input[nidx, cidx, h, w] + + return out + + +class TestUnpoolOp(OpTest): + def setUp(self): + self.op_type = "unpool" + self.init_test_case() + pre_input = np.random.random(self.shape).astype("float32") + N, C, H, W = pre_input.shape + H_out = (H - self.ksize[0] + 2 * self.paddings[0]) / self.strides[0] + 1 + W_out = (W - self.ksize[1] + 2 * self.paddings[1]) / self.strides[1] + 1 + input = np.zeros((N, C, H_out, W_out)) + indices = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * self.strides[0] - self.paddings[0], 0)) + r_end = np.min((i * self.strides[0] + self.ksize[0] - self.paddings[0], H)) + c_start = np.max((j * self.strides[1] - self.paddings[1], 0)) + c_end = np.min((j * self.strides[1] + self.ksize[1] - self.paddings[1], W)) + for nidx in xrange(N): + for cidx in xrange(C): + x_masked = pre_input[nidx, cidx, r_start:r_end, c_start:c_end] + input[nidx, cidx, i, j] = x_masked.max() + arg = x_masked.argmax() + indices[nidx, cidx, i, j] = (r_start + arg / self.ksize[1]) * W + c_start + arg % self.ksize[1] + output = self.Unpool2d_forward_naive(input, indices, self.ksize, self.strides, self.paddings).astype("float32") + self.inputs = {'X': input.astype('float32'), + 'Y': indices.astype('int16')} + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'unpoolingtype': self.unpoolingtype, + } + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + print self.outputs['Out'] + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.5) + + def init_test_case(self): + self.Unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpoolingtype = "max" + self.shape = [10, 2, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + + +if __name__ == '__main__': + unittest.main() -- GitLab