未验证 提交 750ea214 编写于 作者: G Guanghua Yu 提交者: GitHub

Support nms options when export model (#5373)

* support nms options when export model

* fix config
上级 c872ec6c
architecture: PicoDet architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
export_post_process: False # Whether post-processing is included in the network when export model.
PicoDet: PicoDet:
backbone: ESNet backbone: ESNet
......
architecture: PicoDet architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
export_post_process: False # Whether post-processing is included in the network when export model.
PicoDet: PicoDet:
backbone: ESNet backbone: ESNet
......
...@@ -4,3 +4,9 @@ log_iter: 20 ...@@ -4,3 +4,9 @@ log_iter: 20
save_dir: output save_dir: output
snapshot_epoch: 1 snapshot_epoch: 1
print_flops: false print_flops: false
# Exporting the model
export:
post_process: True # Whether post-processing is included in the network when export model.
nms: True # Whether NMS is included in the network when export model.
benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
...@@ -41,7 +41,7 @@ TRT_MIN_SUBGRAPH = { ...@@ -41,7 +41,7 @@ TRT_MIN_SUBGRAPH = {
'HigherHRNet': 3, 'HigherHRNet': 3,
'HRNet': 3, 'HRNet': 3,
'DeepSORT': 3, 'DeepSORT': 3,
'ByteTrack':10, 'ByteTrack': 10,
'JDE': 10, 'JDE': 10,
'FairMOT': 5, 'FairMOT': 5,
'GFL': 16, 'GFL': 16,
...@@ -166,7 +166,8 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -166,7 +166,8 @@ def _dump_infer_config(config, path, image_shape, model):
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:]) reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
if infer_arch == 'PicoDet': if infer_arch == 'PicoDet':
if config.get('export_post_process', False): if hasattr(config, 'export') and config['export'].get('post_process',
False):
infer_cfg['arch'] = 'GFL' infer_cfg['arch'] = 'GFL'
head_name = 'PicoHeadV2' if config['PicoHeadV2'] else 'PicoHead' head_name = 'PicoHeadV2' if config['PicoHeadV2'] else 'PicoHead'
infer_cfg['NMS'] = config[head_name]['nms'] infer_cfg['NMS'] = config[head_name]['nms']
......
...@@ -651,13 +651,21 @@ class Trainer(object): ...@@ -651,13 +651,21 @@ class Trainer(object):
if hasattr(layer, 'convert_to_deploy'): if hasattr(layer, 'convert_to_deploy'):
layer.convert_to_deploy() layer.convert_to_deploy()
export_post_process = self.cfg.get('export_post_process', False) export_post_process = self.cfg['export'].get(
if hasattr(self.model, 'export_post_process'): 'post_process', False) if hasattr(self.cfg, 'export') else True
self.model.export_post_process = export_post_process export_nms = self.cfg['export'].get('nms', False) if hasattr(
image_shape = [None] + image_shape[1:] self.cfg, 'export') else True
export_benchmark = self.cfg['export'].get(
'benchmark', False) if hasattr(self.cfg, 'export') else False
if hasattr(self.model, 'fuse_norm'): if hasattr(self.model, 'fuse_norm'):
self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize', self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
False) False)
if hasattr(self.model, 'export_post_process'):
self.model.export_post_process = export_post_process if not export_benchmark else False
if hasattr(self.model, 'export_nms'):
self.model.export_nms = export_nms if not export_benchmark else False
if export_post_process and not export_benchmark:
image_shape = [None] + image_shape[1:]
# Save infer cfg # Save infer cfg
_dump_infer_config(self.cfg, _dump_infer_config(self.cfg,
...@@ -789,5 +797,6 @@ class Trainer(object): ...@@ -789,5 +797,6 @@ class Trainer(object):
images.sort() images.sort()
assert len(images) > 0, "no image found in {}".format(infer_dir) assert len(images) > 0, "no image found in {}".format(infer_dir)
all_images.extend(images) all_images.extend(images)
logger.info("Found {} inference images in total.".format(len(images))) logger.info("Found {} inference images in total.".format(
len(images)))
return all_images return all_images
...@@ -42,6 +42,7 @@ class PicoDet(BaseArch): ...@@ -42,6 +42,7 @@ class PicoDet(BaseArch):
self.neck = neck self.neck = neck
self.head = head self.head = head
self.export_post_process = True self.export_post_process = True
self.export_nms = True
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -68,8 +69,8 @@ class PicoDet(BaseArch): ...@@ -68,8 +69,8 @@ class PicoDet(BaseArch):
else: else:
im_shape = self.inputs['im_shape'] im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor'] scale_factor = self.inputs['scale_factor']
bboxes, bbox_num = self.head.post_process(head_outs, im_shape, bboxes, bbox_num = self.head.post_process(
scale_factor) head_outs, im_shape, scale_factor, export_nms=self.export_nms)
return bboxes, bbox_num return bboxes, bbox_num
def get_loss(self, ): def get_loss(self, ):
...@@ -85,7 +86,11 @@ class PicoDet(BaseArch): ...@@ -85,7 +86,11 @@ class PicoDet(BaseArch):
def get_pred(self): def get_pred(self):
if not self.export_post_process: if not self.export_post_process:
return {'picodet': self._forward()[0]} return {'picodet': self._forward()[0]}
else: elif self.export_nms:
bbox_pred, bbox_num = self._forward() bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num} output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output return output
else:
bboxes, mlvl_scores = self._forward()
output = {'bbox': bboxes, 'scores': mlvl_scores}
return output
...@@ -335,6 +335,24 @@ class PicoHead(OTAVFLHead): ...@@ -335,6 +335,24 @@ class PicoHead(OTAVFLHead):
return (cls_logits_list, bboxes_reg_list) return (cls_logits_list, bboxes_reg_list)
def post_process(self,
gfl_head_outs,
im_shape,
scale_factor,
export_nms=True):
cls_scores, bboxes_reg = gfl_head_outs
bboxes = paddle.concat(bboxes_reg, axis=1)
mlvl_scores = paddle.concat(cls_scores, axis=1)
mlvl_scores = mlvl_scores.transpose([0, 2, 1])
if not export_nms:
return bboxes, mlvl_scores
else:
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
bboxes /= im_scale
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
return bbox_pred, bbox_num
@register @register
class PicoHeadV2(GFLHead): class PicoHeadV2(GFLHead):
...@@ -625,3 +643,21 @@ class PicoHeadV2(GFLHead): ...@@ -625,3 +643,21 @@ class PicoHeadV2(GFLHead):
loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl) loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
return loss_states return loss_states
def post_process(self,
gfl_head_outs,
im_shape,
scale_factor,
export_nms=True):
cls_scores, bboxes_reg = gfl_head_outs
bboxes = paddle.concat(bboxes_reg, axis=1)
mlvl_scores = paddle.concat(cls_scores, axis=1)
mlvl_scores = mlvl_scores.transpose([0, 2, 1])
if not export_nms:
return bboxes, mlvl_scores
else:
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
bboxes /= im_scale
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
return bbox_pred, bbox_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册