diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 941cfb36b291dcd1dbedbf51de5edd2cf0017167..1dffaab0eef35ec41c27c9c6e00f25dda048d490 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -118,6 +118,11 @@ def main(config, device, logger, vdl_writer): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) + pre_best_model_dict = dict() + # load fp32 model to begin quantization + if config["Global"]["pretrained_model"] is not None: + pre_best_model_dict = load_model(config, model) + quanter = QAT(config=quant_config, act_preprocess=PACT) quanter.quantize(model) @@ -134,10 +139,12 @@ def main(config, device, logger, vdl_writer): step_each_epoch=len(train_dataloader), parameters=model.parameters()) + # resume PACT training process + if config["Global"]["checkpoints"] is not None: + pre_best_model_dict = load_model(config, model, optimizer) + # build metric eval_class = build_metric(config['Metric']) - # load pretrain model - pre_best_model_dict = load_model(config, model, optimizer) logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader)))