diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 35e1b97102bf3c1650a9907f33d0d938b0288c60..622b9c3b290cb53fc40c4d07341ac115e8b73922 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -103,6 +103,13 @@ class Trainer(object): if 'slim' in cfg and cfg['slim_type'] == 'OFA': self.model.model.load_meanstd(cfg['TestReader'][ 'sample_transforms']) + elif 'slim' in cfg and cfg['slim_type'] == 'Distill': + self.model.student_model.load_meanstd(cfg['TestReader'][ + 'sample_transforms']) + elif 'slim' in cfg and cfg[ + 'slim_type'] == 'DistillPrune' and self.mode == 'train': + self.model.student_model.load_meanstd(cfg['TestReader'][ + 'sample_transforms']) else: self.model.load_meanstd(cfg['TestReader']['sample_transforms'])