From c62657b3eccf2fa41e12c94b5a34d82a6f54890f Mon Sep 17 00:00:00 2001 From: RedContritio Date: Wed, 1 Feb 2023 15:15:12 +0800 Subject: [PATCH] Fix Python IndexError of case9: paddle.static.nn.deform_conv2d (#49990) * add dimension check for deformable_conv * add unittest --- .../tests/unittests/test_deform_conv2d.py | 18 ++++++++++++++++++ python/paddle/static/nn/common.py | 5 +++++ 2 files changed, 23 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py index 7a5093f872..d484e140b6 100644 --- a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py @@ -726,5 +726,23 @@ class TestDeformConv2DFunctionalWithGroups(TestDeformConv2DFunctional): self.no_bias = False +class TestDeformConv2DError(unittest.TestCase): + def test_input_error(self): + def test_input_rank_error(): + paddle.enable_static() + x = paddle.static.data(name='error_x_1', shape=[0], dtype='float32') + offset = paddle.static.data( + name='error_offset_1', shape=[0], dtype='float32' + ) + mask = paddle.static.data( + name='error_mask_1', shape=[0, 0, 0], dtype='float32' + ) + out = paddle.static.nn.deform_conv2d( + x, offset, mask, 0, 0, deformable_groups=0 + ) + + self.assertRaises(ValueError, test_input_rank_error) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index c43385a8e9..1581f29921 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2244,6 +2244,11 @@ def deformable_conv( mask, 'mask', (paddle.static.Variable, type(None)), 'deformable_conv' ) + if input.ndim != 4: + raise ValueError( + f'The input should be of [N, C, H, W] format, but received {input.shape}' + ) + num_channels = input.shape[1] assert param_attr is not False, "param_attr should not be False here." -- GitLab