From a39cf8fc840e5565d1682009d33bd886b1fced91 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 14 Feb 2023 10:33:41 +0800 Subject: [PATCH] fix post quant (#7751) --- ppdet/engine/trainer.py | 2 ++ ppdet/slim/__init__.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 98da6c477..ae0e21d8e 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -72,6 +72,8 @@ class Trainer(object): self.amp_level = self.cfg.get('amp_level', 'O1') self.custom_white_list = self.cfg.get('custom_white_list', None) self.custom_black_list = self.cfg.get('custom_black_list', None) + if 'slim' in cfg and cfg['slim_type'] == 'PTQ': + self.cfg['TestDataset'] = create('TestDataset')() # build data loader capital_mode = self.mode.capitalize() diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index 7d75082b2..712919002 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -83,9 +83,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): load_config(slim_cfg) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) + cfg['slim_type'] = cfg.slim cfg['slim'] = slim cfg['model'] = slim(model) - cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'UnstructuredPruner': load_config(slim_cfg) slim = create(cfg.slim) -- GitLab