提交 64345cc0 编写于 作者: W weishengyu

dbg

上级 f645fac1
......@@ -75,42 +75,6 @@ class ConvBNLayer(TheseusLayer):
return y
class Branches(TheseusLayer):
def __init__(self,
block_num,
in_channels,
out_channels,
has_se=False,
name=None):
super(Branches, self).__init__()
self.basic_block_list = []
for i in range(len(out_channels)):
self.basic_block_list.append([])
for j in range(block_num):
in_ch = in_channels[i] if j == 0 else out_channels[i]
basic_block_func = self.add_sublayer(
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
BasicBlock(
num_channels=in_ch,
num_filters=out_channels[i],
has_se=has_se,
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.basic_block_list[i].append(basic_block_func)
def forward(self, x, res_dict=None):
outs = []
for idx, xi in enumerate(x):
conv = xi
basic_block_list = self.basic_block_list[idx]
for basic_block_func in basic_block_list:
conv = basic_block_func(conv)
outs.append(conv)
return outs
class BottleneckBlock(TheseusLayer):
def __init__(self,
num_channels,
......@@ -172,38 +136,30 @@ class BottleneckBlock(TheseusLayer):
return y
class BasicBlock(TheseusLayer):
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride=1,
has_se=False,
downsample=False,
name=None):
super(BasicBlock, self).__init__()
self.has_se = has_se
self.downsample = downsample
self.conv1 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu")
stride=1,
act="relu",
name=name + "_conv1")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=1,
act=None)
if self.downsample:
self.conv_down = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
act="relu")
act=None,
name=name + "_conv2")
if self.has_se:
self.se = SELayer(
......@@ -212,14 +168,11 @@ class BasicBlock(TheseusLayer):
reduction_ratio=16,
name='fc' + name)
def forward(self, input, res_dict=None):
def forward(self, input):
residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
if self.downsample:
residual = self.conv_down(input)
if self.has_se:
conv2 = self.se(conv2)
......@@ -315,12 +268,21 @@ class HighResolutionModule(TheseusLayer):
name=None):
super(HighResolutionModule, self).__init__()
self.branches_func = Branches(
block_num=4,
in_channels=num_channels,
out_channels=num_filters,
has_se=has_se,
name=name)
self.basic_block_list = []
for i in range(len(num_filters)):
self.basic_block_list.append([])
for j in range(4):
in_ch = num_channels[i] if j == 0 else num_channels[i]
basic_block_func = self.add_sublayer(
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
BasicBlock(
num_channels=in_ch,
num_filters=num_filters[i],
has_se=has_se,
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.basic_block_list[i].append(basic_block_func)
self.fuse_func = FuseLayers(
in_channels=num_filters,
......@@ -329,8 +291,14 @@ class HighResolutionModule(TheseusLayer):
name=name)
def forward(self, input, res_dict=None):
out = self.branches_func(input)
out = self.fuse_func(out)
outs = []
for idx, input in enumerate(input):
conv = input
basic_block_list = self.basic_block_list[idx]
for basic_block_func in basic_block_list:
conv = basic_block_func(conv)
outs.append(conv)
out = self.fuse_func(outs)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册