未验证 提交 9db4a317 编写于 作者: W wangguanzhong 提交者: GitHub

fix_test_nms (#3050)

上级 b7c50f5b
...@@ -570,16 +570,19 @@ class TestAnchorGenerator(LayerTest): ...@@ -570,16 +570,19 @@ class TestAnchorGenerator(LayerTest):
class TestMulticlassNms(LayerTest): class TestMulticlassNms(LayerTest):
def test_multiclass_nms(self): def test_multiclass_nms(self):
boxes_np = np.random.rand(81, 4).astype('float32') boxes_np = np.random.rand(10, 81, 4).astype('float32')
scores_np = np.random.rand(81).astype('float32') scores_np = np.random.rand(10, 81).astype('float32')
rois_num_np = np.array([40, 41]).astype('int32') rois_num_np = np.array([2, 8]).astype('int32')
with self.static_graph(): with self.static_graph():
boxes = paddle.static.data( boxes = paddle.static.data(
name='bboxes', shape=[81, 4], dtype='float32', lod_level=1) name='bboxes',
shape=[None, 81, 4],
dtype='float32',
lod_level=1)
scores = paddle.static.data( scores = paddle.static.data(
name='scores', shape=[81], dtype='float32', lod_level=1) name='scores', shape=[None, 81], dtype='float32', lod_level=1)
rois_num = paddle.static.data( rois_num = paddle.static.data(
name='rois_num', shape=[40, 41], dtype='int32') name='rois_num', shape=[None], dtype='int32')
output = ops.multiclass_nms( output = ops.multiclass_nms(
bboxes=boxes, bboxes=boxes,
...@@ -599,7 +602,10 @@ class TestMulticlassNms(LayerTest): ...@@ -599,7 +602,10 @@ class TestMulticlassNms(LayerTest):
'rois_num': rois_num_np 'rois_num': rois_num_np
}, },
fetch_list=output, fetch_list=output,
with_lod=False) with_lod=True)
out_np = np.array(out_np)
index_np = np.array(index_np)
nms_rois_num_np = np.array(nms_rois_num_np)
with self.dynamic_graph(): with self.dynamic_graph():
boxes_dy = base.to_variable(boxes_np) boxes_dy = base.to_variable(boxes_np)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册