From be181cb3bd3fb99fd1f4be4b5e94c12e80095fd8 Mon Sep 17 00:00:00 2001 From: Double_V Date: Mon, 28 Jun 2021 20:44:06 +0800 Subject: [PATCH] add new load dygraph func (#3088) * add new load dygraph func * update load_pretrain_params * update load_dygrah_params * Update save_load.py * Update train.py * Update save_load.py * return {} when path is None * return {} when path is None --- ppocr/utils/save_load.py | 30 +++++++++++++++++++++++++++++- tools/train.py | 4 ++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 23f5401b..76420abb 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -25,7 +25,7 @@ import paddle from ppocr.utils.logging import get_logger -__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] +__all__ = ['init_model', 'save_model', 'load_dygraph_params'] def _mkdir_if_not_exist(path, logger): @@ -89,6 +89,34 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): return best_model_dict +def load_dygraph_params(config, model, logger, optimizer): + ckp = config['Global']['checkpoints'] + if ckp and os.path.exists(ckp): + pre_best_model_dict = init_model(config, model, optimizer) + return pre_best_model_dict + else: + pm = config['Global']['pretrained_model'] + if pm is None: + return {} + if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"): + logger.info(f"The pretrained_model {pm} does not exists!") + return {} + pm = pm if pm.endswith('.pdparams') else pm + '.pdparams' + params = paddle.load(pm) + state_dict = model.state_dict() + new_state_dict = {} + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) + model.set_state_dict(new_state_dict) + logger.info(f"loaded pretrained_model successful from {pm}") + return {} + + def save_model(model, optimizer, model_path, diff --git a/tools/train.py b/tools/train.py index b024240b..20f5a670 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,7 +35,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_dygraph_params import tools.program as program dist.get_world_size() @@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, optimizer) + pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: -- GitLab