From ceee71a0ef63cee73e91d920c5739145bf4bf735 Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Fri, 27 Aug 2021 18:59:14 +0800 Subject: [PATCH] Add unpool2d op & Expose max_unpool2d API (#35056) * add maxunppol2d op, test=develop * fix typo, test=develop * fix unpool unitest, test=develop * fix unpool code-example, test=develop * fix for unpool_op_unittest,test=develop * fix example code, test=develop * add noqa:F401, test=develop * fix converage, test=develop * fix unitest for unpool, test=develop * rename unpool2d to unpool, test=develop * rename unpool2d to unpool, test=develop --- paddle/fluid/operators/math/unpooling.cc | 1 + paddle/fluid/operators/math/unpooling.cu | 47 +-- paddle/fluid/operators/unpool_op.cc | 15 +- .../fluid/tests/unittests/test_unpool_op.py | 272 +++++++++++++++--- python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/pooling.py | 203 +++++++++---- python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/pooling.py | 92 ++++++ 9 files changed, 511 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/operators/math/unpooling.cc b/paddle/fluid/operators/math/unpooling.cc index f5f5b380df2..bcb2b92780c 100644 --- a/paddle/fluid/operators/math/unpooling.cc +++ b/paddle/fluid/operators/math/unpooling.cc @@ -38,6 +38,7 @@ class Unpool2dMaxFunctor { for (int c = 0; c < output_channels; ++c) { for (int i = 0; i < input_feasize; ++i) { int index = indices_data[i]; + PADDLE_ENFORCE_LT( index, output_feasize, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/math/unpooling.cu b/paddle/fluid/operators/math/unpooling.cu index a73f76f53be..ad23892f379 100644 --- a/paddle/fluid/operators/math/unpooling.cu +++ b/paddle/fluid/operators/math/unpooling.cu @@ -25,48 +25,27 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, const int channels, T* output_data, const int output_height, const int output_width) { - int in_n_stride = input_height * input_width * channels; - int in_c_stride = input_height * input_width; - int out_n_stride = output_height * output_width * channels; - int out_c_stride = output_height * output_width; - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (int i = index; i < nthreads; i += offset) { - int bidx = i / in_n_stride; - int boffset = i % in_n_stride; - int cidx = boffset / in_c_stride; - int out_offset = bidx * out_n_stride + cidx * out_c_stride; - int out_index = indices_data[i]; - PADDLE_ENFORCE(out_index < out_c_stride, - "out_index < out_c_stride. Expected %ld < %ld, but got " - "%ld >= %ld. Please check input value.", - out_index, out_c_stride, out_index, out_c_stride); - output_data[out_offset + out_index] = input_data[i]; + CUDA_KERNEL_LOOP(linearIndex, nthreads) { + int c = (linearIndex / input_width / input_height) % channels; + int n = linearIndex / input_width / input_height / channels; + output_data += (n * channels + c) * output_height * output_width; + int maxind = indices_data[linearIndex]; + output_data[maxind] = input_data[linearIndex]; } } + template __global__ void KernelUnpool2dMaxGrad( const int nthreads, const T* input_data, const int* indices_data, const int input_height, const int input_width, const int channels, const T* output_data, const T* output_grad, const int output_height, const int output_width, T* input_grad) { - int in_n_stride = input_height * input_width * channels; - int in_c_stride = input_height * input_width; - int out_n_stride = output_height * output_width * channels; - int out_c_stride = output_height * output_width; - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (int i = index; i < nthreads; i += offset) { - int bidx = i / in_n_stride; - int boffset = i % in_n_stride; - int cidx = boffset / in_c_stride; - int out_offset = bidx * out_n_stride + cidx * out_c_stride; - int out_index = indices_data[i]; - PADDLE_ENFORCE(out_index < out_c_stride, - "out_index < out_c_stride. Expected %ld < %ld, but got " - "%ld >= %ld. Please check input value.", - out_index, out_c_stride, out_index, out_c_stride); - input_grad[i] = output_grad[out_offset + out_index]; + CUDA_KERNEL_LOOP(linearIndex, nthreads) { + int c = (linearIndex / input_width / input_height) % channels; + int n = linearIndex / input_width / input_height / channels; + output_grad += (n * channels + c) * output_height * output_width; + int maxind = indices_data[linearIndex]; + input_grad[linearIndex] = output_grad[maxind]; } } /* diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index ad50d92c727..108cd2722b5 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -54,6 +54,16 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "unpooling_type", "(string), unpooling type, can be \"max\" for max-unpooling ") .InEnum({"max"}); + AddAttr>("output_size", + "(vector, optional). The shape of output.") + .SetDefault({0, 0}); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("NCHW"); AddComment(R"DOC( Input shape is: $(N, C_{in}, H_{in}, W_{in})$, Output shape is: $(N, C_{out}, H_{out}, W_{out})$, where @@ -93,6 +103,8 @@ class UnpoolOp : public framework::OperatorWithKernel { std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector output_size = + ctx->Attrs().Get>("output_size"); PADDLE_ENFORCE_EQ(in_x_dims.size() == 4, true, platform::errors::InvalidArgument( "Unpool Intput(X) must be of 4-dimensional, but " @@ -111,8 +123,7 @@ class UnpoolOp : public framework::OperatorWithKernel { if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) { output_shape.push_back(-1); } else { - output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i], - paddings[i], strides[i])); + output_shape.push_back(output_size[i]); } } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); diff --git a/python/paddle/fluid/tests/unittests/test_unpool_op.py b/python/paddle/fluid/tests/unittests/test_unpool_op.py index 256c0a83d12..f6dc3fba6a2 100644 --- a/python/paddle/fluid/tests/unittests/test_unpool_op.py +++ b/python/paddle/fluid/tests/unittests/test_unpool_op.py @@ -19,10 +19,26 @@ import numpy as np from op_test import OpTest -def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): +def _unpool_output_size(x, kernel_size, stride, padding, output_size): + input_size = x.shape + default_size = [] + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + ret = output_size + return ret + + +def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings, + output_size): s0, s1, s2, s3 = input.shape - out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0] - out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1] + output_size = _unpool_output_size(input, ksize, strides, paddings, + output_size) + out_hsize = output_size[0] + out_wsize = output_size[1] out = np.zeros((s0, s1, out_hsize, out_wsize)) for nidx in range(s0): for cidx in range(s1): @@ -31,7 +47,7 @@ def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): index = indices[nidx, cidx, h, w] hidx = (index - index % out_wsize) // out_wsize widx = index % out_wsize - out[nidx, cidx, int(hidx), int(widx)] = \ + out[nidx, cidx, hidx, widx] = \ input[nidx, cidx, h, w] return out @@ -41,33 +57,25 @@ class TestUnpoolOp(OpTest): def setUp(self): self.op_type = "unpool" self.init_test_case() - pre_input = np.random.random(self.shape).astype("float64") - nsize, csize, hsize, wsize = pre_input.shape - hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) // \ - self.strides[0] + 1 - wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) // \ - self.strides[1] + 1 - input = np.zeros((nsize, csize, hsize_out, wsize_out)) - indices = np.zeros((nsize, csize, hsize_out, wsize_out)) - for i in range(hsize_out): - for j in range(wsize_out): - r_start = np.max((i * self.strides[0] - self.paddings[0], 0)) - r_end = np.min((i * self.strides[0] + self.ksize[0] - \ - self.paddings[0], hsize)) - c_start = np.max((j * self.strides[1] - self.paddings[1], 0)) - c_end = np.min((j * self.strides[1] + self.ksize[1] - \ - self.paddings[1], wsize)) - for nidx in range(nsize): - for cidx in range(csize): - x_masked = pre_input[nidx, cidx, r_start:r_end, \ - c_start:c_end] - input[nidx, cidx, i, j] = x_masked.max() - arg = x_masked.argmax() - indices[nidx, cidx, i, j] = \ - (r_start + arg // self.ksize[1]) * wsize + \ - c_start + arg % self.ksize[1] + input = np.random.randint(0, 100, self.shape) + nsize, csize, hsize, wsize = input.shape + self.output_size = _unpool_output_size(input, self.ksize, self.strides, + self.paddings, self.output_size) + indices = np.random.permutation( + np.arange(0, self.output_size[0] * self.output_size[1]))[:hsize * + wsize] + indices = np.reshape(indices, [hsize, wsize]) + idx_list = [] + for n in range(nsize): + c_list = [] + for c in range(csize): + c_list.append(indices.tolist()) + idx_list.append(c_list) + indices = np.array(idx_list) + output = self.unpool2d_forward_naive(input, indices, self.ksize, \ - self.strides, self.paddings).astype("float64") + self.strides, self.paddings, self.output_size).astype("float64") + self.inputs = { 'X': input.astype('float64'), 'Indices': indices.astype('int32') @@ -77,6 +85,7 @@ class TestUnpoolOp(OpTest): 'paddings': self.paddings, 'ksize': self.ksize, 'unpooling_type': self.unpooling_type, + 'output_size': self.output_size, } self.outputs = {'Out': output.astype('float64')} @@ -89,10 +98,209 @@ class TestUnpoolOp(OpTest): def init_test_case(self): self.unpool2d_forward_naive = unpool2dmax_forward_naive self.unpooling_type = "max" - self.shape = [6, 4, 7, 7] - self.ksize = [3, 3] + self.shape = [2, 4, 7, 8] + self.ksize = [2, 2] + self.strides = [2, 2] + self.paddings = [0, 0] + self.output_size = None + + +class TestUnpoolOpcase1(TestUnpoolOp): + def init_test_case(self): + self.unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpooling_type = "max" + self.shape = [3, 2, 5, 5] + self.ksize = [4, 4] + self.strides = [2, 2] + self.paddings = [0, 0] + self.output_size = None + + +class TestUnpoolOpOuputsize(TestUnpoolOp): + def init_test_case(self): + self.unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpooling_type = "max" + self.shape = [3, 2, 5, 5] + self.ksize = [4, 4] + self.strides = [2, 2] + self.paddings = [0, 0] + self.output_size = [9, 9] + + +class TestUnpoolOpOuput(TestUnpoolOp): + def init_test_case(self): + self.unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpooling_type = "max" + self.shape = [3, 2, 5, 5] + self.ksize = [4, 4] self.strides = [2, 2] self.paddings = [0, 0] + self.output_size = [9, 9] + + +class TestUnpoolOpException(unittest.TestCase): + def test_exception(self): + import paddle.nn.functional as F + import paddle + + def indices_size_error(): + data = paddle.randint(shape=[1, 1, 3, 3]) + indices = paddle.reshape(paddle.arange(0, 12), shape[1, 1, 3, 4]) + MaxPool2D = F.maxunpool2d(data, indices, kernel_size=2, stride=2) + + def indices_value_error(): + data = paddle.randint(shape=[1, 1, 3, 3]) + indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4]) + MaxPool2D = F.maxunpool2d(data, indices, kernel_size=2, stride=2) + + def data_format_error(): + data = paddle.randint(shape=[1, 1, 3, 3]) + indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4]) + MaxPool2D = F.maxunpool2d( + data, indices, kernel_size=2, stride=2, data_format="NHWC") + + def data_outputsize_error(): + data = paddle.randint(shape=[1, 1, 3, 3]) + indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4]) + MaxPool2D = F.maxunpool2d( + data, + indices, + kernel_size=2, + stride=2, + output_size=[5, 6, 7, 8]) + + def data_outputsize_error2(): + data = paddle.randint(shape=[1, 1, 3, 3]) + indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4]) + MaxPool2D = F.maxunpool2d( + data, indices, kernel_size=2, stride=2, output_size=[100, 100]) + + self.assertRaises(ValueError, indices_size_error) + self.assertRaises(ValueError, indices_value_error) + self.assertRaises(ValueError, data_format_error) + self.assertRaises(ValueError, data_outputsize_error) + self.assertRaises(ValueError, data_outputsize_error2) + + +class TestUnpoolOpAPI_dy(unittest.TestCase): + def test_case(self): + import paddle + import paddle.nn.functional as F + import paddle.fluid.core as core + import paddle.fluid as fluid + import numpy as np + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + with fluid.dygraph.guard(place): + input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]]]).astype("float32") + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool2d( + input_x, kernel_size=2, stride=2, return_mask=True) + out_pp = F.max_unpool2d( + output, indices, kernel_size=2, stride=2, output_size=(5, 5)) + output_np = output.numpy() + indices_np = indices.numpy() + expect_res =unpool2dmax_forward_naive(output_np, indices_np, [2,2], \ + [2,2], [0,0], [5,5]).astype("float64") + self.assertTrue(np.allclose(out_pp.numpy(), expect_res)) + + +class TestUnpoolOpAPI_dy2(unittest.TestCase): + def test_case(self): + import paddle + import paddle.nn.functional as F + import paddle.fluid.core as core + import paddle.fluid as fluid + import numpy as np + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + with fluid.dygraph.guard(place): + input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]]]).astype("float32") + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool2d( + input_x, kernel_size=2, stride=2, return_mask=True) + out_pp = F.max_unpool2d( + output, indices, kernel_size=2, stride=None, output_size=(5, 5)) + output_np = output.numpy() + indices_np = indices.numpy() + expect_res =unpool2dmax_forward_naive(output_np, indices_np, [2,2], \ + [2,2], [0,0], [5,5]).astype("float64") + self.assertTrue(np.allclose(out_pp.numpy(), expect_res)) + + +class TestUnpoolOpAPI_dy3(unittest.TestCase): + def test_case(self): + import paddle + import paddle.nn.functional as F + import paddle.fluid.core as core + import paddle.fluid as fluid + import numpy as np + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + with fluid.dygraph.guard(place): + input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]]]).astype("float32") + input_x = paddle.to_tensor(input_data) + Pool2d = paddle.nn.MaxPool2D( + kernel_size=2, stride=2, return_mask=True) + UnPool = paddle.nn.MaxUnPool2D(kernel_size=2, stride=2) + + output, indices = Pool2d(input_x) + out_pp = UnPool(output, indices) + output_np = output.numpy() + indices_np = indices.numpy() + expect_res =unpool2dmax_forward_naive(output_np, indices_np, [2,2], \ + [2,2], [0,0], [4,4]).astype("float64") + self.assertTrue(np.allclose(out_pp.numpy(), expect_res)) + + +class TestUnpoolOpAPI_st(unittest.TestCase): + def test_case(self): + import paddle + import paddle.nn.functional as F + import paddle.fluid.core as core + import paddle.fluid as fluid + paddle.enable_static() + + input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]).astype("float32") + + x = fluid.data(name="x", shape=[1, 1, 4, 4], dtype="float32") + output, indices = F.max_pool2d( + x, kernel_size=2, stride=2, return_mask=True) + unpool_out = F.max_unpool2d( + output, indices, kernel_size=2, stride=None, output_size=(5, 5)) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + results = exe.run(paddle.fluid.default_main_program(),\ + feed={"x":input_data}, + fetch_list=[unpool_out], + return_numpy=True) + + pool_out_np = np.array([[[[6., 8.], [14., 16.]]]]).astype("float32") + indices_np = np.array([[[[5, 7], [13, 15]]]]).astype("int32") + expect_res =unpool2dmax_forward_naive(pool_out_np, indices_np, [2,2], \ + [2,2], [0,0], [5,5]).astype("float64") + self.assertTrue(np.allclose(results[0], expect_res)) if __name__ == '__main__': diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 394bac51ade..98444e69d0b 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -73,6 +73,7 @@ from .layer.pooling import AvgPool3D # noqa: F401 from .layer.pooling import MaxPool1D # noqa: F401 from .layer.pooling import MaxPool2D # noqa: F401 from .layer.pooling import MaxPool3D # noqa: F401 +from .layer.pooling import MaxUnPool2D # noqa: F401 from .layer.pooling import AdaptiveAvgPool1D # noqa: F401 from .layer.pooling import AdaptiveAvgPool2D # noqa: F401 from .layer.pooling import AdaptiveAvgPool3D # noqa: F401 diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index e10f0f1686d..feacbeeea70 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -101,6 +101,7 @@ from .pooling import adaptive_max_pool3d # noqa: F401 from .pooling import adaptive_avg_pool1d # noqa: F401 from .pooling import adaptive_avg_pool2d # noqa: F401 from .pooling import adaptive_avg_pool3d # noqa: F401 +from .pooling import max_unpool2d # noqa: F401 from .vision import affine_grid # noqa: F401 from .vision import grid_sample # noqa: F401 @@ -166,6 +167,7 @@ __all__ = [ #noqa 'max_pool1d', 'max_pool2d', 'max_pool3d', + 'max_unpool2d', 'adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d', diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index d3ae44bf7ce..ba69c639037 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -616,75 +616,168 @@ def max_pool1d(x, squeeze(mask, [2])) if return_mask else squeeze(pool_out, [2]) -def max_pool2d(x, - kernel_size, - stride=None, - padding=0, - return_mask=False, - ceil_mode=False, - data_format="NCHW", - name=None): +def _unpool_output_size(x, kernel_size, stride, padding, output_size): + input_size = x.shape + default_size = [] + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), len(kernel_size) + 2, len(output_size))) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'invalid output_size "{}" (dim {} must be between {} and {})'. + format(output_size, d, min_size, max_size)) + + ret = output_size + return ret + + +def max_unpool2d(x, + indices, + kernel_size, + stride=None, + padding=0, + data_format="NCHW", + output_size=None, + name=None): """ - This API implements max pooling 2d operation. - See more details in :ref:`api_nn_pooling_MaxPool2d` . + This API implements max unpooling 2d opereation. + + `max_unpool2d` is not fully invertible, since the non-maximal values are lost. + + `max_unpool2d` takes in as input the output of `max_unpool2d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + `max_unpool2d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument `output_size` in the forward call. Args: - x (Tensor): The input tensor of pooling operator which is a 4-D tensor with - shape [N, C, H, W]. The format of input tensor is `"NCHW"` or - `"NHWC"`, where `N` is batch size, `C` is the number of channels, + x (Tensor): The input tensor of unpooling operator which is a 4-D tensor with + shape [N, C, H, W]. The format of input tensor is `"NCHW"`, + where `N` is batch size, `C` is the number of channels, `H` is the height of the feature, and `W` is the width of the feature. The data type if float32 or float64. - kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, - it must contain two integers, (kernel_size_Height, kernel_size_Width). - Otherwise, the pool kernel size will be a square of an int. - stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, - it must contain two integers, (stride_Height, stride_Width). - Otherwise, the pool stride size will be a square of an int. - padding (string|int|list|tuple): The padding size. Padding could be in one of the following forms. - 1. A string in ['valid', 'same']. - 2. An int, which means the feature map is zero padded by size of `padding` on every sides. - 3. A list[int] or tuple(int) whose length is 2, [pad_height, pad_weight] whose value means the padding size of each dimension. - 4. A list[int] or tuple(int) whose length is 4. [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side. - 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). - The default value is 0. - ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape - return_mask (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format - data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. - The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: - `[batch_size, input_channels, input_height, input_width]`. + indices (Tensor): The indices given out by maxpooling2d which is a 4-D tensor with + shape [N, C, H, W]. The format of input tensor is `"NCHW"` , + where `N` is batch size, `C` is the number of channels, + `H` is the height of the feature, and `W` is the width of the + feature. The data type if float32 or float64. + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + kernel_size (int|tuple): Size of the max unpooling window. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, padding). name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. - Returns: - Tensor: The output tensor of pooling result. The data type is same as input tensor. - - Raises: - ValueError: If `padding` is a string, but not "SAME" or "VALID". - ValueError: If `padding` is "VALID", but `ceil_mode` is True. - ShapeError: If the output's shape calculated is not greater than 0. - - Examples: - .. code-block:: python + + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + or as given by :attr:`output_size` in the call operator + + Returns: + Tensor: The output tensor of unpooling result. + + Raises: + ValueError: If the input is not a 4-D tensor. + ValueError: If indeces shape is not equal input shape. + + + Examples: + .. code-block:: python + import paddle import paddle.nn.functional as F import numpy as np - - # max pool2d - x = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32]).astype(np.float32)) - out = F.max_pool2d(x, - kernel_size=2, - stride=2, padding=0) - # output.shape [1, 3, 16, 16] - # for return_mask=True - out, max_indices = F.max_pool2d(x, - kernel_size=2, - stride=2, - padding=0, - return_mask=True) - # out.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 1, 6, 6]).astype(np.float32)) + pool_out, indices = F.max_pool2d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 1, 3, 3], indices shape: [1, 1, 3, 3] + unpool_out = F.max_unpool2d(pool_out, indices, kernel_size=2, padding=0) + # unpool_out shape: [1, 1, 6, 6] + + # specify a different output size than input size + unpool_out = F.max_unpool2d(pool_out, indices, kernel_size=2, padding=0, output_size=[7,7]) + # unpool_out shape: [1, 1, 7, 7] + """ kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 2, 'pool_stride') + padding = utils.convert_to_list(padding, 2, 'padding') + + if data_format not in ["NCHW"]: + raise ValueError("Attr(data_format) should be 'NCHW'. Received " + "Attr(data_format): %s." % str(data_format)) + + output_size = _unpool_output_size(x, kernel_size, stride, padding, + output_size) + + if in_dygraph_mode(): + output = _C_ops.unpool(x, indices, 'unpooling_type', 'max', 'ksize', + kernel_size, 'strides', stride, 'paddings', + padding, "output_size", output_size, + "data_format", data_format) + return output + + op_type = "unpool" + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name="x") + unpool_out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=op_type, + inputs={"X": x, + "Indices": indices}, + outputs={"Out": unpool_out}, + attrs={ + "unpooling_type": "max", + "ksize": kernel_size, + "strides": stride, + "paddings": padding, + "output_size": output_size + }) + return unpool_out + + +def max_pool2d(x, + kernel_size, + stride=None, + padding=0, + return_mask=False, + ceil_mode=False, + data_format="NCHW", + name=None): + kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size') if stride is None: stride = kernel_size else: diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 1ffd992579a..074dfac5108 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -54,6 +54,7 @@ from .pooling import AdaptiveAvgPool3D # noqa: F401 from .pooling import AdaptiveMaxPool1D # noqa: F401 from .pooling import AdaptiveMaxPool2D # noqa: F401 from .pooling import AdaptiveMaxPool3D # noqa: F401 +from .pooling import MaxUnPool2D # noqa: F401 from .conv import Conv1D # noqa: F401 from .conv import Conv2D # noqa: F401 from .conv import Conv3D # noqa: F401 diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 881f9256841..8eb0da122e9 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1128,3 +1128,95 @@ class AdaptiveMaxPool3D(Layer): def extra_repr(self): return 'output_size={}, return_mask={}'.format(self._output_size, self._return_mask) + + +class MaxUnPool2D(Layer): + """ + This API implements max unpooling 2d opereation. + + `max_unpool2d` is not fully invertible, since the non-maximal values are lost. + + `max_unpool2d` takes in as input the output of `max_unpool2d` + including the indices of the maximal values and computes a partial inverse + in which all non-maximal values are set to zero. + + `max_unpool2d` can map several input sizes to the same output + sizes. Hence, the inversion process can get ambiguous. + To accommodate this, you can provide the needed output size + as an additional argument `output_size` in the forward call. + + Parameters: + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + kernel_size (int|tuple): Size of the max unpooling window. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, padding). + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + or as given by :attr:`output_size` in the call operator + + Returns: + A callable object of MaxUnPool2D. + + + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 1, 7, 7]).astype(np.float32)) + pool_out, indices = F.max_pool2d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 1, 3, 3], indices shape: [1, 1, 3, 3] + Unpool2D = paddle.nn.MaxUnPool2D(kernel_size=2, padding=0) + unpool_out = UnPool2D(pool_out, indices) + # unpool_out shape: [1, 1, 6, 6] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + data_format="NCHW", + output_size=None, + name=None): + super(MaxUnPool2D, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.data_format = data_format + self.output_size = output_size + self.name = name + + def forward(self, x, indices): + return F.max_unpool2d( + x, + indices, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + data_format=self.data_format, + output_size=self.output_size, + name=self.name) + + def extra_repr(self): + return 'output_size={}'.format(self.output_size) -- GitLab