未验证 提交 2c82482c 编写于 作者: M michaelowenliu 提交者: GitHub

Merge pull request #397 from michaelowenliu/develop

add feat_channels in backbone
...@@ -224,9 +224,9 @@ class ResNet_vd(nn.Layer): ...@@ -224,9 +224,9 @@ class ResNet_vd(nn.Layer):
] if layers >= 50 else [64, 64, 128, 256] ] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
# for channels of returned stage # for channels of four returned stages
self.backbone_channels = [c * 4 for c in num_filters self.feat_channels = [c * 4 for c in num_filters
] if layers >= 50 else num_filters ] if layers >= 50 else num_filters
dilation_dict = None dilation_dict = None
if output_stride == 8: if output_stride == 8:
...@@ -319,7 +319,7 @@ class ResNet_vd(nn.Layer): ...@@ -319,7 +319,7 @@ class ResNet_vd(nn.Layer):
block_list.append(basic_block) block_list.append(basic_block)
shortcut = True shortcut = True
self.stage_list.append(block_list) self.stage_list.append(block_list)
utils.load_pretrained_model(self, pretrained) utils.load_pretrained_model(self, pretrained)
def forward(self, inputs): def forward(self, inputs):
...@@ -336,8 +336,6 @@ class ResNet_vd(nn.Layer): ...@@ -336,8 +336,6 @@ class ResNet_vd(nn.Layer):
feat_list.append(y) feat_list.append(y)
return feat_list return feat_list
@manager.BACKBONES.add_component @manager.BACKBONES.add_component
......
...@@ -190,7 +190,7 @@ class DANet(nn.Layer): ...@@ -190,7 +190,7 @@ class DANet(nn.Layer):
self.backbone = backbone self.backbone = backbone
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices] in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
self.head = DAHead(num_classes=num_classes, in_channels=in_channels) self.head = DAHead(num_classes=num_classes, in_channels=in_channels)
......
...@@ -62,14 +62,13 @@ class DeepLabV3P(nn.Layer): ...@@ -62,14 +62,13 @@ class DeepLabV3P(nn.Layer):
super(DeepLabV3P, self).__init__() super(DeepLabV3P, self).__init__()
self.backbone = backbone self.backbone = backbone
backbone_channels = backbone.backbone_channels backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = DeepLabV3PHead( self.head = DeepLabV3PHead(num_classes, backbone_indices,
num_classes, backbone_channels, aspp_ratios,
backbone_indices, aspp_out_channels)
backbone_channels,
aspp_ratios,
aspp_out_channels)
utils.load_entire_model(self, pretrained) utils.load_entire_model(self, pretrained)
...@@ -81,6 +80,7 @@ class DeepLabV3P(nn.Layer): ...@@ -81,6 +80,7 @@ class DeepLabV3P(nn.Layer):
F.resize_bilinear(logit, input.shape[2:]) for logit in logit_list F.resize_bilinear(logit, input.shape[2:]) for logit in logit_list
] ]
class DeepLabV3PHead(nn.Layer): class DeepLabV3PHead(nn.Layer):
""" """
The DeepLabV3PHead implementation based on PaddlePaddle. The DeepLabV3PHead implementation based on PaddlePaddle.
...@@ -110,14 +110,14 @@ class DeepLabV3PHead(nn.Layer): ...@@ -110,14 +110,14 @@ class DeepLabV3PHead(nn.Layer):
aspp_out_channels=256): aspp_out_channels=256):
super(DeepLabV3PHead, self).__init__() super(DeepLabV3PHead, self).__init__()
self.aspp = pyramid_pool.ASPPModule( self.aspp = pyramid_pool.ASPPModule(
aspp_ratios, aspp_ratios,
backbone_channels[backbone_indices[1]], backbone_channels[1],
aspp_out_channels, aspp_out_channels,
sep_conv=True, sep_conv=True,
image_pooling=True) image_pooling=True)
self.decoder = Decoder(num_classes, backbone_channels[backbone_indices[0]]) self.decoder = Decoder(num_classes, backbone_channels[0])
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
self.init_weight() self.init_weight()
...@@ -135,6 +135,7 @@ class DeepLabV3PHead(nn.Layer): ...@@ -135,6 +135,7 @@ class DeepLabV3PHead(nn.Layer):
def init_weight(self): def init_weight(self):
pass pass
@manager.MODELS.add_component @manager.MODELS.add_component
class DeepLabV3(nn.Layer): class DeepLabV3(nn.Layer):
""" """
...@@ -147,7 +148,7 @@ class DeepLabV3(nn.Layer): ...@@ -147,7 +148,7 @@ class DeepLabV3(nn.Layer):
Args: Args:
Refer to DeepLabV3P above Refer to DeepLabV3P above
""" """
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone, backbone,
...@@ -159,15 +160,14 @@ class DeepLabV3(nn.Layer): ...@@ -159,15 +160,14 @@ class DeepLabV3(nn.Layer):
super(DeepLabV3, self).__init__() super(DeepLabV3, self).__init__()
self.backbone = backbone self.backbone = backbone
backbone_channels = backbone.backbone_channels backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = DeepLabV3Head(num_classes, backbone_indices,
backbone_channels, aspp_ratios,
aspp_out_channels)
self.head = DeepLabV3Head(
num_classes,
backbone_indices,
backbone_channels,
aspp_ratios,
aspp_out_channels)
utils.load_entire_model(self, pretrained) utils.load_entire_model(self, pretrained)
def forward(self, input): def forward(self, input):
...@@ -191,13 +191,13 @@ class DeepLabV3Head(nn.Layer): ...@@ -191,13 +191,13 @@ class DeepLabV3Head(nn.Layer):
self.aspp = pyramid_pool.ASPPModule( self.aspp = pyramid_pool.ASPPModule(
aspp_ratios, aspp_ratios,
backbone_channels[backbone_indices[0]], backbone_channels[0],
aspp_out_channels, aspp_out_channels,
sep_conv=False, sep_conv=False,
image_pooling=True) image_pooling=True)
self.cls = nn.Conv2d( self.cls = nn.Conv2d(
in_channels=backbone_channels[backbone_indices[0]], in_channels=backbone_channels[0],
out_channels=num_classes, out_channels=num_classes,
kernel_size=1) kernel_size=1)
......
...@@ -203,7 +203,7 @@ class OCRNet(nn.Layer): ...@@ -203,7 +203,7 @@ class OCRNet(nn.Layer):
self.backbone = backbone self.backbone = backbone
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices] in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
self.head = OCRHead( self.head = OCRHead(
num_classes=num_classes, num_classes=num_classes,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册