diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 35e1b97102bf3c1650a9907f33d0d938b0288c60..622b9c3b290cb53fc40c4d07341ac115e8b73922 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -103,6 +103,13 @@ class Trainer(object): if 'slim' in cfg and cfg['slim_type'] == 'OFA': self.model.model.load_meanstd(cfg['TestReader'][ 'sample_transforms']) + elif 'slim' in cfg and cfg['slim_type'] == 'Distill': + self.model.student_model.load_meanstd(cfg['TestReader'][ + 'sample_transforms']) + elif 'slim' in cfg and cfg[ + 'slim_type'] == 'DistillPrune' and self.mode == 'train': + self.model.student_model.load_meanstd(cfg['TestReader'][ + 'sample_transforms']) else: self.model.load_meanstd(cfg['TestReader']['sample_transforms']) diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index e71481d1c8fd61c646f4919ddb52c85020f11725..8b343eb6001ec48bc720f5ec31bce65e95cc9797 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -37,14 +37,15 @@ def build_slim_model(cfg, slim_cfg, mode='train'): if slim_load_cfg['slim'] == 'Distill': model = DistillModel(cfg, slim_cfg) cfg['model'] = model + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'OFA': load_config(slim_cfg) model = create(cfg.architecture) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) - cfg['slim_type'] = cfg.slim - cfg['model'] = slim(model, model.state_dict()) cfg['slim'] = slim + cfg['model'] = slim(model, model.state_dict()) + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'DistillPrune': if mode == 'train': model = DistillModel(cfg, slim_cfg) @@ -64,9 +65,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): load_config(slim_cfg) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) - cfg['slim_type'] = cfg.slim - cfg['model'] = slim(model) cfg['slim'] = slim + cfg['model'] = slim(model) + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'UnstructuredPruner': load_config(slim_cfg) slim = create(cfg.slim)