未验证 提交 ce0d2bdb 编写于 作者: C Chang Xu 提交者: GitHub

update_slim_in_trainer (#5622)

上级 0c4166af
......@@ -103,6 +103,13 @@ class Trainer(object):
if 'slim' in cfg and cfg['slim_type'] == 'OFA':
self.model.model.load_meanstd(cfg['TestReader'][
'sample_transforms'])
elif 'slim' in cfg and cfg['slim_type'] == 'Distill':
self.model.student_model.load_meanstd(cfg['TestReader'][
'sample_transforms'])
elif 'slim' in cfg and cfg[
'slim_type'] == 'DistillPrune' and self.mode == 'train':
self.model.student_model.load_meanstd(cfg['TestReader'][
'sample_transforms'])
else:
self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册