From f9a2b26aa667be938994516d75931021b83eb5ba Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 28 Mar 2022 19:00:35 +0800 Subject: [PATCH] fix quant logic (#5806) * fix quant logic * fix undef * fix doc --- deploy/slim/quantization/quant.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 941cfb36..1dffaab0 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -118,6 +118,11 @@ def main(config, device, logger, vdl_writer): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) + pre_best_model_dict = dict() + # load fp32 model to begin quantization + if config["Global"]["pretrained_model"] is not None: + pre_best_model_dict = load_model(config, model) + quanter = QAT(config=quant_config, act_preprocess=PACT) quanter.quantize(model) @@ -134,10 +139,12 @@ def main(config, device, logger, vdl_writer): step_each_epoch=len(train_dataloader), parameters=model.parameters()) + # resume PACT training process + if config["Global"]["checkpoints"] is not None: + pre_best_model_dict = load_model(config, model, optimizer) + # build metric eval_class = build_metric(config['Metric']) - # load pretrain model - pre_best_model_dict = load_model(config, model, optimizer) logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) -- GitLab