From 24c7087f73a05c1e875a5b3b90f6f9cdb711120b Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 22 Jul 2021 10:40:09 +0800 Subject: [PATCH] fix index erro in conv2d_transpose (#34270) --- python/paddle/fluid/layers/nn.py | 13 ++++++++++++- .../unittests/test_conv2d_transpose_op.py | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 239f46d1d95..71ea085d381 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3924,6 +3924,10 @@ def conv2d_transpose(input, print(conv2d_transpose.shape) # [-1, 2, 34, 34] """ assert param_attr is not False, "param_attr should not be False in conv2d_transpose." + if len(input.shape) != 4: + raise ValueError("Input size should be 4, " + "but received {}".format(len(input.shape))) + if data_format not in ['NCHW', 'NHWC']: raise ValueError( "Attr(data_format) of Op(fluid.layers.conv2d_transpose) got wrong value: received " @@ -4015,7 +4019,14 @@ def conv2d_transpose(input, output_size = utils.convert_to_list(output_size, 2, 'output_size') else: raise ValueError("output_size should be int, list[int] or tuple[int]") - groups = 1 if groups is None else groups + + if groups is None: + groups = 1 + elif groups <= 0: + raise ValueError("the groups of input must be greater than 0, " + "but received the groups of input is {}".format( + groups)) + filter_shape = [input_channel, num_filters // groups] + filter_size img_filter = helper.create_parameter( diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index b106f7aa9c1..027c806fc02 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -898,6 +898,25 @@ class TestConv2DTransposeOpException(unittest.TestCase): self.assertRaises(ValueError, attr_padding_with_data_format) + error_input = fluid.layers.data( + name='error_data', shape=[1], dtype="float32") + + def error_input_size(): + out = fluid.layers.conv2d_transpose( + input=error_input, groups=1, num_filters=6, filter_size=3) + + self.assertRaises(ValueError, error_input_size) + + def error_groups(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=0, + num_filters=6, + filter_size=3, + data_format='NHWC') + + self.assertRaises(ValueError, error_groups) + class TestConv2DTransposeRepr(unittest.TestCase): def test_case(self): -- GitLab