diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 942f59e000a24faef72844b45e833fc55061d2ac..e9c1a8d31110ef20dd66be28d78b1e866fcd85ae 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT from ppocr.data import build_dataloader +def export_single_model(quanter, model, infer_shape, save_path, logger): + quanter.save_quantized_model( + model, + save_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + logger.info('inference QAT model is saved to {}'.format(save_path)) + + def main(): ############################################################################################################ # 1. quantization configs @@ -76,7 +87,14 @@ def main(): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) # get QAT model @@ -97,22 +115,25 @@ def main(): # start eval metirc = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, use_srn) + logger.info('metric eval ***************') - for k, v in metirc.items(): + for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) infer_shape = [3, 32, 100] if config['Architecture'][ 'model_type'] != "det" else [3, 640, 640] - quanter.save_quantized_model( - model, - save_path, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) - logger.info('inference QAT model is saved to {}'.format(save_path)) + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(quanter, model.model_list[idx], infer_shape, + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(quanter, model, infer_shape, save_path, logger) if __name__ == "__main__": diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 315e3b4321a544e77795c43d493873fcf46e1930..37aab68a0e88afce54e10fb6248c73684b58d808 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) + quanter = QAT(config=quant_config, act_preprocess=PACT) + quanter.quantize(model) + if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer): logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) - quanter = QAT(config=quant_config, act_preprocess=PACT) - quanter.quantize(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 76420abb5a0da3e0138478c34bdb53d593492bf4..1d760e983a635dcc6b48b839ee99434c67b4378d 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -91,14 +91,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): def load_dygraph_params(config, model, logger, optimizer): ckp = config['Global']['checkpoints'] - if ckp and os.path.exists(ckp): + if ckp and os.path.exists(ckp + ".pdparams"): pre_best_model_dict = init_model(config, model, optimizer) return pre_best_model_dict else: pm = config['Global']['pretrained_model'] if pm is None: return {} - if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"): + if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"): logger.info(f"The pretrained_model {pm} does not exists!") return {} pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'