diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 702f3e9770d0572e9128357bfd6b39199566a959..4b890f6fa352772e6ebe1614b798e1ce69cdd17c 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -54,11 +54,28 @@ def load_model(config, model, optimizer=None): pretrained_model = global_config.get('pretrained_model') best_model_dict = {} if checkpoints: - if checkpoints.endswith('pdparams'): + if checkpoints.endswith('.pdparams'): checkpoints = checkpoints.replace('.pdparams', '') - assert os.path.exists(checkpoints + ".pdopt"), \ - f"The {checkpoints}.pdopt does not exists!" - load_pretrained_params(model, checkpoints) + assert os.path.exists(checkpoints + ".pdparams"), \ + "The {}.pdparams does not exists!".format(checkpoints) + + # load params from trained model + params = paddle.load(checkpoints + '.pdparams') + state_dict = model.state_dict() + new_state_dict = {} + for key, value in state_dict.items(): + if key not in params: + logger.warning("{} not in loaded params {} !".format( + key, params.keys())) + pre_value = params[key] + if list(value.shape) == list(pre_value.shape): + new_state_dict[key] = pre_value + else: + logger.warning( + "The shape of model params {} {} not matched with loaded params shape {} !". + format(key, value.shape, pre_value.shape)) + model.set_state_dict(new_state_dict) + optim_dict = paddle.load(checkpoints + '.pdopt') if optimizer is not None: optimizer.set_state_dict(optim_dict) @@ -80,10 +97,10 @@ def load_model(config, model, optimizer=None): def load_pretrained_params(model, path): logger = get_logger() - if path.endswith('pdparams'): + if path.endswith('.pdparams'): path = path.replace('.pdparams', '') assert os.path.exists(path + ".pdparams"), \ - f"The {path}.pdparams does not exists!" + "The {}.pdparams does not exists!".format(path) params = paddle.load(path + '.pdparams') state_dict = model.state_dict() @@ -92,11 +109,11 @@ def load_pretrained_params(model, path): if list(state_dict[k1].shape) == list(params[k2].shape): new_state_dict[k1] = params[k2] else: - logger.info( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" - ) + logger.warning( + "The shape of model params {} {} not matched with loaded params {} {} !". + format(k1, state_dict[k1].shape, k2, params[k2].shape)) model.set_state_dict(new_state_dict) - logger.info(f"load pretrain successful from {path}") + logger.info("load pretrain successful from {}".format(path)) return model