From 2d59aa09b1b9b6c8e50bfe625241f3968ae2d139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 30 Jan 2023 16:30:19 +0800 Subject: [PATCH] fix the div 0 error of deform_conv2d (#49962) --- .../tests/unittests/test_deformable_conv_op.py | 17 +++++++++++++++++ python/paddle/static/nn/common.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py index e3bc04d414..7359da3c8e 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -441,6 +441,23 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase): self.assertRaises(ValueError, test_invalid_filter) + def test_invalid_groups(): + paddle.enable_static() + input = paddle.static.data( + name='input_groups', shape=[1, 1, 1, 1], dtype='float32' + ) + offset = paddle.static.data( + name='offset_groups', shape=[1, 1], dtype='float32' + ) + mask = paddle.static.data( + name='mask_groups', shape=[1], dtype='float32' + ) + paddle.static.nn.deform_conv2d( + input, offset, mask, 1, 1, padding=1, groups=0 + ) + + self.assertRaises(ValueError, test_invalid_groups) + class TestDeformConv2DAPI(unittest.TestCase): def test_api(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 53954f49f3..0b278eefa1 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2255,6 +2255,8 @@ def deformable_conv( if groups is None: num_filter_channels = num_channels else: + if groups == 0: + raise ValueError("groups should not be 0.") if num_channels % groups != 0: raise ValueError("num_channels must be divisible by groups.") num_filter_channels = num_channels // groups -- GitLab