未验证 提交 43b410b7 编写于 作者: W wangxinxin08 提交者: GitHub

add nms trt (#5603)

* add nms trt support

* add check version code

* fix bugs
上级 71424eb4
......@@ -653,7 +653,7 @@ def load_predictor(model_dir,
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 25,
workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
......
......@@ -23,6 +23,7 @@ from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
__all__ = ['PPYOLOEHead']
......@@ -86,6 +87,8 @@ class PPYOLOEHead(nn.Layer):
self.static_assigner = static_assigner
self.assigner = assigner
self.nms = nms
if isinstance(self.nms, MultiClassNMS) and trt:
self.nms.trt = trt
self.exclude_nms = exclude_nms
# stem
self.stem_cls = nn.LayerList()
......
......@@ -440,7 +440,8 @@ class MultiClassNMS(object):
normalized=True,
nms_eta=1.0,
return_index=False,
return_rois_num=True):
return_rois_num=True,
trt=False):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
......@@ -450,6 +451,7 @@ class MultiClassNMS(object):
self.nms_eta = nms_eta
self.return_index = return_index
self.return_rois_num = return_rois_num
self.trt = trt
def __call__(self, bboxes, score, background_label=-1):
"""
......@@ -471,7 +473,19 @@ class MultiClassNMS(object):
kwargs.update({'rois_num': bbox_num})
if background_label > -1:
kwargs.update({'background_label': background_label})
return ops.multiclass_nms(bboxes, score, **kwargs)
kwargs.pop('trt')
# TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt
if self.trt and (int(paddle.version.major) == 0 or
(int(paddle.version.major) >= 2 and
int(paddle.version.minor) >= 3)):
# TODO(wangxinxin08): tricky switch to run nms on tensorrt
kwargs.update({'nms_eta': 1.1})
bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
mask = paddle.slice(bbox, [-1], [0], [1]) != -1
bbox = paddle.masked_select(bbox, mask).reshape((-1, 6))
return bbox, bbox_num, None
else:
return ops.multiclass_nms(bboxes, score, **kwargs)
@register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册