未验证 提交 480c12d6 编写于 作者: F Feng Ni 提交者: GitHub

fix name of pretrain weights in backbone for dcn (#2582)

* fix name of pretrain weights

* format, test=document_fix
上级 5e24f530
...@@ -21,7 +21,9 @@ import paddle.nn.functional as F ...@@ -21,7 +21,9 @@ import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn.initializer import Uniform from paddle.nn.initializer import Uniform
from ppdet.modeling.layers import DeformableConvV2 from paddle import ParamAttr
from paddle.nn.initializer import Constant
from paddle.vision.ops import DeformConv2D
from .name_adapter import NameAdapter from .name_adapter import NameAdapter
from ..shape_spec import ShapeSpec from ..shape_spec import ShapeSpec
...@@ -53,8 +55,9 @@ class ConvNormLayer(nn.Layer): ...@@ -53,8 +55,9 @@ class ConvNormLayer(nn.Layer):
assert norm_type in ['bn', 'sync_bn'] assert norm_type in ['bn', 'sync_bn']
self.norm_type = norm_type self.norm_type = norm_type
self.act = act self.act = act
self.dcn_v2 = dcn_v2
if not dcn_v2: if not self.dcn_v2:
self.conv = nn.Conv2D( self.conv = nn.Conv2D(
in_channels=ch_in, in_channels=ch_in,
out_channels=ch_out, out_channels=ch_out,
...@@ -62,25 +65,37 @@ class ConvNormLayer(nn.Layer): ...@@ -62,25 +65,37 @@ class ConvNormLayer(nn.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=paddle.ParamAttr(learning_rate=lr), weight_attr=ParamAttr(learning_rate=lr),
bias_attr=False) bias_attr=False)
else: else:
self.conv = DeformableConvV2( self.offset_channel = 2 * filter_size**2
self.mask_channel = filter_size**2
self.conv_offset = nn.Conv2D(
in_channels=ch_in,
out_channels=3 * filter_size**2,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
weight_attr=ParamAttr(initializer=Constant(0.)),
bias_attr=ParamAttr(initializer=Constant(0.)))
self.conv = DeformConv2D(
in_channels=ch_in, in_channels=ch_in,
out_channels=ch_out, out_channels=ch_out,
kernel_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
dilation=1,
groups=groups, groups=groups,
weight_attr=paddle.ParamAttr(learning_rate=lr), weight_attr=ParamAttr(learning_rate=lr),
bias_attr=False) bias_attr=False)
norm_lr = 0. if freeze_norm else lr norm_lr = 0. if freeze_norm else lr
param_attr = paddle.ParamAttr( param_attr = ParamAttr(
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
bias_attr = paddle.ParamAttr( bias_attr = ParamAttr(
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
...@@ -103,7 +118,17 @@ class ConvNormLayer(nn.Layer): ...@@ -103,7 +118,17 @@ class ConvNormLayer(nn.Layer):
param.stop_gradient = True param.stop_gradient = True
def forward(self, inputs): def forward(self, inputs):
out = self.conv(inputs) if not self.dcn_v2:
out = self.conv(inputs)
else:
offset_mask = self.conv_offset(inputs)
offset, mask = paddle.split(
offset_mask,
num_or_sections=[self.offset_channel, self.mask_channel],
axis=1)
mask = F.sigmoid(mask)
out = self.conv(inputs, offset, mask=mask)
if self.norm_type in ['bn', 'sync_bn']: if self.norm_type in ['bn', 'sync_bn']:
out = self.norm(out) out = self.norm(out)
if self.act: if self.act:
......
...@@ -157,7 +157,7 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -157,7 +157,7 @@ def load_pretrain_weight(model, pretrain_weight):
weights_path = path + '.pdparams' weights_path = path + '.pdparams'
param_state_dict = paddle.load(weights_path) param_state_dict = paddle.load(weights_path)
ignore_set = set() lack_backbone_weights_cnt = 0
lack_modules = set() lack_modules = set()
for name, weight in model_dict.items(): for name, weight in model_dict.items():
if name in param_state_dict.keys(): if name in param_state_dict.keys():
...@@ -168,7 +168,13 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -168,7 +168,13 @@ def load_pretrain_weight(model, pretrain_weight):
param_state_dict.pop(name, None) param_state_dict.pop(name, None)
else: else:
lack_modules.add(name.split('.')[0]) lack_modules.add(name.split('.')[0])
logger.debug('Lack weights: {}'.format(name)) if name.find('backbone') >= 0:
logger.info('Lack backbone weights: {}'.format(name))
lack_backbone_weights_cnt += 1
if lack_backbone_weights_cnt > 0:
logger.info('Lack {} weights in backbone.'.format(
lack_backbone_weights_cnt))
if len(lack_modules) > 0: if len(lack_modules) > 0:
logger.info('Lack weights of modules: {}'.format(', '.join( logger.info('Lack weights of modules: {}'.format(', '.join(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册