From 4a6f7ceca6314dec02d0c0159cfddfb28c704ca8 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Wed, 24 Nov 2021 09:40:33 +0000 Subject: [PATCH] pair param with key when load trained model params --- ppocr/utils/save_load.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index bf973c90..7bdaafd5 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -57,22 +57,23 @@ def load_model(config, model, optimizer=None): if checkpoints.endswith('pdparams'): checkpoints = checkpoints.replace('.pdparams', '') assert os.path.exists(checkpoints + ".pdparams"), \ - f"The {checkpoints}.pdparams does not exists!" - + "The {}.pdparams does not exists!".format(checkpoints) + # load params from trained model params = paddle.load(checkpoints + '.pdparams') state_dict = model.state_dict() new_state_dict = {} for key, value in state_dict.items(): if key not in params: - logger.warning(f"{key} not in loaded params {params.keys()} !") + logger.warning("{} not in loaded params {} !".format( + key, params.keys())) pre_value = params[key] if list(value.shape) == list(pre_value.shape): new_state_dict[key] = pre_value else: logger.warning( - f"The shape of model params {key} {value.shape} not matched with loaded params shape {pre_value.shape} !" - ) + "The shape of model params {} {} not matched with loaded params shape {} !". + format(key, value.shape, pre_value.shape)) model.set_state_dict(new_state_dict) optim_dict = paddle.load(checkpoints + '.pdopt') @@ -99,7 +100,7 @@ def load_pretrained_params(model, path): if path.endswith('pdparams'): path = path.replace('.pdparams', '') assert os.path.exists(path + ".pdparams"), \ - f"The {path}.pdparams does not exists!" + "The {}.pdparams does not exists!".format(path) params = paddle.load(path + '.pdparams') state_dict = model.state_dict() @@ -109,10 +110,10 @@ def load_pretrained_params(model, path): new_state_dict[k1] = params[k2] else: logger.warning( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" - ) + "The shape of model params {} {} not matched with loaded params {} {} !". + format(k1, state_dict[k1].shape, k2, params[k2].shape)) model.set_state_dict(new_state_dict) - logger.info(f"load pretrain successful from {path}") + logger.info("load pretrain successful from {}".format(path)) return model -- GitLab