未验证 提交 a5944b8c 编写于 作者: C Chang Xu 提交者: GitHub

fix_ofa_demo (#5607)

上级 bbece395
...@@ -37,14 +37,15 @@ def build_slim_model(cfg, slim_cfg, mode='train'): ...@@ -37,14 +37,15 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
if slim_load_cfg['slim'] == 'Distill': if slim_load_cfg['slim'] == 'Distill':
model = DistillModel(cfg, slim_cfg) model = DistillModel(cfg, slim_cfg)
cfg['model'] = model cfg['model'] = model
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'OFA': elif slim_load_cfg['slim'] == 'OFA':
load_config(slim_cfg) load_config(slim_cfg)
model = create(cfg.architecture) model = create(cfg.architecture)
load_pretrain_weight(model, cfg.weights) load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim) slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model, model.state_dict())
cfg['slim'] = slim cfg['slim'] = slim
cfg['model'] = slim(model, model.state_dict())
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'DistillPrune': elif slim_load_cfg['slim'] == 'DistillPrune':
if mode == 'train': if mode == 'train':
model = DistillModel(cfg, slim_cfg) model = DistillModel(cfg, slim_cfg)
...@@ -64,9 +65,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): ...@@ -64,9 +65,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
load_config(slim_cfg) load_config(slim_cfg)
load_pretrain_weight(model, cfg.weights) load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim) slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model)
cfg['slim'] = slim cfg['slim'] = slim
cfg['model'] = slim(model)
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'UnstructuredPruner': elif slim_load_cfg['slim'] == 'UnstructuredPruner':
load_config(slim_cfg) load_config(slim_cfg)
slim = create(cfg.slim) slim = create(cfg.slim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册