From fb2c24a048e0b126b8f034acb262f137bf2ee57d Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 4 Jan 2022 10:47:56 +0800 Subject: [PATCH] support export post_process in PicoDet (#5044) * support export post_process in PicoDet --- configs/picodet/_base_/picodet_esnet.yml | 1 + ppdet/engine/export_utils.py | 2 ++ ppdet/engine/trainer.py | 5 +++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/configs/picodet/_base_/picodet_esnet.yml b/configs/picodet/_base_/picodet_esnet.yml index aa099fca1..150d25e9f 100644 --- a/configs/picodet/_base_/picodet_esnet.yml +++ b/configs/picodet/_base_/picodet_esnet.yml @@ -1,5 +1,6 @@ architecture: PicoDet pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams +export_post_process: False # Whether post-processing is included in the network PicoDet: backbone: ESNet diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 4fe623d17..192b89658 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -165,6 +165,8 @@ def _dump_infer_config(config, path, image_shape, model): reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:]) if infer_arch == 'PicoDet': + if config.get('export_post_process', False): + infer_cfg['arch'] = 'GFL' infer_cfg['NMS'] = config['PicoHead']['nms'] # In order to speed up the prediction, the threshold of nms # is adjusted here, which can be changed in infer_cfg.yml diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 0d073d028..b8d3b2aca 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -631,7 +631,8 @@ class Trainer(object): im_shape = [image_shape[0], 2] scale_factor = [image_shape[0], 2] - if hasattr(self.model, 'deploy'): + export_post_process = self.cfg.get('export_post_process', False) + if hasattr(self.model, 'deploy') and not export_post_process: self.model.deploy = True if hasattr(self.model, 'fuse_norm'): self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize', @@ -668,7 +669,7 @@ class Trainer(object): pruned_input_spec = input_spec # TODO: Hard code, delete it when support prune input_spec. - if self.cfg.architecture == 'PicoDet': + if self.cfg.architecture == 'PicoDet' and not export_post_process: pruned_input_spec = [{ "image": InputSpec( shape=image_shape, name='image') -- GitLab