提交 67ae525e 编写于 作者: z37757's avatar z37757

删除backbone里的pretrained_model字段

上级 03802c7f
......@@ -8,7 +8,7 @@ Global:
# evaluation is run every 1260 iterations
eval_batch_step: [37800, 1260]
cal_metric_during_train: False
pretrained_model:
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
......@@ -23,8 +23,6 @@ Architecture:
Backbone:
name: ResNet_vd
layers: 50
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
Neck:
name: FPN_UNet
in_channels: [256, 512, 1024, 2048]
......
......@@ -16,8 +16,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle
from paddle import ParamAttr
import paddle.nn as nn
......@@ -27,8 +25,6 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
from ppocr.utils.logging import get_logger
__all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
......@@ -250,7 +246,6 @@ class ResNet_vd(nn.Layer):
layers=50,
dcn_stage=None,
out_indices=None,
pretrained_model=None,
**kwargs):
super(ResNet_vd, self).__init__()
......@@ -344,30 +339,6 @@ class ResNet_vd(nn.Layer):
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
if pretrained_model is not None:
self.load_pretrained_params(pretrained_model)
def load_pretrained_params(self, path):
logger = get_logger()
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
state_dict = self.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
self.set_state_dict(new_state_dict)
logger.info(f"loaded backbone pretrained_model successful from {path}")
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册