未验证 提交 c62657b3 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case9: paddle.static.nn.deform_conv2d (#49990)

* add dimension check for deformable_conv

* add unittest
上级 5349b9b9
...@@ -726,5 +726,23 @@ class TestDeformConv2DFunctionalWithGroups(TestDeformConv2DFunctional): ...@@ -726,5 +726,23 @@ class TestDeformConv2DFunctionalWithGroups(TestDeformConv2DFunctional):
self.no_bias = False self.no_bias = False
class TestDeformConv2DError(unittest.TestCase):
def test_input_error(self):
def test_input_rank_error():
paddle.enable_static()
x = paddle.static.data(name='error_x_1', shape=[0], dtype='float32')
offset = paddle.static.data(
name='error_offset_1', shape=[0], dtype='float32'
)
mask = paddle.static.data(
name='error_mask_1', shape=[0, 0, 0], dtype='float32'
)
out = paddle.static.nn.deform_conv2d(
x, offset, mask, 0, 0, deformable_groups=0
)
self.assertRaises(ValueError, test_input_rank_error)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -2244,6 +2244,11 @@ def deformable_conv( ...@@ -2244,6 +2244,11 @@ def deformable_conv(
mask, 'mask', (paddle.static.Variable, type(None)), 'deformable_conv' mask, 'mask', (paddle.static.Variable, type(None)), 'deformable_conv'
) )
if input.ndim != 4:
raise ValueError(
f'The input should be of [N, C, H, W] format, but received {input.shape}'
)
num_channels = input.shape[1] num_channels = input.shape[1]
assert param_attr is not False, "param_attr should not be False here." assert param_attr is not False, "param_attr should not be False here."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册