未验证 提交 43b410b7 编写于 作者: W wangxinxin08 提交者: GitHub

add nms trt (#5603)

* add nms trt support

* add check version code

* fix bugs
上级 71424eb4
...@@ -653,7 +653,7 @@ def load_predictor(model_dir, ...@@ -653,7 +653,7 @@ def load_predictor(model_dir,
} }
if run_mode in precision_map.keys(): if run_mode in precision_map.keys():
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
workspace_size=1 << 25, workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size, max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size, min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode], precision_mode=precision_map[run_mode],
......
...@@ -23,6 +23,7 @@ from ..initializer import bias_init_with_prob, constant_, normal_ ...@@ -23,6 +23,7 @@ from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
__all__ = ['PPYOLOEHead'] __all__ = ['PPYOLOEHead']
...@@ -86,6 +87,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -86,6 +87,8 @@ class PPYOLOEHead(nn.Layer):
self.static_assigner = static_assigner self.static_assigner = static_assigner
self.assigner = assigner self.assigner = assigner
self.nms = nms self.nms = nms
if isinstance(self.nms, MultiClassNMS) and trt:
self.nms.trt = trt
self.exclude_nms = exclude_nms self.exclude_nms = exclude_nms
# stem # stem
self.stem_cls = nn.LayerList() self.stem_cls = nn.LayerList()
......
...@@ -440,7 +440,8 @@ class MultiClassNMS(object): ...@@ -440,7 +440,8 @@ class MultiClassNMS(object):
normalized=True, normalized=True,
nms_eta=1.0, nms_eta=1.0,
return_index=False, return_index=False,
return_rois_num=True): return_rois_num=True,
trt=False):
super(MultiClassNMS, self).__init__() super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold self.score_threshold = score_threshold
self.nms_top_k = nms_top_k self.nms_top_k = nms_top_k
...@@ -450,6 +451,7 @@ class MultiClassNMS(object): ...@@ -450,6 +451,7 @@ class MultiClassNMS(object):
self.nms_eta = nms_eta self.nms_eta = nms_eta
self.return_index = return_index self.return_index = return_index
self.return_rois_num = return_rois_num self.return_rois_num = return_rois_num
self.trt = trt
def __call__(self, bboxes, score, background_label=-1): def __call__(self, bboxes, score, background_label=-1):
""" """
...@@ -471,7 +473,19 @@ class MultiClassNMS(object): ...@@ -471,7 +473,19 @@ class MultiClassNMS(object):
kwargs.update({'rois_num': bbox_num}) kwargs.update({'rois_num': bbox_num})
if background_label > -1: if background_label > -1:
kwargs.update({'background_label': background_label}) kwargs.update({'background_label': background_label})
return ops.multiclass_nms(bboxes, score, **kwargs) kwargs.pop('trt')
# TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt
if self.trt and (int(paddle.version.major) == 0 or
(int(paddle.version.major) >= 2 and
int(paddle.version.minor) >= 3)):
# TODO(wangxinxin08): tricky switch to run nms on tensorrt
kwargs.update({'nms_eta': 1.1})
bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
mask = paddle.slice(bbox, [-1], [0], [1]) != -1
bbox = paddle.masked_select(bbox, mask).reshape((-1, 6))
return bbox, bbox_num, None
else:
return ops.multiclass_nms(bboxes, score, **kwargs)
@register @register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册