未验证 提交 cc3e7cd8 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-pick][BugFix]Fix pooling output_size bug if encounter list[Tensor] (#46360)

* [Check]Enhance pooling output_size type check

* add unittest
上级 b74c0920
......@@ -436,6 +436,35 @@ class TestZOutputSizeTensor2(unittest.TestCase):
np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7])
class TestZOutputSizeTensor3(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def tearDown(self):
paddle.enable_static()
def test_dygraph(self):
x = paddle.randn([1, 3, 6, 6])
pool_out, indices = F.max_pool2d(x,
kernel_size=2,
stride=2,
padding=0,
return_mask=True)
output_size = [
paddle.assign([1]),
paddle.assign([1]),
paddle.assign([7]),
paddle.assign([7])
]
unpool_out = F.max_unpool2d(pool_out,
indices,
kernel_size=2,
padding=0,
output_size=output_size)
np.testing.assert_array_equal(unpool_out.shape, [1, 3, 7, 7])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -646,6 +646,9 @@ def max_pool1d(x,
def _unpool_output_size(x, kernel_size, stride, padding, output_size):
assert output_size is None or isinstance(
output_size, (list, tuple)
), "Required output_size is None|list|tuple, but received %s" % output_size
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
......@@ -654,7 +657,7 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
has_static_var = False
if output_size is None:
ret = default_size
return default_size
elif utils._contain_var(output_size):
if not _non_static_mode():
has_static_var = True
......@@ -663,27 +666,25 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
for i, var in enumerate(output_size):
if isinstance(var, Variable):
output_size[i] = var.numpy()[0]
ret = output_size
else:
if len(output_size) == len(kernel_size) + 2:
output_size = output_size[2:]
if len(output_size) != len(kernel_size):
raise ValueError(
"output_size should be a sequence containing "
"{} or {} elements, but it has a length of '{}'".format(
len(kernel_size),
len(kernel_size) + 2, len(output_size)))
if not has_static_var:
for d in range(len(kernel_size)):
min_size = default_size[d] - stride[d]
max_size = default_size[d] + stride[d]
if not (min_size < output_size[d] < max_size):
raise ValueError(
'invalid output_size "{}" (dim {} must be between {} and {})'
.format(output_size, d, min_size, max_size))
ret = output_size
return ret
if len(output_size) == len(kernel_size) + 2:
output_size = output_size[2:]
if len(output_size) != len(kernel_size):
raise ValueError(
"output_size should be a sequence containing "
"{} or {} elements, but it has a length of '{}'".format(
len(kernel_size),
len(kernel_size) + 2, len(output_size)))
if not has_static_var:
for d in range(len(kernel_size)):
min_size = default_size[d] - stride[d]
max_size = default_size[d] + stride[d]
if not (min_size < output_size[d] < max_size):
raise ValueError(
'invalid output_size "{}" (dim {} must be between {} and {})'
.format(output_size, d, min_size, max_size))
return output_size
def max_unpool1d(x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册