From 6e61c9f9d1831d43a1c0b9dba4cd362cb13af846 Mon Sep 17 00:00:00 2001 From: JYChen Date: Mon, 31 Jul 2023 11:04:46 +0800 Subject: [PATCH] fix distribute_fpn_proposals (#55785) --- python/paddle/vision/ops.py | 8 +++++ test/legacy_test/test_detection.py | 57 ++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index b02451f53b3..0f73f950d15 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1218,7 +1218,15 @@ def distribute_fpn_proposals( rois_num=rois_num) """ + assert ( + max_level > 0 and min_level > 0 + ), "min_level and max_level should be greater than 0" + num_lvl = max_level - min_level + 1 + assert num_lvl > 1, "max_level should be greater than min_level" + assert ( + num_lvl < 100 + ), "Only support max to 100 levels, (max_level - min_level + 1 < 100)" if in_dygraph_mode(): assert ( diff --git a/test/legacy_test/test_detection.py b/test/legacy_test/test_detection.py index fbb0c35824d..e3aebb94669 100644 --- a/test/legacy_test/test_detection.py +++ b/test/legacy_test/test_detection.py @@ -259,6 +259,63 @@ class TestDistributeFpnProposals(LayerTest): refer_scale=224, ) + def test_distribute_fpn_proposals_error2(self): + program = Program() + with program_guard(program): + fpn_rois = paddle.static.data( + name='min_max_level_error1', + shape=[10, 4], + dtype='float32', + lod_level=1, + ) + self.assertRaises( + AssertionError, + paddle.vision.ops.distribute_fpn_proposals, + fpn_rois=fpn_rois, + min_level=0, + max_level=-1, + refer_level=4, + refer_scale=224, + ) + + def test_distribute_fpn_proposals_error3(self): + program = Program() + with program_guard(program): + fpn_rois = paddle.static.data( + name='min_max_level_error2', + shape=[10, 4], + dtype='float32', + lod_level=1, + ) + self.assertRaises( + AssertionError, + paddle.vision.ops.distribute_fpn_proposals, + fpn_rois=fpn_rois, + min_level=2, + max_level=2, + refer_level=4, + refer_scale=224, + ) + + def test_distribute_fpn_proposals_error4(self): + program = Program() + with program_guard(program): + fpn_rois = paddle.static.data( + name='min_max_level_error3', + shape=[10, 4], + dtype='float32', + lod_level=1, + ) + self.assertRaises( + AssertionError, + paddle.vision.ops.distribute_fpn_proposals, + fpn_rois=fpn_rois, + min_level=2, + max_level=500, + refer_level=4, + refer_scale=224, + ) + if __name__ == '__main__': paddle.enable_static() -- GitLab