未验证 提交 750ea214 编写于 作者: G Guanghua Yu 提交者: GitHub

Support nms options when export model (#5373)

* support nms options when export model

* fix config
上级 c872ec6c
architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
export_post_process: False # Whether post-processing is included in the network when export model.
PicoDet:
backbone: ESNet
......
architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
export_post_process: False # Whether post-processing is included in the network when export model.
PicoDet:
backbone: ESNet
......
......@@ -4,3 +4,9 @@ log_iter: 20
save_dir: output
snapshot_epoch: 1
print_flops: false
# Exporting the model
export:
post_process: True # Whether post-processing is included in the network when export model.
nms: True # Whether NMS is included in the network when export model.
benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
......@@ -41,7 +41,7 @@ TRT_MIN_SUBGRAPH = {
'HigherHRNet': 3,
'HRNet': 3,
'DeepSORT': 3,
'ByteTrack':10,
'ByteTrack': 10,
'JDE': 10,
'FairMOT': 5,
'GFL': 16,
......@@ -166,7 +166,8 @@ def _dump_infer_config(config, path, image_shape, model):
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
if infer_arch == 'PicoDet':
if config.get('export_post_process', False):
if hasattr(config, 'export') and config['export'].get('post_process',
False):
infer_cfg['arch'] = 'GFL'
head_name = 'PicoHeadV2' if config['PicoHeadV2'] else 'PicoHead'
infer_cfg['NMS'] = config[head_name]['nms']
......
......@@ -651,13 +651,21 @@ class Trainer(object):
if hasattr(layer, 'convert_to_deploy'):
layer.convert_to_deploy()
export_post_process = self.cfg.get('export_post_process', False)
if hasattr(self.model, 'export_post_process'):
self.model.export_post_process = export_post_process
image_shape = [None] + image_shape[1:]
export_post_process = self.cfg['export'].get(
'post_process', False) if hasattr(self.cfg, 'export') else True
export_nms = self.cfg['export'].get('nms', False) if hasattr(
self.cfg, 'export') else True
export_benchmark = self.cfg['export'].get(
'benchmark', False) if hasattr(self.cfg, 'export') else False
if hasattr(self.model, 'fuse_norm'):
self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
False)
if hasattr(self.model, 'export_post_process'):
self.model.export_post_process = export_post_process if not export_benchmark else False
if hasattr(self.model, 'export_nms'):
self.model.export_nms = export_nms if not export_benchmark else False
if export_post_process and not export_benchmark:
image_shape = [None] + image_shape[1:]
# Save infer cfg
_dump_infer_config(self.cfg,
......@@ -789,5 +797,6 @@ class Trainer(object):
images.sort()
assert len(images) > 0, "no image found in {}".format(infer_dir)
all_images.extend(images)
logger.info("Found {} inference images in total.".format(len(images)))
logger.info("Found {} inference images in total.".format(
len(images)))
return all_images
......@@ -42,6 +42,7 @@ class PicoDet(BaseArch):
self.neck = neck
self.head = head
self.export_post_process = True
self.export_nms = True
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -68,8 +69,8 @@ class PicoDet(BaseArch):
else:
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bboxes, bbox_num = self.head.post_process(head_outs, im_shape,
scale_factor)
bboxes, bbox_num = self.head.post_process(
head_outs, im_shape, scale_factor, export_nms=self.export_nms)
return bboxes, bbox_num
def get_loss(self, ):
......@@ -85,7 +86,11 @@ class PicoDet(BaseArch):
def get_pred(self):
if not self.export_post_process:
return {'picodet': self._forward()[0]}
else:
elif self.export_nms:
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
else:
bboxes, mlvl_scores = self._forward()
output = {'bbox': bboxes, 'scores': mlvl_scores}
return output
......@@ -335,6 +335,24 @@ class PicoHead(OTAVFLHead):
return (cls_logits_list, bboxes_reg_list)
def post_process(self,
gfl_head_outs,
im_shape,
scale_factor,
export_nms=True):
cls_scores, bboxes_reg = gfl_head_outs
bboxes = paddle.concat(bboxes_reg, axis=1)
mlvl_scores = paddle.concat(cls_scores, axis=1)
mlvl_scores = mlvl_scores.transpose([0, 2, 1])
if not export_nms:
return bboxes, mlvl_scores
else:
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
bboxes /= im_scale
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
return bbox_pred, bbox_num
@register
class PicoHeadV2(GFLHead):
......@@ -625,3 +643,21 @@ class PicoHeadV2(GFLHead):
loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
return loss_states
def post_process(self,
gfl_head_outs,
im_shape,
scale_factor,
export_nms=True):
cls_scores, bboxes_reg = gfl_head_outs
bboxes = paddle.concat(bboxes_reg, axis=1)
mlvl_scores = paddle.concat(cls_scores, axis=1)
mlvl_scores = mlvl_scores.transpose([0, 2, 1])
if not export_nms:
return bboxes, mlvl_scores
else:
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
bboxes /= im_scale
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
return bbox_pred, bbox_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册