diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 2808d2c58a2514c60b68a3e1ab4a21f0093a25c4..362a8b1a113bbda66e1be16482e83b983f219a5b 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -653,7 +653,7 @@ def load_predictor(model_dir, } if run_mode in precision_map.keys(): config.enable_tensorrt_engine( - workspace_size=1 << 25, + workspace_size=(1 << 25) * batch_size, max_batch_size=batch_size, min_subgraph_size=min_subgraph_size, precision_mode=precision_map[run_mode], diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 393709eb571b630dc69283d366816bffbf7a8286..31e7590a080ed37ac28c78e5e7b22e6ebe283181 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -23,6 +23,7 @@ from ..initializer import bias_init_with_prob, constant_, normal_ from ..assigners.utils import generate_anchors_for_grid_cell from ppdet.modeling.backbones.cspresnet import ConvBNLayer from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn +from ppdet.modeling.layers import MultiClassNMS __all__ = ['PPYOLOEHead'] @@ -86,6 +87,8 @@ class PPYOLOEHead(nn.Layer): self.static_assigner = static_assigner self.assigner = assigner self.nms = nms + if isinstance(self.nms, MultiClassNMS) and trt: + self.nms.trt = trt self.exclude_nms = exclude_nms # stem self.stem_cls = nn.LayerList() diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 055cbf4f6388ce2acf6ef5d4cd58ab93dbbc9fcb..8238e89734909b23dd0956b838d18d7f8cf019de 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -440,7 +440,8 @@ class MultiClassNMS(object): normalized=True, nms_eta=1.0, return_index=False, - return_rois_num=True): + return_rois_num=True, + trt=False): super(MultiClassNMS, self).__init__() self.score_threshold = score_threshold self.nms_top_k = nms_top_k @@ -450,6 +451,7 @@ class MultiClassNMS(object): self.nms_eta = nms_eta self.return_index = return_index self.return_rois_num = return_rois_num + self.trt = trt def __call__(self, bboxes, score, background_label=-1): """ @@ -471,7 +473,19 @@ class MultiClassNMS(object): kwargs.update({'rois_num': bbox_num}) if background_label > -1: kwargs.update({'background_label': background_label}) - return ops.multiclass_nms(bboxes, score, **kwargs) + kwargs.pop('trt') + # TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt + if self.trt and (int(paddle.version.major) == 0 or + (int(paddle.version.major) >= 2 and + int(paddle.version.minor) >= 3)): + # TODO(wangxinxin08): tricky switch to run nms on tensorrt + kwargs.update({'nms_eta': 1.1}) + bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs) + mask = paddle.slice(bbox, [-1], [0], [1]) != -1 + bbox = paddle.masked_select(bbox, mask).reshape((-1, 6)) + return bbox, bbox_num, None + else: + return ops.multiclass_nms(bboxes, score, **kwargs) @register