diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md index f76ae7f842fb6b7002e084be59dc7ccb31f39771..4114d9f2e6c584566dbfc6d9280074d767848ce1 100644 --- a/doc/doc_ch/detection.md +++ b/doc/doc_ch/detection.md @@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 cd PaddleOCR/ # 根据backbone的不同选择下载对应的预训练模型 # 下载MobileNetV3的预训练模型 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams # 或,下载ResNet18_vd的预训练模型 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams # 或,下载ResNet50_vd的预训练模型 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams ``` diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md index a634dd4903483a819caee88cf6dd1781253e6f85..1be34b330f91b59060fc84af6f5ac44022da1a35 100644 --- a/doc/doc_en/detection_en.md +++ b/doc/doc_en/detection_en.md @@ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h ```shell cd PaddleOCR/ # Download the pre-trained model of MobileNetV3 -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams # or, download the pre-trained model of ResNet18_vd -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams # or, download the pre-trained model of ResNet50_vd -wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams ``` diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index f6013a406634ed110ea5af613a5f31e56ce90ead..0dd94e86c808ce4f77e27a5c819fccd59578f0c5 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -111,13 +111,16 @@ def load_pretrained_params(model, path): params = paddle.load(path + '.pdparams') state_dict = model.state_dict() new_state_dict = {} - for k1, k2 in zip(state_dict.keys(), params.keys()): - if list(state_dict[k1].shape) == list(params[k2].shape): - new_state_dict[k1] = params[k2] + for k1 in params.keys(): + if k1 not in state_dict.keys(): + logger.warning("The pretrained params {} not in model".format(k1)) else: - logger.warning( - "The shape of model params {} {} not matched with loaded params {} {} !". - format(k1, state_dict[k1].shape, k2, params[k2].shape)) + if list(state_dict[k1].shape) == list(params[k1].shape): + new_state_dict[k1] = params[k1] + else: + logger.warning( + "The shape of model params {} {} not matched with loaded params {} {} !". + format(k1, state_dict[k1].shape, k1, params[k1].shape)) model.set_state_dict(new_state_dict) logger.info("load pretrain successful from {}".format(path)) return model