From ae4167dc32fe4e35dffc90b48a80ee07cd1dc00f Mon Sep 17 00:00:00 2001 From: zhoujun Date: Fri, 12 Nov 2021 11:06:36 +0800 Subject: [PATCH] merge init_model and load_dygraph_params to load_model (#4623) * merge init_model and load_dygraph_params to load_model --- deploy/slim/prune/export_prune_model.py | 4 +- deploy/slim/prune/sensitivity_anal.py | 4 +- deploy/slim/quantization/export_model.py | 4 +- deploy/slim/quantization/quant.py | 4 +- deploy/slim/quantization/quant_kl.py | 2 +- .../architectures/distillation_model.py | 2 +- ppocr/utils/save_load.py | 76 +++++-------------- tools/eval.py | 4 +- tools/export_center.py | 4 +- tools/export_model.py | 4 +- tools/infer_cls.py | 4 +- tools/infer_det.py | 4 +- tools/infer_e2e.py | 4 +- tools/infer_rec.py | 8 +- tools/infer_table.py | 6 +- tools/train.py | 4 +- 16 files changed, 48 insertions(+), 90 deletions(-) diff --git a/deploy/slim/prune/export_prune_model.py b/deploy/slim/prune/export_prune_model.py index 29f7d211..2c9d0a18 100644 --- a/deploy/slim/prune/export_prune_model.py +++ b/deploy/slim/prune/export_prune_model.py @@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model import tools.program as program @@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer): logger.info(f"FLOPs after pruning: {flops}") # load pretrain model - pre_best_model_dict = init_model(config, model, logger, None) + load_model(config, model) metric = program.eval(model, valid_dataloader, post_process_class, eval_class) logger.info(f"metric['hmean']: {metric['hmean']}") diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index 0f0492af..c5d00877 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -32,7 +32,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model import tools.program as program dist.get_world_size() @@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, logger, optimizer) + pre_best_model_dict = load_model(config, model, optimizer) logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index d94e5303..dddae923 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -28,7 +28,7 @@ from paddle.jit import to_static from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser from ppocr.metrics import build_metric @@ -101,7 +101,7 @@ def main(): quanter = QAT(config=quant_config) quanter.quantize(model) - init_model(config, model) + load_model(config, model) model.eval() # build metric diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 37aab68a..941cfb36 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -37,7 +37,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model import tools.program as program from paddleslim.dygraph.quant import QAT @@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, logger, optimizer) + pre_best_model_dict = load_model(config, model, optimizer) logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) diff --git a/deploy/slim/quantization/quant_kl.py b/deploy/slim/quantization/quant_kl.py index d866784a..cc3a455b 100755 --- a/deploy/slim/quantization/quant_kl.py +++ b/deploy/slim/quantization/quant_kl.py @@ -37,7 +37,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model import tools.program as program import paddleslim from paddleslim.dygraph.quant import QAT diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 1e95fe57..5e867940 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone from ppocr.modeling.necks import build_neck from ppocr.modeling.heads import build_head from .base_model import BaseModel -from ppocr.utils.save_load import init_model, load_pretrained_params +from ppocr.utils.save_load import load_pretrained_params __all__ = ['DistillationModel'] diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index a7d24dd7..702f3e97 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -25,7 +25,7 @@ import paddle from ppocr.utils.logging import get_logger -__all__ = ['init_model', 'save_model', 'load_dygraph_params'] +__all__ = ['load_model'] def _mkdir_if_not_exist(path, logger): @@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def init_model(config, model, optimizer=None, lr_scheduler=None): +def load_model(config, model, optimizer=None): """ load model from checkpoint or pretrained_model """ @@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): pretrained_model = global_config.get('pretrained_model') best_model_dict = {} if checkpoints: - assert os.path.exists(checkpoints + ".pdparams"), \ - "Given dir {}.pdparams not exist.".format(checkpoints) + if checkpoints.endswith('pdparams'): + checkpoints = checkpoints.replace('.pdparams', '') assert os.path.exists(checkpoints + ".pdopt"), \ - "Given dir {}.pdopt not exist.".format(checkpoints) - para_dict = paddle.load(checkpoints + '.pdparams') - opti_dict = paddle.load(checkpoints + '.pdopt') - model.set_state_dict(para_dict) + f"The {checkpoints}.pdopt does not exists!" + load_pretrained_params(model, checkpoints) + optim_dict = paddle.load(checkpoints + '.pdopt') if optimizer is not None: - optimizer.set_state_dict(opti_dict) + optimizer.set_state_dict(optim_dict) if os.path.exists(checkpoints + '.states'): with open(checkpoints + '.states', 'rb') as f: @@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): best_model_dict['start_epoch'] = states_dict['epoch'] + 1 logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - if not isinstance(pretrained_model, list): - pretrained_model = [pretrained_model] - for pretrained in pretrained_model: - if not (os.path.isdir(pretrained) or - os.path.exists(pretrained + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(pretrained)) - param_state_dict = paddle.load(pretrained + '.pdparams') - model.set_state_dict(param_state_dict) - logger.info("load pretrained model from {}".format( - pretrained_model)) + load_pretrained_params(model, pretrained_model) else: logger.info('train from scratch') return best_model_dict -def load_dygraph_params(config, model, logger, optimizer): - ckp = config['Global']['checkpoints'] - if ckp and os.path.exists(ckp + ".pdparams"): - pre_best_model_dict = init_model(config, model, optimizer) - return pre_best_model_dict - else: - pm = config['Global']['pretrained_model'] - if pm is None: - return {} - if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"): - logger.info(f"The pretrained_model {pm} does not exists!") - return {} - pm = pm if pm.endswith('.pdparams') else pm + '.pdparams' - params = paddle.load(pm) - state_dict = model.state_dict() - new_state_dict = {} - for k1, k2 in zip(state_dict.keys(), params.keys()): - if list(state_dict[k1].shape) == list(params[k2].shape): - new_state_dict[k1] = params[k2] - else: - logger.info( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" - ) - model.set_state_dict(new_state_dict) - logger.info(f"loaded pretrained_model successful from {pm}") - return {} - - def load_pretrained_params(model, path): - if path is None: - return False - if not os.path.exists(path) and not os.path.exists(path + ".pdparams"): - print(f"The pretrained_model {path} does not exists!") - return False - - path = path if path.endswith('.pdparams') else path + '.pdparams' - params = paddle.load(path) + logger = get_logger() + if path.endswith('pdparams'): + path = path.replace('.pdparams', '') + assert os.path.exists(path + ".pdparams"), \ + f"The {path}.pdparams does not exists!" + + params = paddle.load(path + '.pdparams') state_dict = model.state_dict() new_state_dict = {} for k1, k2 in zip(state_dict.keys(), params.keys()): if list(state_dict[k1].shape) == list(params[k2].shape): new_state_dict[k1] = params[k2] else: - print( + logger.info( f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" ) model.set_state_dict(new_state_dict) - print(f"load pretrain successful from {path}") + logger.info(f"load pretrain successful from {path}") return model diff --git a/tools/eval.py b/tools/eval.py index 28247bc5..c85490a3 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -27,7 +27,7 @@ from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model, load_dygraph_params +from ppocr.utils.save_load import load_model from ppocr.utils.utility import print_dict import tools.program as program @@ -60,7 +60,7 @@ def main(): else: model_type = None - best_model_dict = load_dygraph_params(config, model, logger, None) + best_model_dict = load_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): diff --git a/tools/export_center.py b/tools/export_center.py index c46e8b9d..30b9c334 100644 --- a/tools/export_center.py +++ b/tools/export_center.py @@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model, load_dygraph_params +from ppocr.utils.save_load import load_model from ppocr.utils.utility import print_dict import tools.program as program @@ -57,7 +57,7 @@ def main(): model = build_model(config['Architecture']) - best_model_dict = load_dygraph_params(config, model, logger, None) + best_model_dict = load_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): diff --git a/tools/export_model.py b/tools/export_model.py index 64a0d403..9ed8e1b6 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -26,7 +26,7 @@ from paddle.jit import to_static from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser @@ -107,7 +107,7 @@ def main(): else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num model = build_model(config["Architecture"]) - init_model(config, model) + load_model(config, model) model.eval() save_path = config["Global"]["save_inference_dir"] diff --git a/tools/infer_cls.py b/tools/infer_cls.py index a588cab4..7522e439 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -32,7 +32,7 @@ import paddle from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -47,7 +47,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model) + load_model(config, model) # create data ops transforms = [] diff --git a/tools/infer_det.py b/tools/infer_det.py index ce16da8d..bb2cca73 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -34,7 +34,7 @@ import paddle from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model, load_dygraph_params +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -59,7 +59,7 @@ def main(): # build model model = build_model(config['Architecture']) - _ = load_dygraph_params(config, model, logger, None) + load_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess']) diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 1cd468b8..96dbac8e 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -34,7 +34,7 @@ import paddle from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -68,7 +68,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model) + load_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess'], diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 29d4b530..adc3c1c3 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -33,7 +33,7 @@ import paddle from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -58,7 +58,7 @@ def main(): model = build_model(config['Architecture']) - init_model(config, model) + load_model(config, model) # create data ops transforms = [] @@ -75,9 +75,7 @@ def main(): 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' ] elif config['Architecture']['algorithm'] == "SAR": - op[op_name]['keep_keys'] = [ - 'image', 'valid_ratio' - ] + op[op_name]['keep_keys'] = ['image', 'valid_ratio'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) diff --git a/tools/infer_table.py b/tools/infer_table.py index f743d875..c73e3840 100644 --- a/tools/infer_table.py +++ b/tools/infer_table.py @@ -34,11 +34,12 @@ from paddle.jit import to_static from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program import cv2 + def main(config, device, logger, vdl_writer): global_config = config['Global'] @@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer): model = build_model(config['Architecture']) - init_model(config, model, logger) + load_model(config, model) # create data ops transforms = [] @@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() main(config, device, logger, vdl_writer) - diff --git a/tools/train.py b/tools/train.py index d182af29..f3852469 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,7 +35,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model, load_dygraph_params +from ppocr.utils.save_load import load_model import tools.program as program dist.get_world_size() @@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) + pre_best_model_dict = load_model(config, model, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( -- GitLab