From ee3df4c28e2754469fdc4141dfe7f146118d35dd Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sun, 4 Jul 2021 01:02:41 +0800 Subject: [PATCH] add support for quant distillation --- deploy/slim/quantization/export_model.py | 47 +++++++++++++++++------- deploy/slim/quantization/quant.py | 13 +++++-- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 100b107a..fe843d41 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT from ppocr.data import build_dataloader +def export_single_model(quanter, model, infer_shape, save_path, logger): + quanter.save_quantized_model( + model, + save_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + logger.info('inference QAT model is saved to {}'.format(save_path)) + + def main(): ############################################################################################################ # 1. quantization configs @@ -76,7 +87,14 @@ def main(): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) # get QAT model @@ -93,24 +111,27 @@ def main(): valid_dataloader = build_dataloader(config, 'Eval', device, logger) # start eval - metirc = program.eval(model, valid_dataloader, post_process_class, - eval_class) + model_type = config['Architecture']['model_type'] + metric = program.eval(model, valid_dataloader, post_process_class, + eval_class, model_type) logger.info('metric eval ***************') - for k, v in metirc.items(): + for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) infer_shape = [3, 32, 100] if config['Architecture'][ 'model_type'] != "det" else [3, 640, 640] - quanter.save_quantized_model( - model, - save_path, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) - logger.info('inference QAT model is saved to {}'.format(save_path)) + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(quanter, model.model_list[idx], infer_shape, + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(quanter, model, infer_shape, save_path, logger) if __name__ == "__main__": diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 315e3b43..37aab68a 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) + quanter = QAT(config=quant_config, act_preprocess=PACT) + quanter.quantize(model) + if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer): logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) - quanter = QAT(config=quant_config, act_preprocess=PACT) - quanter.quantize(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, -- GitLab