未验证 提交 d83bc9cd 编写于 作者: C Chang Xu 提交者: GitHub

[Cherry-Pick] Update Slim in PaddleDetection (#5670)

上级 1c8b0368
......@@ -103,6 +103,13 @@ class Trainer(object):
if 'slim' in cfg and cfg['slim_type'] == 'OFA':
self.model.model.load_meanstd(cfg['TestReader'][
'sample_transforms'])
elif 'slim' in cfg and cfg['slim_type'] == 'Distill':
self.model.student_model.load_meanstd(cfg['TestReader'][
'sample_transforms'])
elif 'slim' in cfg and cfg[
'slim_type'] == 'DistillPrune' and self.mode == 'train':
self.model.student_model.load_meanstd(cfg['TestReader'][
'sample_transforms'])
else:
self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
......
......@@ -37,14 +37,15 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
if slim_load_cfg['slim'] == 'Distill':
model = DistillModel(cfg, slim_cfg)
cfg['model'] = model
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'OFA':
load_config(slim_cfg)
model = create(cfg.architecture)
load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model, model.state_dict())
cfg['slim'] = slim
cfg['model'] = slim(model, model.state_dict())
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'DistillPrune':
if mode == 'train':
model = DistillModel(cfg, slim_cfg)
......@@ -64,9 +65,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
load_config(slim_cfg)
load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model)
cfg['slim'] = slim
cfg['model'] = slim(model)
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'UnstructuredPruner':
load_config(slim_cfg)
slim = create(cfg.slim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册