From 18860735ed1544f866e42e661c050015edbea555 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 22 Sep 2022 10:51:43 +0800 Subject: [PATCH] [BugFix]Fix pooling output_size bug if encounter list[Tensor] (#46352) * [Check]Enhance pooling output_size type check * add unittest --- .../fluid/tests/unittests/test_unpool_op.py | 29 ++++++++++++ python/paddle/nn/functional/pooling.py | 45 ++++++++++--------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_unpool_op.py b/python/paddle/fluid/tests/unittests/test_unpool_op.py index 42522ae6cb..9d94e5acbf 100644 --- a/python/paddle/fluid/tests/unittests/test_unpool_op.py +++ b/python/paddle/fluid/tests/unittests/test_unpool_op.py @@ -436,6 +436,35 @@ class TestZOutputSizeTensor2(unittest.TestCase): np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7]) +class TestZOutputSizeTensor3(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_dygraph(self): + x = paddle.randn([1, 3, 6, 6]) + pool_out, indices = F.max_pool2d(x, + kernel_size=2, + stride=2, + padding=0, + return_mask=True) + output_size = [ + paddle.assign([1]), + paddle.assign([1]), + paddle.assign([7]), + paddle.assign([7]) + ] + unpool_out = F.max_unpool2d(pool_out, + indices, + kernel_size=2, + padding=0, + output_size=output_size) + np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7]) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 307b0783b1..d2ba7468c8 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -646,6 +646,9 @@ def max_pool1d(x, def _unpool_output_size(x, kernel_size, stride, padding, output_size): + assert output_size is None or isinstance( + output_size, (list, tuple) + ), "Required output_size is None|list|tuple, but received %s" % output_size input_size = x.shape default_size = [] for d in range(len(kernel_size)): @@ -654,7 +657,7 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): has_static_var = False if output_size is None: - ret = default_size + return default_size elif utils._contain_var(output_size): if not _non_static_mode(): has_static_var = True @@ -663,27 +666,25 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): for i, var in enumerate(output_size): if isinstance(var, Variable): output_size[i] = var.numpy()[0] - ret = output_size - else: - if len(output_size) == len(kernel_size) + 2: - output_size = output_size[2:] - if len(output_size) != len(kernel_size): - raise ValueError( - "output_size should be a sequence containing " - "{} or {} elements, but it has a length of '{}'".format( - len(kernel_size), - len(kernel_size) + 2, len(output_size))) - if not has_static_var: - for d in range(len(kernel_size)): - min_size = default_size[d] - stride[d] - max_size = default_size[d] + stride[d] - if not (min_size < output_size[d] < max_size): - raise ValueError( - 'invalid output_size "{}" (dim {} must be between {} and {})' - .format(output_size, d, min_size, max_size)) - - ret = output_size - return ret + + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), + len(kernel_size) + 2, len(output_size))) + if not has_static_var: + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'invalid output_size "{}" (dim {} must be between {} and {})' + .format(output_size, d, min_size, max_size)) + + return output_size def max_unpool1d(x, -- GitLab