From d83bc9cd7592cae6863d6c957d9ae7cad3c198be Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Wed, 13 Apr 2022 10:05:16 +0800 Subject: [PATCH] [Cherry-Pick] Update Slim in PaddleDetection (#5670) --- ppdet/engine/trainer.py | 7 +++++++ ppdet/slim/__init__.py | 9 +++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 35e1b9710..622b9c3b2 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -103,6 +103,13 @@ class Trainer(object): if 'slim' in cfg and cfg['slim_type'] == 'OFA': self.model.model.load_meanstd(cfg['TestReader'][ 'sample_transforms']) + elif 'slim' in cfg and cfg['slim_type'] == 'Distill': + self.model.student_model.load_meanstd(cfg['TestReader'][ + 'sample_transforms']) + elif 'slim' in cfg and cfg[ + 'slim_type'] == 'DistillPrune' and self.mode == 'train': + self.model.student_model.load_meanstd(cfg['TestReader'][ + 'sample_transforms']) else: self.model.load_meanstd(cfg['TestReader']['sample_transforms']) diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index e71481d1c..8b343eb60 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -37,14 +37,15 @@ def build_slim_model(cfg, slim_cfg, mode='train'): if slim_load_cfg['slim'] == 'Distill': model = DistillModel(cfg, slim_cfg) cfg['model'] = model + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'OFA': load_config(slim_cfg) model = create(cfg.architecture) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) - cfg['slim_type'] = cfg.slim - cfg['model'] = slim(model, model.state_dict()) cfg['slim'] = slim + cfg['model'] = slim(model, model.state_dict()) + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'DistillPrune': if mode == 'train': model = DistillModel(cfg, slim_cfg) @@ -64,9 +65,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): load_config(slim_cfg) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) - cfg['slim_type'] = cfg.slim - cfg['model'] = slim(model) cfg['slim'] = slim + cfg['model'] = slim(model) + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'UnstructuredPruner': load_config(slim_cfg) slim = create(cfg.slim) -- GitLab