提交 6c193662 编写于 作者: L LDOUBLEV

fix comment

上级 a2cdabd0
......@@ -161,6 +161,12 @@ def main(config, device, logger, vdl_writer):
if config["Global"]["pretrained_model"] is not None:
pre_best_model_dict = load_model(config, model)
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build loss
loss_class = build_loss(config['Loss'])
......@@ -175,12 +181,6 @@ def main(config, device, logger, vdl_writer):
if config["Global"]["checkpoints"] is not None:
pre_best_model_dict = load_model(config, model, optimizer)
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build metric
eval_class = build_metric(config['Metric'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册