未验证 提交 f9a2b26a 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix quant logic (#5806)

* fix quant logic

* fix undef

* fix doc
上级 3d692957
...@@ -118,6 +118,11 @@ def main(config, device, logger, vdl_writer): ...@@ -118,6 +118,11 @@ def main(config, device, logger, vdl_writer):
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
pre_best_model_dict = dict()
# load fp32 model to begin quantization
if config["Global"]["pretrained_model"] is not None:
pre_best_model_dict = load_model(config, model)
quanter = QAT(config=quant_config, act_preprocess=PACT) quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model) quanter.quantize(model)
...@@ -134,10 +139,12 @@ def main(config, device, logger, vdl_writer): ...@@ -134,10 +139,12 @@ def main(config, device, logger, vdl_writer):
step_each_epoch=len(train_dataloader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) parameters=model.parameters())
# resume PACT training process
if config["Global"]["checkpoints"] is not None:
pre_best_model_dict = load_model(config, model, optimizer)
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册