From 0d32f554c17a97aa534b4ff9901dcfa9a9c77f97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:01:20 +0800 Subject: [PATCH] fix the indexerror of conv2d_transpose (#50005) --- .../fluid/tests/unittests/test_conv2d_transpose_op.py | 11 +++++++++++ python/paddle/static/nn/common.py | 3 +++ 2 files changed, 14 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index afbce517f6..8933930356 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -989,6 +989,17 @@ class TestConv2DTransposeOpException(unittest.TestCase): self.assertRaises(ValueError, error_groups) + def error_0_filter_number(): + out = paddle.static.nn.conv2d_transpose( + input=data, + groups=1, + num_filters=0, + filter_size=3, + data_format='NCHW', + ) + + self.assertRaises(ValueError, error_0_filter_number) + class TestConv2DTransposeRepr(unittest.TestCase): def test_case(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 5da81feb33..3b40153cbb 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -1542,6 +1542,9 @@ def conv2d_transpose( "but received {}".format(len(input.shape)) ) + if num_filters == 0: + raise ValueError("num of filters should not be 0.") + if data_format not in ['NCHW', 'NHWC']: raise ValueError( "Attr(data_format) of Op(paddle.static.nn.layers.conv2d_transpose) got wrong value: received " -- GitLab