From 43b410b7aef51db9ad230e205bfd07a8bbeec424 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Fri, 8 Apr 2022 11:01:22 +0800 Subject: [PATCH] add nms trt (#5603) * add nms trt support * add check version code * fix bugs --- deploy/python/infer.py | 2 +- ppdet/modeling/heads/ppyoloe_head.py | 3 +++ ppdet/modeling/layers.py | 18 ++++++++++++++++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 2808d2c58..362a8b1a1 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -653,7 +653,7 @@ def load_predictor(model_dir, } if run_mode in precision_map.keys(): config.enable_tensorrt_engine( - workspace_size=1 << 25, + workspace_size=(1 << 25) * batch_size, max_batch_size=batch_size, min_subgraph_size=min_subgraph_size, precision_mode=precision_map[run_mode], diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 393709eb5..31e7590a0 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -23,6 +23,7 @@ from ..initializer import bias_init_with_prob, constant_, normal_ from ..assigners.utils import generate_anchors_for_grid_cell from ppdet.modeling.backbones.cspresnet import ConvBNLayer from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn +from ppdet.modeling.layers import MultiClassNMS __all__ = ['PPYOLOEHead'] @@ -86,6 +87,8 @@ class PPYOLOEHead(nn.Layer): self.static_assigner = static_assigner self.assigner = assigner self.nms = nms + if isinstance(self.nms, MultiClassNMS) and trt: + self.nms.trt = trt self.exclude_nms = exclude_nms # stem self.stem_cls = nn.LayerList() diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 055cbf4f6..8238e8973 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -440,7 +440,8 @@ class MultiClassNMS(object): normalized=True, nms_eta=1.0, return_index=False, - return_rois_num=True): + return_rois_num=True, + trt=False): super(MultiClassNMS, self).__init__() self.score_threshold = score_threshold self.nms_top_k = nms_top_k @@ -450,6 +451,7 @@ class MultiClassNMS(object): self.nms_eta = nms_eta self.return_index = return_index self.return_rois_num = return_rois_num + self.trt = trt def __call__(self, bboxes, score, background_label=-1): """ @@ -471,7 +473,19 @@ class MultiClassNMS(object): kwargs.update({'rois_num': bbox_num}) if background_label > -1: kwargs.update({'background_label': background_label}) - return ops.multiclass_nms(bboxes, score, **kwargs) + kwargs.pop('trt') + # TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt + if self.trt and (int(paddle.version.major) == 0 or + (int(paddle.version.major) >= 2 and + int(paddle.version.minor) >= 3)): + # TODO(wangxinxin08): tricky switch to run nms on tensorrt + kwargs.update({'nms_eta': 1.1}) + bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs) + mask = paddle.slice(bbox, [-1], [0], [1]) != -1 + bbox = paddle.masked_select(bbox, mask).reshape((-1, 6)) + return bbox, bbox_num, None + else: + return ops.multiclass_nms(bboxes, score, **kwargs) @register -- GitLab