未验证 提交 09d35587 编写于 作者: J Jason 提交者: GitHub

Merge pull request #396 from Channingss/fix_nms

fix bug of multiclass_nms when attr:keep_top_k==-1
...@@ -72,6 +72,8 @@ def multiclass_nms(op, block): ...@@ -72,6 +72,8 @@ def multiclass_nms(op, block):
dims=(), dims=(),
vals=[float(attrs['nms_threshold'])])) vals=[float(attrs['nms_threshold'])]))
boxes_num = block.var(outputs['Out'][0]).shape[0]
top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k'])
node_keep_top_k = onnx.helper.make_node( node_keep_top_k = onnx.helper.make_node(
'Constant', 'Constant',
inputs=[], inputs=[],
...@@ -80,7 +82,7 @@ def multiclass_nms(op, block): ...@@ -80,7 +82,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k[0] + "@const", name=name_keep_top_k[0] + "@const",
data_type=onnx.TensorProto.INT64, data_type=onnx.TensorProto.INT64,
dims=(), dims=(),
vals=[np.int64(attrs['keep_top_k'])])) vals=[top_k_value]))
node_keep_top_k_2D = onnx.helper.make_node( node_keep_top_k_2D = onnx.helper.make_node(
'Constant', 'Constant',
...@@ -90,7 +92,7 @@ def multiclass_nms(op, block): ...@@ -90,7 +92,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k_2D[0] + "@const", name=name_keep_top_k_2D[0] + "@const",
data_type=onnx.TensorProto.INT64, data_type=onnx.TensorProto.INT64,
dims=[1, 1], dims=[1, 1],
vals=[np.int64(attrs['keep_top_k'])])) vals=[top_k_value]))
# the paddle data format is x1,y1,x2,y2 # the paddle data format is x1,y1,x2,y2
kwargs = {'center_point_box': 0} kwargs = {'center_point_box': 0}
......
...@@ -72,6 +72,8 @@ def multiclass_nms(op, block): ...@@ -72,6 +72,8 @@ def multiclass_nms(op, block):
dims=(), dims=(),
vals=[float(attrs['nms_threshold'])])) vals=[float(attrs['nms_threshold'])]))
boxes_num = block.var(outputs['Out'][0]).shape[0]
top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k'])
node_keep_top_k = onnx.helper.make_node( node_keep_top_k = onnx.helper.make_node(
'Constant', 'Constant',
inputs=[], inputs=[],
...@@ -80,7 +82,7 @@ def multiclass_nms(op, block): ...@@ -80,7 +82,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k[0] + "@const", name=name_keep_top_k[0] + "@const",
data_type=onnx.TensorProto.INT64, data_type=onnx.TensorProto.INT64,
dims=(), dims=(),
vals=[np.int64(attrs['keep_top_k'])])) vals=[top_k_value]))
node_keep_top_k_2D = onnx.helper.make_node( node_keep_top_k_2D = onnx.helper.make_node(
'Constant', 'Constant',
...@@ -90,7 +92,7 @@ def multiclass_nms(op, block): ...@@ -90,7 +92,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k_2D[0] + "@const", name=name_keep_top_k_2D[0] + "@const",
data_type=onnx.TensorProto.INT64, data_type=onnx.TensorProto.INT64,
dims=[1, 1], dims=[1, 1],
vals=[np.int64(attrs['keep_top_k'])])) vals=[top_k_value]))
# the paddle data format is x1,y1,x2,y2 # the paddle data format is x1,y1,x2,y2
kwargs = {'center_point_box': 0} kwargs = {'center_point_box': 0}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册