未验证 提交 ceee71a0 编写于 作者: X xiaoting 提交者: GitHub

Add unpool2d op & Expose max_unpool2d API (#35056)

* add maxunppol2d op, test=develop

* fix typo, test=develop

* fix unpool unitest, test=develop

* fix unpool code-example, test=develop

* fix for unpool_op_unittest,test=develop

* fix example code, test=develop

* add noqa:F401, test=develop

* fix converage, test=develop

* fix unitest for unpool, test=develop

* rename unpool2d to unpool, test=develop

* rename unpool2d to unpool, test=develop
上级 234ce932
......@@ -38,6 +38,7 @@ class Unpool2dMaxFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
......
......@@ -25,48 +25,27 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
const int channels, T* output_data,
const int output_height,
const int output_width) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ENFORCE(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
output_data[out_offset + out_index] = input_data[i];
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_width / input_height) % channels;
int n = linearIndex / input_width / input_height / channels;
output_data += (n * channels + c) * output_height * output_width;
int maxind = indices_data[linearIndex];
output_data[maxind] = input_data[linearIndex];
}
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(
const int nthreads, const T* input_data, const int* indices_data,
const int input_height, const int input_width, const int channels,
const T* output_data, const T* output_grad, const int output_height,
const int output_width, T* input_grad) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ENFORCE(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
input_grad[i] = output_grad[out_offset + out_index];
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_width / input_height) % channels;
int n = linearIndex / input_width / input_height / channels;
output_grad += (n * channels + c) * output_height * output_width;
int maxind = indices_data[linearIndex];
input_grad[linearIndex] = output_grad[maxind];
}
}
/*
......
......@@ -54,6 +54,16 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddAttr<std::vector<int>>("output_size",
"(vector, optional). The shape of output.")
.SetDefault({0, 0});
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("NCHW");
AddComment(R"DOC(
Input shape is: $(N, C_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, H_{out}, W_{out})$, where
......@@ -93,6 +103,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> output_size =
ctx->Attrs().Get<std::vector<int>>("output_size");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Unpool Intput(X) must be of 4-dimensional, but "
......@@ -111,8 +123,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1);
} else {
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
output_shape.push_back(output_size[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
......
......@@ -19,10 +19,26 @@ import numpy as np
from op_test import OpTest
def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings):
def _unpool_output_size(x, kernel_size, stride, padding, output_size):
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d]
+ kernel_size[d] - 2 * padding[d])
if output_size is None:
ret = default_size
else:
ret = output_size
return ret
def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings,
output_size):
s0, s1, s2, s3 = input.shape
out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0]
out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1]
output_size = _unpool_output_size(input, ksize, strides, paddings,
output_size)
out_hsize = output_size[0]
out_wsize = output_size[1]
out = np.zeros((s0, s1, out_hsize, out_wsize))
for nidx in range(s0):
for cidx in range(s1):
......@@ -31,7 +47,7 @@ def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings):
index = indices[nidx, cidx, h, w]
hidx = (index - index % out_wsize) // out_wsize
widx = index % out_wsize
out[nidx, cidx, int(hidx), int(widx)] = \
out[nidx, cidx, hidx, widx] = \
input[nidx, cidx, h, w]
return out
......@@ -41,33 +57,25 @@ class TestUnpoolOp(OpTest):
def setUp(self):
self.op_type = "unpool"
self.init_test_case()
pre_input = np.random.random(self.shape).astype("float64")
nsize, csize, hsize, wsize = pre_input.shape
hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) // \
self.strides[0] + 1
wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) // \
self.strides[1] + 1
input = np.zeros((nsize, csize, hsize_out, wsize_out))
indices = np.zeros((nsize, csize, hsize_out, wsize_out))
for i in range(hsize_out):
for j in range(wsize_out):
r_start = np.max((i * self.strides[0] - self.paddings[0], 0))
r_end = np.min((i * self.strides[0] + self.ksize[0] - \
self.paddings[0], hsize))
c_start = np.max((j * self.strides[1] - self.paddings[1], 0))
c_end = np.min((j * self.strides[1] + self.ksize[1] - \
self.paddings[1], wsize))
for nidx in range(nsize):
for cidx in range(csize):
x_masked = pre_input[nidx, cidx, r_start:r_end, \
c_start:c_end]
input[nidx, cidx, i, j] = x_masked.max()
arg = x_masked.argmax()
indices[nidx, cidx, i, j] = \
(r_start + arg // self.ksize[1]) * wsize + \
c_start + arg % self.ksize[1]
input = np.random.randint(0, 100, self.shape)
nsize, csize, hsize, wsize = input.shape
self.output_size = _unpool_output_size(input, self.ksize, self.strides,
self.paddings, self.output_size)
indices = np.random.permutation(
np.arange(0, self.output_size[0] * self.output_size[1]))[:hsize *
wsize]
indices = np.reshape(indices, [hsize, wsize])
idx_list = []
for n in range(nsize):
c_list = []
for c in range(csize):
c_list.append(indices.tolist())
idx_list.append(c_list)
indices = np.array(idx_list)
output = self.unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float64")
self.strides, self.paddings, self.output_size).astype("float64")
self.inputs = {
'X': input.astype('float64'),
'Indices': indices.astype('int32')
......@@ -77,6 +85,7 @@ class TestUnpoolOp(OpTest):
'paddings': self.paddings,
'ksize': self.ksize,
'unpooling_type': self.unpooling_type,
'output_size': self.output_size,
}
self.outputs = {'Out': output.astype('float64')}
......@@ -89,10 +98,209 @@ class TestUnpoolOp(OpTest):
def init_test_case(self):
self.unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpooling_type = "max"
self.shape = [6, 4, 7, 7]
self.ksize = [3, 3]
self.shape = [2, 4, 7, 8]
self.ksize = [2, 2]
self.strides = [2, 2]
self.paddings = [0, 0]
self.output_size = None
class TestUnpoolOpcase1(TestUnpoolOp):
def init_test_case(self):
self.unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpooling_type = "max"
self.shape = [3, 2, 5, 5]
self.ksize = [4, 4]
self.strides = [2, 2]
self.paddings = [0, 0]
self.output_size = None
class TestUnpoolOpOuputsize(TestUnpoolOp):
def init_test_case(self):
self.unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpooling_type = "max"
self.shape = [3, 2, 5, 5]
self.ksize = [4, 4]
self.strides = [2, 2]
self.paddings = [0, 0]
self.output_size = [9, 9]
class TestUnpoolOpOuput(TestUnpoolOp):
def init_test_case(self):
self.unpool2d_forward_naive = unpool2dmax_forward_naive
self.unpooling_type = "max"
self.shape = [3, 2, 5, 5]
self.ksize = [4, 4]
self.strides = [2, 2]
self.paddings = [0, 0]
self.output_size = [9, 9]
class TestUnpoolOpException(unittest.TestCase):
def test_exception(self):
import paddle.nn.functional as F
import paddle
def indices_size_error():
data = paddle.randint(shape=[1, 1, 3, 3])
indices = paddle.reshape(paddle.arange(0, 12), shape[1, 1, 3, 4])
MaxPool2D = F.maxunpool2d(data, indices, kernel_size=2, stride=2)
def indices_value_error():
data = paddle.randint(shape=[1, 1, 3, 3])
indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4])
MaxPool2D = F.maxunpool2d(data, indices, kernel_size=2, stride=2)
def data_format_error():
data = paddle.randint(shape=[1, 1, 3, 3])
indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4])
MaxPool2D = F.maxunpool2d(
data, indices, kernel_size=2, stride=2, data_format="NHWC")
def data_outputsize_error():
data = paddle.randint(shape=[1, 1, 3, 3])
indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4])
MaxPool2D = F.maxunpool2d(
data,
indices,
kernel_size=2,
stride=2,
output_size=[5, 6, 7, 8])
def data_outputsize_error2():
data = paddle.randint(shape=[1, 1, 3, 3])
indices = paddle.reshape(paddle.arange(4, 40), shape[1, 1, 3, 4])
MaxPool2D = F.maxunpool2d(
data, indices, kernel_size=2, stride=2, output_size=[100, 100])
self.assertRaises(ValueError, indices_size_error)
self.assertRaises(ValueError, indices_value_error)
self.assertRaises(ValueError, data_format_error)
self.assertRaises(ValueError, data_outputsize_error)
self.assertRaises(ValueError, data_outputsize_error2)
class TestUnpoolOpAPI_dy(unittest.TestCase):
def test_case(self):
import paddle
import paddle.nn.functional as F
import paddle.fluid.core as core
import paddle.fluid as fluid
import numpy as np
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]]]).astype("float32")
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool2d(
input_x, kernel_size=2, stride=2, return_mask=True)
out_pp = F.max_unpool2d(
output, indices, kernel_size=2, stride=2, output_size=(5, 5))
output_np = output.numpy()
indices_np = indices.numpy()
expect_res =unpool2dmax_forward_naive(output_np, indices_np, [2,2], \
[2,2], [0,0], [5,5]).astype("float64")
self.assertTrue(np.allclose(out_pp.numpy(), expect_res))
class TestUnpoolOpAPI_dy2(unittest.TestCase):
def test_case(self):
import paddle
import paddle.nn.functional as F
import paddle.fluid.core as core
import paddle.fluid as fluid
import numpy as np
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]]]).astype("float32")
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool2d(
input_x, kernel_size=2, stride=2, return_mask=True)
out_pp = F.max_unpool2d(
output, indices, kernel_size=2, stride=None, output_size=(5, 5))
output_np = output.numpy()
indices_np = indices.numpy()
expect_res =unpool2dmax_forward_naive(output_np, indices_np, [2,2], \
[2,2], [0,0], [5,5]).astype("float64")
self.assertTrue(np.allclose(out_pp.numpy(), expect_res))
class TestUnpoolOpAPI_dy3(unittest.TestCase):
def test_case(self):
import paddle
import paddle.nn.functional as F
import paddle.fluid.core as core
import paddle.fluid as fluid
import numpy as np
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]]]).astype("float32")
input_x = paddle.to_tensor(input_data)
Pool2d = paddle.nn.MaxPool2D(
kernel_size=2, stride=2, return_mask=True)
UnPool = paddle.nn.MaxUnPool2D(kernel_size=2, stride=2)
output, indices = Pool2d(input_x)
out_pp = UnPool(output, indices)
output_np = output.numpy()
indices_np = indices.numpy()
expect_res =unpool2dmax_forward_naive(output_np, indices_np, [2,2], \
[2,2], [0,0], [4,4]).astype("float64")
self.assertTrue(np.allclose(out_pp.numpy(), expect_res))
class TestUnpoolOpAPI_st(unittest.TestCase):
def test_case(self):
import paddle
import paddle.nn.functional as F
import paddle.fluid.core as core
import paddle.fluid as fluid
paddle.enable_static()
input_data = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
[13, 14, 15, 16]]]]).astype("float32")
x = fluid.data(name="x", shape=[1, 1, 4, 4], dtype="float32")
output, indices = F.max_pool2d(
x, kernel_size=2, stride=2, return_mask=True)
unpool_out = F.max_unpool2d(
output, indices, kernel_size=2, stride=None, output_size=(5, 5))
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(paddle.fluid.default_main_program(),\
feed={"x":input_data},
fetch_list=[unpool_out],
return_numpy=True)
pool_out_np = np.array([[[[6., 8.], [14., 16.]]]]).astype("float32")
indices_np = np.array([[[[5, 7], [13, 15]]]]).astype("int32")
expect_res =unpool2dmax_forward_naive(pool_out_np, indices_np, [2,2], \
[2,2], [0,0], [5,5]).astype("float64")
self.assertTrue(np.allclose(results[0], expect_res))
if __name__ == '__main__':
......
......@@ -73,6 +73,7 @@ from .layer.pooling import AvgPool3D # noqa: F401
from .layer.pooling import MaxPool1D # noqa: F401
from .layer.pooling import MaxPool2D # noqa: F401
from .layer.pooling import MaxPool3D # noqa: F401
from .layer.pooling import MaxUnPool2D # noqa: F401
from .layer.pooling import AdaptiveAvgPool1D # noqa: F401
from .layer.pooling import AdaptiveAvgPool2D # noqa: F401
from .layer.pooling import AdaptiveAvgPool3D # noqa: F401
......
......@@ -101,6 +101,7 @@ from .pooling import adaptive_max_pool3d # noqa: F401
from .pooling import adaptive_avg_pool1d # noqa: F401
from .pooling import adaptive_avg_pool2d # noqa: F401
from .pooling import adaptive_avg_pool3d # noqa: F401
from .pooling import max_unpool2d # noqa: F401
from .vision import affine_grid # noqa: F401
from .vision import grid_sample # noqa: F401
......@@ -166,6 +167,7 @@ __all__ = [ #noqa
'max_pool1d',
'max_pool2d',
'max_pool3d',
'max_unpool2d',
'adaptive_avg_pool1d',
'adaptive_avg_pool2d',
'adaptive_avg_pool3d',
......
......@@ -616,75 +616,168 @@ def max_pool1d(x,
squeeze(mask, [2])) if return_mask else squeeze(pool_out, [2])
def max_pool2d(x,
kernel_size,
stride=None,
padding=0,
return_mask=False,
ceil_mode=False,
data_format="NCHW",
name=None):
def _unpool_output_size(x, kernel_size, stride, padding, output_size):
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d]
+ kernel_size[d] - 2 * padding[d])
if output_size is None:
ret = default_size
else:
if len(output_size) == len(kernel_size) + 2:
output_size = output_size[2:]
if len(output_size) != len(kernel_size):
raise ValueError(
"output_size should be a sequence containing "
"{} or {} elements, but it has a length of '{}'".format(
len(kernel_size), len(kernel_size) + 2, len(output_size)))
for d in range(len(kernel_size)):
min_size = default_size[d] - stride[d]
max_size = default_size[d] + stride[d]
if not (min_size < output_size[d] < max_size):
raise ValueError(
'invalid output_size "{}" (dim {} must be between {} and {})'.
format(output_size, d, min_size, max_size))
ret = output_size
return ret
def max_unpool2d(x,
indices,
kernel_size,
stride=None,
padding=0,
data_format="NCHW",
output_size=None,
name=None):
"""
This API implements max pooling 2d operation.
See more details in :ref:`api_nn_pooling_MaxPool2d` .
This API implements max unpooling 2d opereation.
`max_unpool2d` is not fully invertible, since the non-maximal values are lost.
`max_unpool2d` takes in as input the output of `max_unpool2d`
including the indices of the maximal values and computes a partial inverse
in which all non-maximal values are set to zero.
`max_unpool2d` can map several input sizes to the same output
sizes. Hence, the inversion process can get ambiguous.
To accommodate this, you can provide the needed output size
as an additional argument `output_size` in the forward call.
Args:
x (Tensor): The input tensor of pooling operator which is a 4-D tensor with
shape [N, C, H, W]. The format of input tensor is `"NCHW"` or
`"NHWC"`, where `N` is batch size, `C` is the number of channels,
x (Tensor): The input tensor of unpooling operator which is a 4-D tensor with
shape [N, C, H, W]. The format of input tensor is `"NCHW"`,
where `N` is batch size, `C` is the number of channels,
`H` is the height of the feature, and `W` is the width of the
feature. The data type if float32 or float64.
kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (kernel_size_Height, kernel_size_Width).
Otherwise, the pool kernel size will be a square of an int.
stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list,
it must contain two integers, (stride_Height, stride_Width).
Otherwise, the pool stride size will be a square of an int.
padding (string|int|list|tuple): The padding size. Padding could be in one of the following forms.
1. A string in ['valid', 'same'].
2. An int, which means the feature map is zero padded by size of `padding` on every sides.
3. A list[int] or tuple(int) whose length is 2, [pad_height, pad_weight] whose value means the padding size of each dimension.
4. A list[int] or tuple(int) whose length is 4. [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side.
5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0).
The default value is 0.
ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape
return_mask (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
indices (Tensor): The indices given out by maxpooling2d which is a 4-D tensor with
shape [N, C, H, W]. The format of input tensor is `"NCHW"` ,
where `N` is batch size, `C` is the number of channels,
`H` is the height of the feature, and `W` is the width of the
feature. The data type if float32 or float64.
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
kernel_size (int|tuple): Size of the max unpooling window.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, padding).
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of pooling result. The data type is same as input tensor.
Raises:
ValueError: If `padding` is a string, but not "SAME" or "VALID".
ValueError: If `padding` is "VALID", but `ceil_mode` is True.
ShapeError: If the output's shape calculated is not greater than 0.
Examples:
.. code-block:: python
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})`, where
.. math::
H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
.. math::
W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
or as given by :attr:`output_size` in the call operator
Returns:
Tensor: The output tensor of unpooling result.
Raises:
ValueError: If the input is not a 4-D tensor.
ValueError: If indeces shape is not equal input shape.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
# max pool2d
x = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32]).astype(np.float32))
out = F.max_pool2d(x,
kernel_size=2,
stride=2, padding=0)
# output.shape [1, 3, 16, 16]
# for return_mask=True
out, max_indices = F.max_pool2d(x,
kernel_size=2,
stride=2,
padding=0,
return_mask=True)
# out.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16],
data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 1, 6, 6]).astype(np.float32))
pool_out, indices = F.max_pool2d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 1, 3, 3], indices shape: [1, 1, 3, 3]
unpool_out = F.max_unpool2d(pool_out, indices, kernel_size=2, padding=0)
# unpool_out shape: [1, 1, 6, 6]
# specify a different output size than input size
unpool_out = F.max_unpool2d(pool_out, indices, kernel_size=2, padding=0, output_size=[7,7])
# unpool_out shape: [1, 1, 7, 7]
"""
kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size')
if stride is None:
stride = kernel_size
else:
stride = utils.convert_to_list(stride, 2, 'pool_stride')
padding = utils.convert_to_list(padding, 2, 'padding')
if data_format not in ["NCHW"]:
raise ValueError("Attr(data_format) should be 'NCHW'. Received "
"Attr(data_format): %s." % str(data_format))
output_size = _unpool_output_size(x, kernel_size, stride, padding,
output_size)
if in_dygraph_mode():
output = _C_ops.unpool(x, indices, 'unpooling_type', 'max', 'ksize',
kernel_size, 'strides', stride, 'paddings',
padding, "output_size", output_size,
"data_format", data_format)
return output
op_type = "unpool"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name="x")
unpool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=op_type,
inputs={"X": x,
"Indices": indices},
outputs={"Out": unpool_out},
attrs={
"unpooling_type": "max",
"ksize": kernel_size,
"strides": stride,
"paddings": padding,
"output_size": output_size
})
return unpool_out
def max_pool2d(x,
kernel_size,
stride=None,
padding=0,
return_mask=False,
ceil_mode=False,
data_format="NCHW",
name=None):
kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size')
if stride is None:
stride = kernel_size
else:
......
......@@ -54,6 +54,7 @@ from .pooling import AdaptiveAvgPool3D # noqa: F401
from .pooling import AdaptiveMaxPool1D # noqa: F401
from .pooling import AdaptiveMaxPool2D # noqa: F401
from .pooling import AdaptiveMaxPool3D # noqa: F401
from .pooling import MaxUnPool2D # noqa: F401
from .conv import Conv1D # noqa: F401
from .conv import Conv2D # noqa: F401
from .conv import Conv3D # noqa: F401
......
......@@ -1128,3 +1128,95 @@ class AdaptiveMaxPool3D(Layer):
def extra_repr(self):
return 'output_size={}, return_mask={}'.format(self._output_size,
self._return_mask)
class MaxUnPool2D(Layer):
"""
This API implements max unpooling 2d opereation.
`max_unpool2d` is not fully invertible, since the non-maximal values are lost.
`max_unpool2d` takes in as input the output of `max_unpool2d`
including the indices of the maximal values and computes a partial inverse
in which all non-maximal values are set to zero.
`max_unpool2d` can map several input sizes to the same output
sizes. Hence, the inversion process can get ambiguous.
To accommodate this, you can provide the needed output size
as an additional argument `output_size` in the forward call.
Parameters:
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
kernel_size (int|tuple): Size of the max unpooling window.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, padding).
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})`, where
.. math::
H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
.. math::
W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
or as given by :attr:`output_size` in the call operator
Returns:
A callable object of MaxUnPool2D.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 1, 7, 7]).astype(np.float32))
pool_out, indices = F.max_pool2d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 1, 3, 3], indices shape: [1, 1, 3, 3]
Unpool2D = paddle.nn.MaxUnPool2D(kernel_size=2, padding=0)
unpool_out = UnPool2D(pool_out, indices)
# unpool_out shape: [1, 1, 6, 6]
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
data_format="NCHW",
output_size=None,
name=None):
super(MaxUnPool2D, self).__init__()
self.ksize = kernel_size
self.stride = stride
self.padding = padding
self.data_format = data_format
self.output_size = output_size
self.name = name
def forward(self, x, indices):
return F.max_unpool2d(
x,
indices,
kernel_size=self.ksize,
stride=self.stride,
padding=self.padding,
data_format=self.data_format,
output_size=self.output_size,
name=self.name)
def extra_repr(self):
return 'output_size={}'.format(self.output_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册