diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index afbce517f624373813e1896dcd2edc3594ca0f48..89339303567f2a50d67475a6890a0a3d2578f75d 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -989,6 +989,17 @@ class TestConv2DTransposeOpException(unittest.TestCase): self.assertRaises(ValueError, error_groups) + def error_0_filter_number(): + out = paddle.static.nn.conv2d_transpose( + input=data, + groups=1, + num_filters=0, + filter_size=3, + data_format='NCHW', + ) + + self.assertRaises(ValueError, error_0_filter_number) + class TestConv2DTransposeRepr(unittest.TestCase): def test_case(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 5da81feb3369d9e4894dbef35fa0f36d0f6a5feb..3b40153cbb797d76ace0e3a950e533f2b36d6648 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -1542,6 +1542,9 @@ def conv2d_transpose( "but received {}".format(len(input.shape)) ) + if num_filters == 0: + raise ValueError("num of filters should not be 0.") + if data_format not in ['NCHW', 'NHWC']: raise ValueError( "Attr(data_format) of Op(paddle.static.nn.layers.conv2d_transpose) got wrong value: received "