diff --git a/ppcls/arch/backbone/legendary_models/hrnet.py b/ppcls/arch/backbone/legendary_models/hrnet.py index 20744f375ffd31b14a7fd479750c2f1150e88773..d720fdfeeb294e19ab86f3dc4ee93482df3358b7 100644 --- a/ppcls/arch/backbone/legendary_models/hrnet.py +++ b/ppcls/arch/backbone/legendary_models/hrnet.py @@ -94,47 +94,6 @@ class Layer1(TheseusLayer): return y -class TransitionLayer(TheseusLayer): - def __init__(self, in_channels, out_channels, name=None): - super(TransitionLayer, self).__init__() - - num_in = len(in_channels) - num_out = len(out_channels) - out = [] - self.conv_bn_func_list = [] - for i in range(num_out): - residual = None - if i < num_in: - if in_channels[i] != out_channels[i]: - residual = self.add_sublayer( - "transition_{}_layer_{}".format(name, i + 1), - ConvBNLayer( - num_channels=in_channels[i], - num_filters=out_channels[i], - filter_size=3)) - else: - residual = self.add_sublayer( - "transition_{}_layer_{}".format(name, i + 1), - ConvBNLayer( - num_channels=in_channels[-1], - num_filters=out_channels[i], - filter_size=3, - stride=2)) - self.conv_bn_func_list.append(residual) - - def forward(self, x, res_dict=None): - outs = [] - for idx, conv_bn_func in enumerate(self.conv_bn_func_list): - if conv_bn_func is None: - outs.append(x[idx]) - else: - if idx < len(x): - outs.append(conv_bn_func(x[idx])) - else: - outs.append(conv_bn_func(x[-1])) - return outs - - class Branches(TheseusLayer): def __init__(self, block_num, @@ -537,8 +496,16 @@ class HRNet(TheseusLayer): self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2") - self.tr1 = TransitionLayer( - in_channels=[256], out_channels=channels_2, name="tr1") + self.tr1_1 = BasicBlock( + num_channels=256, + num_filters=width, + has_se=has_se, + name="tr1_1") + self.tr1_2 = BasicBlock( + num_channels=width, + num_filters=width * 2, + has_se=has_se, + name="tr1_2") self.st2 = Stage( num_channels=channels_2, @@ -547,8 +514,11 @@ class HRNet(TheseusLayer): has_se=self.has_se, name="st2") - self.tr2 = TransitionLayer( - in_channels=channels_2, out_channels=channels_3, name="tr2") + self.tr2 = BasicBlock( + num_channels=width * 2, + num_filters=width * 4, + has_se=has_se, + name="tr2") self.st3 = Stage( num_channels=channels_3, num_modules=num_modules_3, @@ -556,8 +526,12 @@ class HRNet(TheseusLayer): has_se=self.has_se, name="st3") - self.tr3 = TransitionLayer( - in_channels=channels_3, out_channels=channels_4, name="tr3") + self.tr3 = BasicBlock( + num_channels=width * 4, + num_filters=width * 8, + has_se=has_se, + name="tr3") + self.st4 = Stage( num_channels=channels_4, num_modules=num_modules_4,