未验证 提交 b39d5ce1 编写于 作者: L littletomatodonkey 提交者: GitHub

fix cbnet (#785)

Add `lr_multi_list` param for SSLD pretrained models' usage.
上级 1c23743b
......@@ -65,7 +65,8 @@ class CBResNet(object):
feature_maps=[2, 3, 4, 5],
dcn_v2_stages=[],
nonlocal_stages=[],
repeat_num=2):
repeat_num=2,
lr_mult_list=[1., 1., 1., 1.]):
super(CBResNet, self).__init__()
if isinstance(feature_maps, Integral):
......@@ -108,6 +109,9 @@ class CBResNet(object):
200: 12,
}
self.lr_mult_list = lr_mult_list
self.stage_num = -1
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
......@@ -143,6 +147,13 @@ class CBResNet(object):
act=None,
name=None,
dcn=False):
# need fine lr for distilled model, default as 1.0
lr_mult = 1.0
mult_idx = max(self.stage_num - 2, 0)
mult_idx = min(self.stage_num - 2, 3)
lr_mult = self.lr_mult_list[mult_idx]
if not dcn:
conv = fluid.layers.conv2d(
input=input,
......@@ -153,7 +164,8 @@ class CBResNet(object):
groups=groups,
act=None,
param_attr=ParamAttr(
name=name + "_weights_" + str(self.curr_level)),
name=name + "_weights_" + str(self.curr_level),
learning_rate=lr_mult),
bias_attr=False)
else:
offset_mask = self._conv_offset(
......@@ -182,12 +194,13 @@ class CBResNet(object):
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(
name=name + "_weights_" + str(self.curr_level)),
name=name + "_weights_" + str(self.curr_level),
learning_rate=lr_mult),
bias_attr=False)
bn_name = self.na.fix_conv_norm_name(name)
norm_lr = 0. if self.freeze_norm else 1.
norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale_' + str(self.curr_level),
......@@ -315,6 +328,8 @@ class CBResNet(object):
"""
assert stage_num in [2, 3, 4, 5]
self.stage_num = stage_num
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
......@@ -422,6 +437,7 @@ class CBResNet(object):
res_endpoints.append(res)
for num in range(1, self.repeat_num):
self.stage_num = -1
self.curr_level = num
res = self.c1_stage(input)
for i in range(len(res_endpoints)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册