diff --git a/python/paddle/fluid/tests/unittests/test_unpool_op.py b/python/paddle/fluid/tests/unittests/test_unpool_op.py index 42522ae6cb88b611f6b67501c28817db01f74ad4..9d94e5acbfea7ea7335fd0560d7831356cdf8d4c 100644 --- a/python/paddle/fluid/tests/unittests/test_unpool_op.py +++ b/python/paddle/fluid/tests/unittests/test_unpool_op.py @@ -436,6 +436,35 @@ class TestZOutputSizeTensor2(unittest.TestCase): np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7]) +class TestZOutputSizeTensor3(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_dygraph(self): + x = paddle.randn([1, 3, 6, 6]) + pool_out, indices = F.max_pool2d(x, + kernel_size=2, + stride=2, + padding=0, + return_mask=True) + output_size = [ + paddle.assign([1]), + paddle.assign([1]), + paddle.assign([7]), + paddle.assign([7]) + ] + unpool_out = F.max_unpool2d(pool_out, + indices, + kernel_size=2, + padding=0, + output_size=output_size) + np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7]) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 637b192207eed150370bc36f05d21235172beb26..4fd6c75e4c6404fa252ab5b19abef0e03820a529 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -646,6 +646,9 @@ def max_pool1d(x, def _unpool_output_size(x, kernel_size, stride, padding, output_size): + assert output_size is None or isinstance( + output_size, (list, tuple) + ), "Required output_size is None|list|tuple, but received %s" % output_size input_size = x.shape default_size = [] for d in range(len(kernel_size)): @@ -654,7 +657,7 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): has_static_var = False if output_size is None: - ret = default_size + return default_size elif utils._contain_var(output_size): if not _non_static_mode(): has_static_var = True @@ -663,27 +666,25 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): for i, var in enumerate(output_size): if isinstance(var, Variable): output_size[i] = var.numpy()[0] - ret = output_size - else: - if len(output_size) == len(kernel_size) + 2: - output_size = output_size[2:] - if len(output_size) != len(kernel_size): - raise ValueError( - "output_size should be a sequence containing " - "{} or {} elements, but it has a length of '{}'".format( - len(kernel_size), - len(kernel_size) + 2, len(output_size))) - if not has_static_var: - for d in range(len(kernel_size)): - min_size = default_size[d] - stride[d] - max_size = default_size[d] + stride[d] - if not (min_size < output_size[d] < max_size): - raise ValueError( - 'invalid output_size "{}" (dim {} must be between {} and {})' - .format(output_size, d, min_size, max_size)) - - ret = output_size - return ret + + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), + len(kernel_size) + 2, len(output_size))) + if not has_static_var: + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'invalid output_size "{}" (dim {} must be between {} and {})' + .format(output_size, d, min_size, max_size)) + + return output_size def max_unpool1d(x,