未验证 提交 b39d5ce1 编写于 作者: L littletomatodonkey 提交者: GitHub

fix cbnet (#785)

Add `lr_multi_list` param for SSLD pretrained models' usage.
上级 1c23743b
...@@ -65,7 +65,8 @@ class CBResNet(object): ...@@ -65,7 +65,8 @@ class CBResNet(object):
feature_maps=[2, 3, 4, 5], feature_maps=[2, 3, 4, 5],
dcn_v2_stages=[], dcn_v2_stages=[],
nonlocal_stages=[], nonlocal_stages=[],
repeat_num=2): repeat_num=2,
lr_mult_list=[1., 1., 1., 1.]):
super(CBResNet, self).__init__() super(CBResNet, self).__init__()
if isinstance(feature_maps, Integral): if isinstance(feature_maps, Integral):
...@@ -108,6 +109,9 @@ class CBResNet(object): ...@@ -108,6 +109,9 @@ class CBResNet(object):
200: 12, 200: 12,
} }
self.lr_mult_list = lr_mult_list
self.stage_num = -1
self.stage_filters = [64, 128, 256, 512] self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64 self._c1_out_chan_num = 64
self.na = NameAdapter(self) self.na = NameAdapter(self)
...@@ -143,6 +147,13 @@ class CBResNet(object): ...@@ -143,6 +147,13 @@ class CBResNet(object):
act=None, act=None,
name=None, name=None,
dcn=False): dcn=False):
# need fine lr for distilled model, default as 1.0
lr_mult = 1.0
mult_idx = max(self.stage_num - 2, 0)
mult_idx = min(self.stage_num - 2, 3)
lr_mult = self.lr_mult_list[mult_idx]
if not dcn: if not dcn:
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
...@@ -153,7 +164,8 @@ class CBResNet(object): ...@@ -153,7 +164,8 @@ class CBResNet(object):
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name + "_weights_" + str(self.curr_level)), name=name + "_weights_" + str(self.curr_level),
learning_rate=lr_mult),
bias_attr=False) bias_attr=False)
else: else:
offset_mask = self._conv_offset( offset_mask = self._conv_offset(
...@@ -182,12 +194,13 @@ class CBResNet(object): ...@@ -182,12 +194,13 @@ class CBResNet(object):
deformable_groups=1, deformable_groups=1,
im2col_step=1, im2col_step=1,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name + "_weights_" + str(self.curr_level)), name=name + "_weights_" + str(self.curr_level),
learning_rate=lr_mult),
bias_attr=False) bias_attr=False)
bn_name = self.na.fix_conv_norm_name(name) bn_name = self.na.fix_conv_norm_name(name)
norm_lr = 0. if self.freeze_norm else 1. norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay norm_decay = self.norm_decay
pattr = ParamAttr( pattr = ParamAttr(
name=bn_name + '_scale_' + str(self.curr_level), name=bn_name + '_scale_' + str(self.curr_level),
...@@ -315,6 +328,8 @@ class CBResNet(object): ...@@ -315,6 +328,8 @@ class CBResNet(object):
""" """
assert stage_num in [2, 3, 4, 5] assert stage_num in [2, 3, 4, 5]
self.stage_num = stage_num
stages, block_func = self.depth_cfg[self.depth] stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2] count = stages[stage_num - 2]
...@@ -422,6 +437,7 @@ class CBResNet(object): ...@@ -422,6 +437,7 @@ class CBResNet(object):
res_endpoints.append(res) res_endpoints.append(res)
for num in range(1, self.repeat_num): for num in range(1, self.repeat_num):
self.stage_num = -1
self.curr_level = num self.curr_level = num
res = self.c1_stage(input) res = self.c1_stage(input)
for i in range(len(res_endpoints)): for i in range(len(res_endpoints)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册