未验证 提交 fb2c24a0 编写于 作者: G Guanghua Yu 提交者: GitHub

support export post_process in PicoDet (#5044)

* support export post_process in PicoDet
上级 1a46f29f
architecture: PicoDet architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
export_post_process: False # Whether post-processing is included in the network
PicoDet: PicoDet:
backbone: ESNet backbone: ESNet
......
...@@ -165,6 +165,8 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -165,6 +165,8 @@ def _dump_infer_config(config, path, image_shape, model):
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:]) reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
if infer_arch == 'PicoDet': if infer_arch == 'PicoDet':
if config.get('export_post_process', False):
infer_cfg['arch'] = 'GFL'
infer_cfg['NMS'] = config['PicoHead']['nms'] infer_cfg['NMS'] = config['PicoHead']['nms']
# In order to speed up the prediction, the threshold of nms # In order to speed up the prediction, the threshold of nms
# is adjusted here, which can be changed in infer_cfg.yml # is adjusted here, which can be changed in infer_cfg.yml
......
...@@ -631,7 +631,8 @@ class Trainer(object): ...@@ -631,7 +631,8 @@ class Trainer(object):
im_shape = [image_shape[0], 2] im_shape = [image_shape[0], 2]
scale_factor = [image_shape[0], 2] scale_factor = [image_shape[0], 2]
if hasattr(self.model, 'deploy'): export_post_process = self.cfg.get('export_post_process', False)
if hasattr(self.model, 'deploy') and not export_post_process:
self.model.deploy = True self.model.deploy = True
if hasattr(self.model, 'fuse_norm'): if hasattr(self.model, 'fuse_norm'):
self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize', self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
...@@ -668,7 +669,7 @@ class Trainer(object): ...@@ -668,7 +669,7 @@ class Trainer(object):
pruned_input_spec = input_spec pruned_input_spec = input_spec
# TODO: Hard code, delete it when support prune input_spec. # TODO: Hard code, delete it when support prune input_spec.
if self.cfg.architecture == 'PicoDet': if self.cfg.architecture == 'PicoDet' and not export_post_process:
pruned_input_spec = [{ pruned_input_spec = [{
"image": InputSpec( "image": InputSpec(
shape=image_shape, name='image') shape=image_shape, name='image')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册