未验证 提交 6e61c9f9 编写于 作者: J JYChen 提交者: GitHub

fix distribute_fpn_proposals (#55785)

上级 608a3f28
......@@ -1218,7 +1218,15 @@ def distribute_fpn_proposals(
rois_num=rois_num)
"""
assert (
max_level > 0 and min_level > 0
), "min_level and max_level should be greater than 0"
num_lvl = max_level - min_level + 1
assert num_lvl > 1, "max_level should be greater than min_level"
assert (
num_lvl < 100
), "Only support max to 100 levels, (max_level - min_level + 1 < 100)"
if in_dygraph_mode():
assert (
......
......@@ -259,6 +259,63 @@ class TestDistributeFpnProposals(LayerTest):
refer_scale=224,
)
def test_distribute_fpn_proposals_error2(self):
program = Program()
with program_guard(program):
fpn_rois = paddle.static.data(
name='min_max_level_error1',
shape=[10, 4],
dtype='float32',
lod_level=1,
)
self.assertRaises(
AssertionError,
paddle.vision.ops.distribute_fpn_proposals,
fpn_rois=fpn_rois,
min_level=0,
max_level=-1,
refer_level=4,
refer_scale=224,
)
def test_distribute_fpn_proposals_error3(self):
program = Program()
with program_guard(program):
fpn_rois = paddle.static.data(
name='min_max_level_error2',
shape=[10, 4],
dtype='float32',
lod_level=1,
)
self.assertRaises(
AssertionError,
paddle.vision.ops.distribute_fpn_proposals,
fpn_rois=fpn_rois,
min_level=2,
max_level=2,
refer_level=4,
refer_scale=224,
)
def test_distribute_fpn_proposals_error4(self):
program = Program()
with program_guard(program):
fpn_rois = paddle.static.data(
name='min_max_level_error3',
shape=[10, 4],
dtype='float32',
lod_level=1,
)
self.assertRaises(
AssertionError,
paddle.vision.ops.distribute_fpn_proposals,
fpn_rois=fpn_rois,
min_level=2,
max_level=500,
refer_level=4,
refer_scale=224,
)
if __name__ == '__main__':
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册