From 9db4a31787602439e8564dda3f0c2bbd1eee4585 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 18 May 2021 10:47:39 +0800 Subject: [PATCH] fix_test_nms (#3050) --- ppdet/modeling/tests/test_ops.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index bccebc258..c08bcca4b 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -570,16 +570,19 @@ class TestAnchorGenerator(LayerTest): class TestMulticlassNms(LayerTest): def test_multiclass_nms(self): - boxes_np = np.random.rand(81, 4).astype('float32') - scores_np = np.random.rand(81).astype('float32') - rois_num_np = np.array([40, 41]).astype('int32') + boxes_np = np.random.rand(10, 81, 4).astype('float32') + scores_np = np.random.rand(10, 81).astype('float32') + rois_num_np = np.array([2, 8]).astype('int32') with self.static_graph(): boxes = paddle.static.data( - name='bboxes', shape=[81, 4], dtype='float32', lod_level=1) + name='bboxes', + shape=[None, 81, 4], + dtype='float32', + lod_level=1) scores = paddle.static.data( - name='scores', shape=[81], dtype='float32', lod_level=1) + name='scores', shape=[None, 81], dtype='float32', lod_level=1) rois_num = paddle.static.data( - name='rois_num', shape=[40, 41], dtype='int32') + name='rois_num', shape=[None], dtype='int32') output = ops.multiclass_nms( bboxes=boxes, @@ -599,7 +602,10 @@ class TestMulticlassNms(LayerTest): 'rois_num': rois_num_np }, fetch_list=output, - with_lod=False) + with_lod=True) + out_np = np.array(out_np) + index_np = np.array(index_np) + nms_rois_num_np = np.array(nms_rois_num_np) with self.dynamic_graph(): boxes_dy = base.to_variable(boxes_np) -- GitLab