diff --git a/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py b/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py index b6bb8cce57fc06bde20ce1c0faa68cb9bd615cb0..7060cfbeebcaf69bbb9543774661e8b914aae287 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py +++ b/x2paddle/op_mapper/paddle2onnx/opset11/paddle_custom_layer/multiclass_nms.py @@ -72,6 +72,8 @@ def multiclass_nms(op, block): dims=(), vals=[float(attrs['nms_threshold'])])) + boxes_num = block.var( outputs['Out'][0]).shape[0] + top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k']) node_keep_top_k = onnx.helper.make_node( 'Constant', inputs=[], @@ -80,7 +82,7 @@ def multiclass_nms(op, block): name=name_keep_top_k[0] + "@const", data_type=onnx.TensorProto.INT64, dims=(), - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) node_keep_top_k_2D = onnx.helper.make_node( 'Constant', @@ -90,7 +92,7 @@ def multiclass_nms(op, block): name=name_keep_top_k_2D[0] + "@const", data_type=onnx.TensorProto.INT64, dims=[1, 1], - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) # the paddle data format is x1,y1,x2,y2 kwargs = {'center_point_box': 0} diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py b/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py index 65430bb159bb4698b62fa9bda6b572062b49b6fc..57d8a7467789e3bead12c2d3ffaf86a95c397cf8 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/paddle_custom_layer/multiclass_nms.py @@ -72,6 +72,8 @@ def multiclass_nms(op, block): dims=(), vals=[float(attrs['nms_threshold'])])) + boxes_num = block.var( outputs['Out'][0]).shape[0] + top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k']) node_keep_top_k = onnx.helper.make_node( 'Constant', inputs=[], @@ -80,7 +82,7 @@ def multiclass_nms(op, block): name=name_keep_top_k[0] + "@const", data_type=onnx.TensorProto.INT64, dims=(), - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) node_keep_top_k_2D = onnx.helper.make_node( 'Constant', @@ -90,7 +92,7 @@ def multiclass_nms(op, block): name=name_keep_top_k_2D[0] + "@const", data_type=onnx.TensorProto.INT64, dims=[1, 1], - vals=[np.int64(attrs['keep_top_k'])])) + vals=[top_k_value])) # the paddle data format is x1,y1,x2,y2 kwargs = {'center_point_box': 0}