未验证 提交 a11fbc0f 编写于 作者: A andyjpaddle 提交者: GitHub

Merge pull request #5166 from tink2123/fix_save_load

fix save load, update det pretrain
...@@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ...@@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
cd PaddleOCR/ cd PaddleOCR/
# 根据backbone的不同选择下载对应的预训练模型 # 根据backbone的不同选择下载对应的预训练模型
# 下载MobileNetV3的预训练模型 # 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
# 或,下载ResNet18_vd的预训练模型 # 或,下载ResNet18_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
# 或,下载ResNet50_vd的预训练模型 # 或,下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
``` ```
<a name="2-----"></a> <a name="2-----"></a>
......
...@@ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h ...@@ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h
```shell ```shell
cd PaddleOCR/ cd PaddleOCR/
# Download the pre-trained model of MobileNetV3 # Download the pre-trained model of MobileNetV3
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
# or, download the pre-trained model of ResNet18_vd # or, download the pre-trained model of ResNet18_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
# or, download the pre-trained model of ResNet50_vd # or, download the pre-trained model of ResNet50_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
``` ```
......
...@@ -111,13 +111,16 @@ def load_pretrained_params(model, path): ...@@ -111,13 +111,16 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()): for k1 in params.keys():
if list(state_dict[k1].shape) == list(params[k2].shape): if k1 not in state_dict.keys():
new_state_dict[k1] = params[k2] logger.warning("The pretrained params {} not in model".format(k1))
else: else:
logger.warning( if list(state_dict[k1].shape) == list(params[k1].shape):
"The shape of model params {} {} not matched with loaded params {} {} !". new_state_dict[k1] = params[k1]
format(k1, state_dict[k1].shape, k2, params[k2].shape)) else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path)) logger.info("load pretrain successful from {}".format(path))
return model return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册