提交 f6df698c 编写于 作者: G gaotingquan 提交者: Tingquan Gao

micro_block -> layer_type

上级 81de331e
......@@ -164,18 +164,18 @@ class BottleneckBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
micro_block=ConvBNLayer,
layer=ConvBNLayer,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.conv0 = micro_block(
self.conv0 = layer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = micro_block(
self.conv1 = layer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
......@@ -183,7 +183,7 @@ class BottleneckBlock(TheseusLayer):
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv2 = micro_block(
self.conv2 = layer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
......@@ -226,13 +226,13 @@ class BasicBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
micro_block=ConvBNLayer,
layer=ConvBNLayer,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.stride = stride
self.conv0 = micro_block(
self.conv0 = layer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
......@@ -240,7 +240,7 @@ class BasicBlock(TheseusLayer):
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = micro_block(
self.conv1 = layer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
......@@ -296,7 +296,7 @@ class ResNet(TheseusLayer):
input_image_channel=3,
return_patterns=None,
return_stages=None,
micro_block="ConvBNLayer",
layer_type="ConvBNLayer",
use_first_short_conv=True,
**kargs):
super().__init__()
......@@ -312,10 +312,10 @@ class ResNet(TheseusLayer):
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
if micro_block == "ConvBNLayer":
micro_block = ConvBNLayer
elif micro_block == "DiverseBranchBlock":
micro_block = DiverseBranchBlock
if layer_type == "ConvBNLayer":
layer = ConvBNLayer
elif layer_type == "DiverseBranchBlock":
layer = DiverseBranchBlock
else:
raise Exception()
......@@ -377,7 +377,7 @@ class ResNet(TheseusLayer):
if i == 0 and block_idx != 0 else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
micro_block=micro_block,
layer=layer,
lr_mult=self.lr_mult_list[block_idx + 1],
data_format=data_format))
shortcut = True
......
......@@ -19,7 +19,7 @@ Global:
Arch:
name: ResNet18
class_num: 1000
micro_block: DiverseBranchBlock
layer_type: DiverseBranchBlock
use_first_short_conv: False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册