提交 9ab8639c 编写于 作者: W weishengyu

remove transition

上级 b38b0f38
......@@ -94,47 +94,6 @@ class Layer1(TheseusLayer):
return y
class TransitionLayer(TheseusLayer):
def __init__(self, in_channels, out_channels, name=None):
super(TransitionLayer, self).__init__()
num_in = len(in_channels)
num_out = len(out_channels)
out = []
self.conv_bn_func_list = []
for i in range(num_out):
residual = None
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvBNLayer(
num_channels=in_channels[i],
num_filters=out_channels[i],
filter_size=3))
else:
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvBNLayer(
num_channels=in_channels[-1],
num_filters=out_channels[i],
filter_size=3,
stride=2))
self.conv_bn_func_list.append(residual)
def forward(self, x, res_dict=None):
outs = []
for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
if conv_bn_func is None:
outs.append(x[idx])
else:
if idx < len(x):
outs.append(conv_bn_func(x[idx]))
else:
outs.append(conv_bn_func(x[-1]))
return outs
class Branches(TheseusLayer):
def __init__(self,
block_num,
......@@ -537,8 +496,16 @@ class HRNet(TheseusLayer):
self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
self.tr1 = TransitionLayer(
in_channels=[256], out_channels=channels_2, name="tr1")
self.tr1_1 = BasicBlock(
num_channels=256,
num_filters=width,
has_se=has_se,
name="tr1_1")
self.tr1_2 = BasicBlock(
num_channels=width,
num_filters=width * 2,
has_se=has_se,
name="tr1_2")
self.st2 = Stage(
num_channels=channels_2,
......@@ -547,8 +514,11 @@ class HRNet(TheseusLayer):
has_se=self.has_se,
name="st2")
self.tr2 = TransitionLayer(
in_channels=channels_2, out_channels=channels_3, name="tr2")
self.tr2 = BasicBlock(
num_channels=width * 2,
num_filters=width * 4,
has_se=has_se,
name="tr2")
self.st3 = Stage(
num_channels=channels_3,
num_modules=num_modules_3,
......@@ -556,8 +526,12 @@ class HRNet(TheseusLayer):
has_se=self.has_se,
name="st3")
self.tr3 = TransitionLayer(
in_channels=channels_3, out_channels=channels_4, name="tr3")
self.tr3 = BasicBlock(
num_channels=width * 4,
num_filters=width * 8,
has_se=has_se,
name="tr3")
self.st4 = Stage(
num_channels=channels_4,
num_modules=num_modules_4,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册