diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 23f5401bb71a2ef50ff2ff2c3c27275d7e10b3c0..76420abb5a0da3e0138478c34bdb53d593492bf4 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -25,7 +25,7 @@ import paddle from ppocr.utils.logging import get_logger -__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] +__all__ = ['init_model', 'save_model', 'load_dygraph_params'] def _mkdir_if_not_exist(path, logger): @@ -89,6 +89,34 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): return best_model_dict +def load_dygraph_params(config, model, logger, optimizer): + ckp = config['Global']['checkpoints'] + if ckp and os.path.exists(ckp): + pre_best_model_dict = init_model(config, model, optimizer) + return pre_best_model_dict + else: + pm = config['Global']['pretrained_model'] + if pm is None: + return {} + if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"): + logger.info(f"The pretrained_model {pm} does not exists!") + return {} + pm = pm if pm.endswith('.pdparams') else pm + '.pdparams' + params = paddle.load(pm) + state_dict = model.state_dict() + new_state_dict = {} + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) + model.set_state_dict(new_state_dict) + logger.info(f"loaded pretrained_model successful from {pm}") + return {} + + def save_model(model, optimizer, model_path, diff --git a/tools/train.py b/tools/train.py index b024240b4d5d4973645336c62d3762087ec7bbeb..20f5a670d5c8e666678259e0042b3b790e528590 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,7 +35,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_dygraph_params import tools.program as program dist.get_world_size() @@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, optimizer) + pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: