提交 6c193662 编写于 作者: L LDOUBLEV

fix comment

上级 a2cdabd0
...@@ -161,6 +161,12 @@ def main(config, device, logger, vdl_writer): ...@@ -161,6 +161,12 @@ def main(config, device, logger, vdl_writer):
if config["Global"]["pretrained_model"] is not None: if config["Global"]["pretrained_model"] is not None:
pre_best_model_dict = load_model(config, model) pre_best_model_dict = load_model(config, model)
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build loss # build loss
loss_class = build_loss(config['Loss']) loss_class = build_loss(config['Loss'])
...@@ -175,12 +181,6 @@ def main(config, device, logger, vdl_writer): ...@@ -175,12 +181,6 @@ def main(config, device, logger, vdl_writer):
if config["Global"]["checkpoints"] is not None: if config["Global"]["checkpoints"] is not None:
pre_best_model_dict = load_model(config, model, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册