diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index 255a8859d043e10e5587ca5abb5dda97fc39ea07..d929807c3ded78cad9e704fded69273d6020a566 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -47,6 +47,7 @@ def build_slim_model(cfg, slim_cfg, mode='train'): model = pruner(model) load_pretrain_weight(model, weights) cfg['model'] = model + cfg['slim_type'] = cfg.slim else: load_config(slim_cfg) model = create(cfg.architecture)