提交 1d393583 编写于 作者: W weishengyu

dbg

上级 0e568006
...@@ -218,7 +218,6 @@ class SELayer(TheseusLayer): ...@@ -218,7 +218,6 @@ class SELayer(TheseusLayer):
class Stage(TheseusLayer): class Stage(TheseusLayer):
def __init__(self, def __init__(self,
num_channels,
num_modules, num_modules,
num_filters, num_filters,
has_se=False, has_se=False,
...@@ -234,7 +233,6 @@ class Stage(TheseusLayer): ...@@ -234,7 +233,6 @@ class Stage(TheseusLayer):
stage_func = self.add_sublayer( stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1), "stage_{}_{}".format(name, i + 1),
HighResolutionModule( HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
has_se=has_se, has_se=has_se,
multi_scale_output=False, multi_scale_output=False,
...@@ -243,7 +241,6 @@ class Stage(TheseusLayer): ...@@ -243,7 +241,6 @@ class Stage(TheseusLayer):
stage_func = self.add_sublayer( stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1), "stage_{}_{}".format(name, i + 1),
HighResolutionModule( HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
has_se=has_se, has_se=has_se,
name=name + '_' + str(i + 1))) name=name + '_' + str(i + 1)))
...@@ -259,7 +256,6 @@ class Stage(TheseusLayer): ...@@ -259,7 +256,6 @@ class Stage(TheseusLayer):
class HighResolutionModule(TheseusLayer): class HighResolutionModule(TheseusLayer):
def __init__(self, def __init__(self,
num_channels,
num_filters, num_filters,
has_se=False, has_se=False,
multi_scale_output=True, multi_scale_output=True,
...@@ -271,7 +267,7 @@ class HighResolutionModule(TheseusLayer): ...@@ -271,7 +267,7 @@ class HighResolutionModule(TheseusLayer):
for i in range(len(num_filters)): for i in range(len(num_filters)):
self.basic_block_list.append([]) self.basic_block_list.append([])
for j in range(4): for j in range(4):
in_ch = num_channels[i] if j == 0 else num_channels[i] in_ch = num_filters[i]
basic_block_func = self.add_sublayer( basic_block_func = self.add_sublayer(
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1), "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
BasicBlock( BasicBlock(
...@@ -425,7 +421,6 @@ class HRNet(TheseusLayer): ...@@ -425,7 +421,6 @@ class HRNet(TheseusLayer):
self._class_dim = class_dim self._class_dim = class_dim
channels_2, channels_3, channels_4 = self.channels[width] channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
self.conv_layer1_1 = ConvBNLayer( self.conv_layer1_1 = ConvBNLayer(
num_channels=3, num_channels=3,
...@@ -462,8 +457,7 @@ class HRNet(TheseusLayer): ...@@ -462,8 +457,7 @@ class HRNet(TheseusLayer):
) )
self.st2 = Stage( self.st2 = Stage(
num_channels=channels_2, num_modules=1,
num_modules=num_modules_2,
num_filters=channels_2, num_filters=channels_2,
has_se=self.has_se, has_se=self.has_se,
name="st2") name="st2")
...@@ -475,8 +469,7 @@ class HRNet(TheseusLayer): ...@@ -475,8 +469,7 @@ class HRNet(TheseusLayer):
stride=2 stride=2
) )
self.st3 = Stage( self.st3 = Stage(
num_channels=channels_3, num_modules=4,
num_modules=num_modules_3,
num_filters=channels_3, num_filters=channels_3,
has_se=self.has_se, has_se=self.has_se,
name="st3") name="st3")
...@@ -489,8 +482,7 @@ class HRNet(TheseusLayer): ...@@ -489,8 +482,7 @@ class HRNet(TheseusLayer):
) )
self.st4 = Stage( self.st4 = Stage(
num_channels=channels_4, num_modules=3,
num_modules=num_modules_4,
num_filters=channels_4, num_filters=channels_4,
has_se=self.has_se, has_se=self.has_se,
name="st4") name="st4")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册