未验证 提交 a39cf8fc 编写于 作者: W wangguanzhong 提交者: GitHub

fix post quant (#7751)

上级 fb6d68da
......@@ -72,6 +72,8 @@ class Trainer(object):
self.amp_level = self.cfg.get('amp_level', 'O1')
self.custom_white_list = self.cfg.get('custom_white_list', None)
self.custom_black_list = self.cfg.get('custom_black_list', None)
if 'slim' in cfg and cfg['slim_type'] == 'PTQ':
self.cfg['TestDataset'] = create('TestDataset')()
# build data loader
capital_mode = self.mode.capitalize()
......
......@@ -83,9 +83,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
load_config(slim_cfg)
load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['slim'] = slim
cfg['model'] = slim(model)
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'UnstructuredPruner':
load_config(slim_cfg)
slim = create(cfg.slim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册