diff --git a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py index 7a5093f872ceda2e0e88095d873ab1a2df763ef4..d484e140b6e1dbec20ac53f2cfda9712adb7dc78 100644 --- a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py @@ -726,5 +726,23 @@ class TestDeformConv2DFunctionalWithGroups(TestDeformConv2DFunctional): self.no_bias = False +class TestDeformConv2DError(unittest.TestCase): + def test_input_error(self): + def test_input_rank_error(): + paddle.enable_static() + x = paddle.static.data(name='error_x_1', shape=[0], dtype='float32') + offset = paddle.static.data( + name='error_offset_1', shape=[0], dtype='float32' + ) + mask = paddle.static.data( + name='error_mask_1', shape=[0, 0, 0], dtype='float32' + ) + out = paddle.static.nn.deform_conv2d( + x, offset, mask, 0, 0, deformable_groups=0 + ) + + self.assertRaises(ValueError, test_input_rank_error) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index c43385a8e91409f90535ba428d1b83b9715ec503..1581f299214df89e194e9d235e2f9c8e2835552e 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2244,6 +2244,11 @@ def deformable_conv( mask, 'mask', (paddle.static.Variable, type(None)), 'deformable_conv' ) + if input.ndim != 4: + raise ValueError( + f'The input should be of [N, C, H, W] format, but received {input.shape}' + ) + num_channels = input.shape[1] assert param_attr is not False, "param_attr should not be False here."