diff --git a/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py b/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py index b6bb8cce57fc06bde20ce1c0faa68cb9bd615cb0..f953ad99d38958f1f8e6fc9927086401a4f28014 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py +++ b/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py @@ -72,6 +72,8 @@ def multiclass_nms(op, block): dims=(), vals=[float(attrs['nms_threshold'])])) + boxes_num = block.var(outputs['Out'][0]).shape[0] + top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k']) node_keep_top_k = onnx.helper.make_node( 'Constant', inputs=[], @@ -80,7 +82,7 @@ def multiclass_nms(op, block): name=name_keep_top_k[0] + "@const", data_type=onnx.TensorProto.INT64, dims=(), - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) node_keep_top_k_2D = onnx.helper.make_node( 'Constant', @@ -90,7 +92,7 @@ def multiclass_nms(op, block): name=name_keep_top_k_2D[0] + "@const", data_type=onnx.TensorProto.INT64, dims=[1, 1], - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) # the paddle data format is x1,y1,x2,y2 kwargs = {'center_point_box': 0} diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py b/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py index 65430bb159bb4698b62fa9bda6b572062b49b6fc..6d8172fc63c7441ff96b49da44ea700d5700289e 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py @@ -72,6 +72,8 @@ def multiclass_nms(op, block): dims=(), vals=[float(attrs['nms_threshold'])])) + boxes_num = block.var(outputs['Out'][0]).shape[0] + top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k']) node_keep_top_k = onnx.helper.make_node( 'Constant', inputs=[], @@ -80,7 +82,7 @@ def multiclass_nms(op, block): name=name_keep_top_k[0] + "@const", data_type=onnx.TensorProto.INT64, dims=(), - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) node_keep_top_k_2D = onnx.helper.make_node( 'Constant', @@ -90,7 +92,7 @@ def multiclass_nms(op, block): name=name_keep_top_k_2D[0] + "@const", data_type=onnx.TensorProto.INT64, dims=[1, 1], - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) # the paddle data format is x1,y1,x2,y2 kwargs = {'center_point_box': 0}