From c42f53d506d612bfc909eb4c40d43839f95c812a Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 5 Jan 2022 15:35:12 +0800 Subject: [PATCH] fix save load, update det pretrain --- doc/doc_ch/detection.md | 6 +++--- doc/doc_en/detection_en.md | 6 +++--- ppocr/utils/save_load.py | 15 +++++++++------ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md index f76ae7f8..4114d9f2 100644 --- a/doc/doc_ch/detection.md +++ b/doc/doc_ch/detection.md @@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 cd PaddleOCR/ # 根据backbone的不同选择下载对应的预训练模型 # 下载MobileNetV3的预训练模型 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams # 或,下载ResNet18_vd的预训练模型 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams # 或,下载ResNet50_vd的预训练模型 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams ``` diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md index a634dd49..1be34b33 100644 --- a/doc/doc_en/detection_en.md +++ b/doc/doc_en/detection_en.md @@ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h ```shell cd PaddleOCR/ # Download the pre-trained model of MobileNetV3 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams # or, download the pre-trained model of ResNet18_vd -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams # or, download the pre-trained model of ResNet50_vd -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams ``` diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index f6013a40..0dd94e86 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -111,13 +111,16 @@ def load_pretrained_params(model, path): params = paddle.load(path + '.pdparams') state_dict = model.state_dict() new_state_dict = {} - for k1, k2 in zip(state_dict.keys(), params.keys()): - if list(state_dict[k1].shape) == list(params[k2].shape): - new_state_dict[k1] = params[k2] + for k1 in params.keys(): + if k1 not in state_dict.keys(): + logger.warning("The pretrained params {} not in model".format(k1)) else: - logger.warning( - "The shape of model params {} {} not matched with loaded params {} {} !". - format(k1, state_dict[k1].shape, k2, params[k2].shape)) + if list(state_dict[k1].shape) == list(params[k1].shape): + new_state_dict[k1] = params[k1] + else: + logger.warning( + "The shape of model params {} {} not matched with loaded params {} {} !". + format(k1, state_dict[k1].shape, k1, params[k1].shape)) model.set_state_dict(new_state_dict) logger.info("load pretrain successful from {}".format(path)) return model -- GitLab