From 47a244f6e2f2d18b52760622563efe074e43414d Mon Sep 17 00:00:00 2001 From: Channingss Date: Tue, 1 Sep 2020 08:21:23 +0000 Subject: [PATCH] fix bug of multiclass_nms when attr:keep_top_k==-1 --- .../opset11/paddle_custom_layer/multiclass_nms.py | 6 ++++-- .../opset9/paddle_custom_layer/multiclass_nms.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py b/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py index b6bb8cc..7060cfb 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py +++ b/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py @@ -72,6 +72,8 @@ def multiclass_nms(op, block): dims=(), vals=[float(attrs['nms_threshold'])])) + boxes_num = block.var( outputs['Out'][0]).shape[0] + top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k']) node_keep_top_k = onnx.helper.make_node( 'Constant', inputs=[], @@ -80,7 +82,7 @@ def multiclass_nms(op, block): name=name_keep_top_k[0] + "@const", data_type=onnx.TensorProto.INT64, dims=(), - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) node_keep_top_k_2D = onnx.helper.make_node( 'Constant', @@ -90,7 +92,7 @@ def multiclass_nms(op, block): name=name_keep_top_k_2D[0] + "@const", data_type=onnx.TensorProto.INT64, dims=[1, 1], - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) # the paddle data format is x1,y1,x2,y2 kwargs = {'center_point_box': 0} diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py b/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py index 65430bb..57d8a74 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py @@ -72,6 +72,8 @@ def multiclass_nms(op, block): dims=(), vals=[float(attrs['nms_threshold'])])) + boxes_num = block.var( outputs['Out'][0]).shape[0] + top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k']) node_keep_top_k = onnx.helper.make_node( 'Constant', inputs=[], @@ -80,7 +82,7 @@ def multiclass_nms(op, block): name=name_keep_top_k[0] + "@const", data_type=onnx.TensorProto.INT64, dims=(), - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) node_keep_top_k_2D = onnx.helper.make_node( 'Constant', @@ -90,7 +92,7 @@ def multiclass_nms(op, block): name=name_keep_top_k_2D[0] + "@const", data_type=onnx.TensorProto.INT64, dims=[1, 1], - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) # the paddle data format is x1,y1,x2,y2 kwargs = {'center_point_box': 0} -- GitLab