提交 67ae525e 编写于 作者: z37757's avatar z37757

删除backbone里的pretrained_model字段

上级 03802c7f
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 1260 iterations # evaluation is run every 1260 iterations
eval_batch_step: [37800, 1260] eval_batch_step: [37800, 1260]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
...@@ -23,8 +23,6 @@ Architecture: ...@@ -23,8 +23,6 @@ Architecture:
Backbone: Backbone:
name: ResNet_vd name: ResNet_vd
layers: 50 layers: 50
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
Neck: Neck:
name: FPN_UNet name: FPN_UNet
in_channels: [256, 512, 1024, 2048] in_channels: [256, 512, 1024, 2048]
......
...@@ -16,8 +16,6 @@ from __future__ import absolute_import ...@@ -16,8 +16,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
...@@ -27,8 +25,6 @@ from paddle.vision.ops import DeformConv2D ...@@ -27,8 +25,6 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform from paddle.nn.initializer import Normal, Constant, XavierUniform
from ppocr.utils.logging import get_logger
__all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"] __all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
...@@ -250,7 +246,6 @@ class ResNet_vd(nn.Layer): ...@@ -250,7 +246,6 @@ class ResNet_vd(nn.Layer):
layers=50, layers=50,
dcn_stage=None, dcn_stage=None,
out_indices=None, out_indices=None,
pretrained_model=None,
**kwargs): **kwargs):
super(ResNet_vd, self).__init__() super(ResNet_vd, self).__init__()
...@@ -344,30 +339,6 @@ class ResNet_vd(nn.Layer): ...@@ -344,30 +339,6 @@ class ResNet_vd(nn.Layer):
self.out_channels.append(num_filters[block]) self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list)) self.stages.append(nn.Sequential(*block_list))
if pretrained_model is not None:
self.load_pretrained_params(pretrained_model)
def load_pretrained_params(self, path):
logger = get_logger()
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
state_dict = self.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
self.set_state_dict(new_state_dict)
logger.info(f"loaded backbone pretrained_model successful from {path}")
def forward(self, inputs): def forward(self, inputs):
y = self.conv1_1(inputs) y = self.conv1_1(inputs)
y = self.conv1_2(y) y = self.conv1_2(y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册