From 7bb67db3d53297df8cce4b30992bb1035ba3bf62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 31 Jan 2023 10:48:58 +0800 Subject: [PATCH] fix the div 0 errors in psroi_pool (#49965) * fix the div 0 errors in psroi_pool * fix case 7 * rool back sth. --- .../fluid/tests/unittests/test_psroi_pool_op.py | 16 ++++++++++++++++ python/paddle/vision/ops.py | 2 ++ 2 files changed, 18 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py b/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py index 40f3c52d4f..c33d218cd8 100644 --- a/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py @@ -339,6 +339,22 @@ class TestPSROIPoolChannelError(unittest.TestCase): self.assertRaises(ValueError, test_channel_error) +class TestPSROIPoolZeroDivError(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.x = paddle.uniform([2, 490, 28, 28], dtype='float32') + self.boxes = paddle.to_tensor( + [[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], dtype='float32' + ) + self.boxes_num = paddle.to_tensor([1, 2], dtype='int32') + + def test_errors(self): + def test_zero_div_error(): + paddle.vision.ops.psroi_pool(self.x, self.boxes, self.boxes_num, 0) + + self.assertRaises(ValueError, test_zero_div_error) + + class TestPSROIPoolStaticAPI(unittest.TestCase): def setUp(self): paddle.enable_static() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 0696b5f7cc..0d43bd0fc5 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1424,6 +1424,8 @@ def psroi_pool(x, boxes, boxes_num, output_size, spatial_scale=1.0, name=None): output_size = (output_size, output_size) pooled_height, pooled_width = output_size assert len(x.shape) == 4, "Input features with shape should be (N, C, H, W)" + if pooled_height * pooled_width == 0: + raise ValueError('output_size should not contain 0.') output_channels = int(x.shape[1] / (pooled_height * pooled_width)) if in_dygraph_mode(): return _C_ops.psroi_pool( -- GitLab