提交 fc8c813b 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix resnet-vd typo

上级 1b0daff6
...@@ -45,7 +45,11 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -45,7 +45,11 @@ class ConvBNLayer(fluid.dygraph.Layer):
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D( self._pool2d_avg = Pool2D(
pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True) pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
self._conv = Conv2D( self._conv = Conv2D(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
...@@ -132,7 +136,7 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -132,7 +136,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
return layer_helper.append_activation(y) return layer_helper.append_activation(y)
class BisicBlock(fluid.dygraph.Layer): class BasicBlock(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
...@@ -140,7 +144,7 @@ class BisicBlock(fluid.dygraph.Layer): ...@@ -140,7 +144,7 @@ class BisicBlock(fluid.dygraph.Layer):
shortcut=True, shortcut=True,
if_first=False, if_first=False,
name=None): name=None):
super(BisicBlock, self).__init__() super(BasicBlock, self).__init__()
self.stride = stride self.stride = stride
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
...@@ -258,9 +262,9 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -258,9 +262,9 @@ class ResNet_vd(fluid.dygraph.Layer):
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
bisic_block = self.add_sublayer( basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BisicBlock( BasicBlock(
num_channels=num_channels[block] num_channels=num_channels[block]
if i == 0 else num_filters[block], if i == 0 else num_filters[block],
num_filters=num_filters[block], num_filters=num_filters[block],
...@@ -268,7 +272,7 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -268,7 +272,7 @@ class ResNet_vd(fluid.dygraph.Layer):
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0, if_first=block == i == 0,
name=conv_name)) name=conv_name))
self.block_list.append(bisic_block) self.block_list.append(basic_block)
shortcut = True shortcut = True
self.pool2d_avg = Pool2D( self.pool2d_avg = Pool2D(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册