未验证 提交 cc3e7cd8 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-pick][BugFix]Fix pooling output_size bug if encounter list[Tensor] (#46360)

* [Check]Enhance pooling output_size type check

* add unittest
上级 b74c0920
...@@ -436,6 +436,35 @@ class TestZOutputSizeTensor2(unittest.TestCase): ...@@ -436,6 +436,35 @@ class TestZOutputSizeTensor2(unittest.TestCase):
np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7]) np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7])
class TestZOutputSizeTensor3(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def tearDown(self):
paddle.enable_static()
def test_dygraph(self):
x = paddle.randn([1, 3, 6, 6])
pool_out, indices = F.max_pool2d(x,
kernel_size=2,
stride=2,
padding=0,
return_mask=True)
output_size = [
paddle.assign([1]),
paddle.assign([1]),
paddle.assign([7]),
paddle.assign([7])
]
unpool_out = F.max_unpool2d(pool_out,
indices,
kernel_size=2,
padding=0,
output_size=output_size)
np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7])
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -646,6 +646,9 @@ def max_pool1d(x, ...@@ -646,6 +646,9 @@ def max_pool1d(x,
def _unpool_output_size(x, kernel_size, stride, padding, output_size): def _unpool_output_size(x, kernel_size, stride, padding, output_size):
assert output_size is None or isinstance(
output_size, (list, tuple)
), "Required output_size is None|list|tuple, but received %s" % output_size
input_size = x.shape input_size = x.shape
default_size = [] default_size = []
for d in range(len(kernel_size)): for d in range(len(kernel_size)):
...@@ -654,7 +657,7 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): ...@@ -654,7 +657,7 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
has_static_var = False has_static_var = False
if output_size is None: if output_size is None:
ret = default_size return default_size
elif utils._contain_var(output_size): elif utils._contain_var(output_size):
if not _non_static_mode(): if not _non_static_mode():
has_static_var = True has_static_var = True
...@@ -663,27 +666,25 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): ...@@ -663,27 +666,25 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
for i, var in enumerate(output_size): for i, var in enumerate(output_size):
if isinstance(var, Variable): if isinstance(var, Variable):
output_size[i] = var.numpy()[0] output_size[i] = var.numpy()[0]
ret = output_size
else: if len(output_size) == len(kernel_size) + 2:
if len(output_size) == len(kernel_size) + 2: output_size = output_size[2:]
output_size = output_size[2:] if len(output_size) != len(kernel_size):
if len(output_size) != len(kernel_size): raise ValueError(
raise ValueError( "output_size should be a sequence containing "
"output_size should be a sequence containing " "{} or {} elements, but it has a length of '{}'".format(
"{} or {} elements, but it has a length of '{}'".format( len(kernel_size),
len(kernel_size), len(kernel_size) + 2, len(output_size)))
len(kernel_size) + 2, len(output_size))) if not has_static_var:
if not has_static_var: for d in range(len(kernel_size)):
for d in range(len(kernel_size)): min_size = default_size[d] - stride[d]
min_size = default_size[d] - stride[d] max_size = default_size[d] + stride[d]
max_size = default_size[d] + stride[d] if not (min_size < output_size[d] < max_size):
if not (min_size < output_size[d] < max_size): raise ValueError(
raise ValueError( 'invalid output_size "{}" (dim {} must be between {} and {})'
'invalid output_size "{}" (dim {} must be between {} and {})' .format(output_size, d, min_size, max_size))
.format(output_size, d, min_size, max_size))
return output_size
ret = output_size
return ret
def max_unpool1d(x, def max_unpool1d(x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册