diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 98da6c47772c4edff38224afbc6a2faea0e09f7c..ae0e21d8ea4b20b1f4995ab4bedd9720bafe8c95 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -72,6 +72,8 @@ class Trainer(object): self.amp_level = self.cfg.get('amp_level', 'O1') self.custom_white_list = self.cfg.get('custom_white_list', None) self.custom_black_list = self.cfg.get('custom_black_list', None) + if 'slim' in cfg and cfg['slim_type'] == 'PTQ': + self.cfg['TestDataset'] = create('TestDataset')() # build data loader capital_mode = self.mode.capitalize() diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index 7d75082b2b333d35b18dc66d1b24931255084f97..712919002ff49d9ff503fa8caaed85c954a02104 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -83,9 +83,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): load_config(slim_cfg) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) + cfg['slim_type'] = cfg.slim cfg['slim'] = slim cfg['model'] = slim(model) - cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'UnstructuredPruner': load_config(slim_cfg) slim = create(cfg.slim)