diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 702f3e9770d0572e9128357bfd6b39199566a959..bf973c90eb5b8c3ef3b3fbda669c6d0f40cfa9ad 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -56,9 +56,25 @@ def load_model(config, model, optimizer=None): if checkpoints: if checkpoints.endswith('pdparams'): checkpoints = checkpoints.replace('.pdparams', '') - assert os.path.exists(checkpoints + ".pdopt"), \ - f"The {checkpoints}.pdopt does not exists!" - load_pretrained_params(model, checkpoints) + assert os.path.exists(checkpoints + ".pdparams"), \ + f"The {checkpoints}.pdparams does not exists!" + + # load params from trained model + params = paddle.load(checkpoints + '.pdparams') + state_dict = model.state_dict() + new_state_dict = {} + for key, value in state_dict.items(): + if key not in params: + logger.warning(f"{key} not in loaded params {params.keys()} !") + pre_value = params[key] + if list(value.shape) == list(pre_value.shape): + new_state_dict[key] = pre_value + else: + logger.warning( + f"The shape of model params {key} {value.shape} not matched with loaded params shape {pre_value.shape} !" + ) + model.set_state_dict(new_state_dict) + optim_dict = paddle.load(checkpoints + '.pdopt') if optimizer is not None: optimizer.set_state_dict(optim_dict) @@ -92,7 +108,7 @@ def load_pretrained_params(model, path): if list(state_dict[k1].shape) == list(params[k2].shape): new_state_dict[k1] = params[k2] else: - logger.info( + logger.warning( f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" ) model.set_state_dict(new_state_dict)