From a5944b8cfe6ef3a6e00b9028fd1d40391832062e Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Thu, 7 Apr 2022 09:54:44 +0800 Subject: [PATCH] fix_ofa_demo (#5607) --- ppdet/slim/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index e71481d1c..8b343eb60 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -37,14 +37,15 @@ def build_slim_model(cfg, slim_cfg, mode='train'): if slim_load_cfg['slim'] == 'Distill': model = DistillModel(cfg, slim_cfg) cfg['model'] = model + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'OFA': load_config(slim_cfg) model = create(cfg.architecture) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) - cfg['slim_type'] = cfg.slim - cfg['model'] = slim(model, model.state_dict()) cfg['slim'] = slim + cfg['model'] = slim(model, model.state_dict()) + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'DistillPrune': if mode == 'train': model = DistillModel(cfg, slim_cfg) @@ -64,9 +65,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): load_config(slim_cfg) load_pretrain_weight(model, cfg.weights) slim = create(cfg.slim) - cfg['slim_type'] = cfg.slim - cfg['model'] = slim(model) cfg['slim'] = slim + cfg['model'] = slim(model) + cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'UnstructuredPruner': load_config(slim_cfg) slim = create(cfg.slim) -- GitLab