提交 bcad7800 编写于 作者: G Guanghua Yu 提交者: qingqing01

Fix resnet.py for ResNet18 and ResNet34 (#2532)

* Fix ResNet18 and ResNet34 config.
上级 fc6abdd2
...@@ -152,8 +152,7 @@ class ResNet(object): ...@@ -152,8 +152,7 @@ class ResNet(object):
ch_in = input.shape[1] ch_in = input.shape[1]
# the naming rule is same as pretrained weight # the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name) name = self.na.fix_shortcut_name(name)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if ch_in != ch_out or stride != 1:
if max_pooling_in_short_cut and not is_first: if max_pooling_in_short_cut and not is_first:
input = fluid.layers.pool2d( input = fluid.layers.pool2d(
input=input, input=input,
...@@ -252,6 +251,8 @@ class ResNet(object): ...@@ -252,6 +251,8 @@ class ResNet(object):
conv = input conv = input
for i in range(count): for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i) conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func( conv = block_func(
input=conv, input=conv,
num_filters=ch_out, num_filters=ch_out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册