diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu index c8fd58eca55318ae12348c7c7d173cbb51aabb25..99e6fd052a3bd48577d7a3fa60b40ecc45bc17a5 100644 --- a/paddle/operators/math/unpooling.cu +++ b/paddle/operators/math/unpooling.cu @@ -29,19 +29,19 @@ __global__ void KernelUnpool2dMax(const int nthreads, T* output_data, const int output_height, const int output_width) { - int bsize = input_height * input_width * channels; - int csize = input_height * input_width; - int out_bsize = output_height * output_width * channels; - int out_csize = output_height * output_width; + int in_n_stride = input_height * input_width * channels; + int in_c_stride = input_height * input_width; + int out_n_stride = output_height * output_width * channels; + int out_c_stride = output_height * output_width; int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (int i = index; i < nthreads; i += offset) { - int bidx = i / bsize; - int boffset = i % bsize; - int cidx = boffset / csize; - int out_offset = bidx * out_bsize + cidx * out_csize; + int bidx = i / in_n_stride; + int boffset = i % in_n_stride; + int cidx = boffset / in_c_stride; + int out_offset = bidx * out_n_stride + cidx * out_c_stride; int out_index = indices_data[i]; - PADDLE_ASSERT(out_index < (output_height * output_width)); + PADDLE_ASSERT(out_index < out_c_stride); output_data[out_offset + out_index] = input_data[i]; } } @@ -57,19 +57,19 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, const int output_height, const int output_width, T* input_grad) { - int bsize = input_height * input_width * channels; - int csize = input_height * input_width; - int out_bsize = output_height * output_width * channels; - int out_csize = output_height * output_width; + int in_n_stride = input_height * input_width * channels; + int in_c_stride = input_height * input_width; + int out_n_stride = output_height * output_width * channels; + int out_c_stride = output_height * output_width; int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (int i = index; i < nthreads; i += offset) { - int bidx = i / bsize; - int boffset = i % bsize; - int cidx = boffset / csize; - int out_offset = bidx * out_bsize + cidx * out_csize; + int bidx = i / in_n_stride; + int boffset = i % in_n_stride; + int cidx = boffset / in_c_stride; + int out_offset = bidx * out_n_stride + cidx * out_c_stride; int out_index = indices_data[i]; - PADDLE_ASSERT(out_index < (output_height * output_width)); + PADDLE_ASSERT(out_index < out_c_stride); input_grad[i] = output_grad[out_offset + out_index]; } } @@ -93,10 +93,8 @@ class Unpool2dMaxFunctor { const T2 * indices_data = indices.data(); T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * input_height * input_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - + int threads = 1024; + int grid = (input.numel() + threads - 1) / threads; KernelUnpool2dMax< T, T2><<(context) @@ -129,10 +127,8 @@ class Unpool2dMaxGradFunctor { const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * input_height * input_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - + int threads = 1024; + int grid = (input.numel() + threads - 1) / threads; KernelUnpool2dMaxGrad< T, T2><<(context) diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index addceca1590b1ec59e8fbd3183d992bd03888109..49a5129188e64617c3841dc6eccf4c417a362ac0 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/unpool_op.h" namespace paddle { @@ -25,7 +25,7 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) The input tensor of unpool operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of feature."); - AddInput("Y", + AddInput("Indices", "(Tensor) The input tensor of the indices given out by MaxPool2d. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of feature."); @@ -50,12 +50,10 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "(string), unpooling type, can be \"max\" for max-unpooling ") .InEnum({"max"}); AddComment(R"DOC( - "input: the input Tensor to invert - indices: the indices given out by MaxPool2d - ksize – Size of the max pooling window. - stride – Stride of the max pooling window. - "It is set to kernel_size by default. - padding – Padding that was added to the input" + "Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 + /07/iccv2011.pdf + PyTorch: http://pytorch.org/docs/master/nn.html?highlight=unpool# + torch.nn.MaxUnpool2d" )DOC"); } }; @@ -79,27 +77,20 @@ public: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp" "should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of UnpoolOp" + PADDLE_ENFORCE(ctx->HasInput("Indices"), "Input(Indices) of UnpoolOp" "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UnpoolOp should not be null."); - auto in_x_dims = ctx->GetInputDim("X"); - auto in_y_dims = ctx->GetInputDim("Y"); + auto in_y_dims = ctx->GetInputDim("Indices"); std::string unpooling_type = ctx->Attrs().Get("unpooling_type"); std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); - PADDLE_ENFORCE(in_x_dims.size() == 4, "Unpooling intput must be of 4-dimensional."); - for (int i = 0; i < 4; ++i) { - PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i], - "X size must be eq Y size!"); - } - - + PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { output_shape.push_back( diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc index 0a1d8b5996de47faef50042911dcca72d5d8a337..9b5ac667d39e43fe85ee830602daf1ad839748c2 100644 --- a/paddle/operators/unpool_op.cu.cc +++ b/paddle/operators/unpool_op.cu.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/unpool_op.h" diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h index f05d22b49fe9cd32e6adca1c1f4019fec5d9bfe3..dfd4ef12b5d685ef8be463a6a2e508c3cc82c8e1 100644 --- a/paddle/operators/unpool_op.h +++ b/paddle/operators/unpool_op.h @@ -2,7 +2,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. -You may obtain a copy of the License at +Indicesou may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -26,7 +26,7 @@ class UnpoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* in_x = context.Input("X"); - const framework::Tensor* in_y = context.Input("Y"); + const framework::Tensor* in_y = context.Input("Indices"); auto * out = context.Output("Out"); std::string unpooling_type = context.Attr("unpooling_type"); std::vector ksize = context.Attr>("ksize"); @@ -47,7 +47,7 @@ class UnpoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* in_x = context.Input("X"); - const framework::Tensor* in_y = context.Input("Y"); + const framework::Tensor* in_y = context.Input("Indices"); const framework::Tensor* out = context.Input("Out"); const framework::Tensor* out_grad = context.Input(framework::GradVarName("Out")); diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py index 22826dc1b3561f9abd5ced628c0a1c9ccd483602..b3c6c85025dc375c96e7c8f9aa35299e37b3dbaa 100644 --- a/python/paddle/v2/fluid/tests/test_unpool_op.py +++ b/python/paddle/v2/fluid/tests/test_unpool_op.py @@ -5,16 +5,16 @@ from op_test import OpTest def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): s0, s1, s2, s3 = input.shape - out_H=(s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0] - out_W=(s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1] - out = np.zeros((s0, s1, out_H, out_W)) + out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0] + out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1] + out = np.zeros((s0, s1, out_hsize, out_wsize)) for nidx in xrange(s0): for cidx in xrange(s1): for h in xrange(s2): for w in xrange(s3): index = indices[nidx, cidx, h, w] - hidx = (index - index % out_W) / out_W - widx = index % out_W + hidx = (index - index % out_wsize) / out_wsize + widx = index % out_wsize out[nidx, cidx, int(hidx), int(widx)] = \ input[nidx, cidx, h, w] @@ -26,34 +26,34 @@ class TestUnpoolOp(OpTest): self.op_type = "unpool" self.init_test_case() pre_input = np.random.random(self.shape).astype("float32") - N, C, H, W = pre_input.shape - H_out = (H - self.ksize[0] + 2 * self.paddings[0]) / \ + nsize, csize, hsize, wsize = pre_input.shape + hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \ self.strides[0] + 1 - W_out = (W - self.ksize[1] + 2 * self.paddings[1]) / \ + wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \ self.strides[1] + 1 - input = np.zeros((N, C, H_out, W_out)) - indices = np.zeros((N, C, H_out, W_out)) - for i in xrange(H_out): - for j in xrange(W_out): + input = np.zeros((nsize, csize, hsize_out, wsize_out)) + indices = np.zeros((nsize, csize, hsize_out, wsize_out)) + for i in xrange(hsize_out): + for j in xrange(wsize_out): r_start = np.max((i * self.strides[0] - self.paddings[0], 0)) r_end = np.min((i * self.strides[0] + self.ksize[0] - \ - self.paddings[0], H)) + self.paddings[0], hsize)) c_start = np.max((j * self.strides[1] - self.paddings[1], 0)) c_end = np.min((j * self.strides[1] + self.ksize[1] - \ - self.paddings[1], W)) - for nidx in xrange(N): - for cidx in xrange(C): + self.paddings[1], wsize)) + for nidx in xrange(nsize): + for cidx in xrange(csize): x_masked = pre_input[nidx, cidx, r_start:r_end, \ c_start:c_end] input[nidx, cidx, i, j] = x_masked.max() arg = x_masked.argmax() indices[nidx, cidx, i, j] = \ - (r_start + arg / self.ksize[1]) * W + \ + (r_start + arg / self.ksize[1]) * wsize + \ c_start + arg % self.ksize[1] output = self.Unpool2d_forward_naive(input, indices, self.ksize, \ self.strides, self.paddings).astype("float32") self.inputs = {'X': input.astype('float32'), - 'Y': indices.astype('int32')} + 'Indices': indices.astype('int32')} self.attrs = { 'strides': self.strides, 'paddings': self.paddings,