未验证 提交 53f1d1e4 编写于 作者: M michaelowenliu 提交者: GitHub

Merge pull request #401 from michaelowenliu/develop

delete single learning_rate
batch_size: 4 batch_size: 4
iters: 100000 iters: 100000
learning_rate: 0.01
train_dataset: train_dataset:
type: Cityscapes type: Cityscapes
......
batch_size: 4 batch_size: 4
iters: 10000 iters: 10000
learning_rate: 0.01
train_dataset: train_dataset:
type: OpticDiscSeg type: OpticDiscSeg
......
...@@ -19,7 +19,7 @@ import paddle.nn.functional as F ...@@ -19,7 +19,7 @@ import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_libs from paddleseg.models.common.layer_libs import ConvBNReLU, ConvBN, AuxLayer
from paddleseg.utils import utils from paddleseg.utils import utils
...@@ -32,11 +32,62 @@ class ANN(nn.Layer): ...@@ -32,11 +32,62 @@ class ANN(nn.Layer):
Zhen, Zhu, et al. "Asymmetric Non-local Neural Networks for Semantic Segmentation." Zhen, Zhu, et al. "Asymmetric Non-local Neural Networks for Semantic Segmentation."
(https://arxiv.org/pdf/1908.07678.pdf) (https://arxiv.org/pdf/1908.07678.pdf)
Args:
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Default to None.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
key_value_channels (int): the key and value channels of self-attention map in both AFNB and APNB modules.
Default to 256.
inter_channels (int): both input and output channels of APNB modules.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
enable_auxiliary_loss (bool): a bool values indicates whether adding auxiliary loss. Default to True.
pretrained (str): the path of pretrained model. Default to None.
"""
def __init__(self,
num_classes,
backbone,
backbone_indices=(2, 3),
key_value_channels=256,
inter_channels=512,
psp_size=(1, 3, 6, 8),
enable_auxiliary_loss=True,
pretrained=None,):
super(ANN, self).__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = ANNHead(
num_classes,
backbone_indices,
backbone_channels,
key_value_channels,
inter_channels,
psp_size,
enable_auxiliary_loss)
utils.load_entire_model(self, pretrained)
def forward(self, input):
feat_list = self.backbone(input)
logit_list = self.head(feat_list)
return [
F.resize_bilinear(logit, input.shape[2:]) for logit in logit_list
]
class ANNHead(nn.Layer):
"""
The ANNHead implementation.
It mainly consists of AFNB and APNB modules. It mainly consists of AFNB and APNB modules.
Args: Args:
num_classes (int): the unique number of target classes. num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Default to None. model_pretrained (str): the path of pretrained model. Default to None.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone. backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as low-level features; the second one will be the first index will be taken as low-level features; the second one will be
...@@ -53,17 +104,13 @@ class ANN(nn.Layer): ...@@ -53,17 +104,13 @@ class ANN(nn.Layer):
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone,
model_pretrained=None,
backbone_indices=(2, 3), backbone_indices=(2, 3),
backbone_channels=(1024, 2048), backbone_channels=(1024, 2048),
key_value_channels=256, key_value_channels=256,
inter_channels=512, inter_channels=512,
psp_size=(1, 3, 6, 8), psp_size=(1, 3, 6, 8),
enable_auxiliary_loss=True): enable_auxiliary_loss=True):
super(ANN, self).__init__() super(ANNHead, self).__init__()
self.backbone = backbone
low_in_channels = backbone_channels[0] low_in_channels = backbone_channels[0]
high_in_channels = backbone_channels[1] high_in_channels = backbone_channels[1]
...@@ -79,7 +126,7 @@ class ANN(nn.Layer): ...@@ -79,7 +126,7 @@ class ANN(nn.Layer):
psp_size=psp_size) psp_size=psp_size)
self.context = nn.Sequential( self.context = nn.Sequential(
layer_libs.ConvBNReLU( ConvBNReLU(
in_channels=high_in_channels, in_channels=high_in_channels,
out_channels=inter_channels, out_channels=inter_channels,
kernel_size=3, kernel_size=3,
...@@ -95,7 +142,7 @@ class ANN(nn.Layer): ...@@ -95,7 +142,7 @@ class ANN(nn.Layer):
self.cls = nn.Conv2d( self.cls = nn.Conv2d(
in_channels=inter_channels, out_channels=num_classes, kernel_size=1) in_channels=inter_channels, out_channels=num_classes, kernel_size=1)
self.auxlayer = layer_libs.AuxLayer( self.auxlayer = AuxLayer(
in_channels=low_in_channels, in_channels=low_in_channels,
inter_channels=low_in_channels // 2, inter_channels=low_in_channels // 2,
out_channels=num_classes, out_channels=num_classes,
...@@ -104,39 +151,29 @@ class ANN(nn.Layer): ...@@ -104,39 +151,29 @@ class ANN(nn.Layer):
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
self.enable_auxiliary_loss = enable_auxiliary_loss self.enable_auxiliary_loss = enable_auxiliary_loss
self.init_weight(model_pretrained) self.init_weight()
def forward(self, input, label=None): def forward(self, feat_list):
logit_list = [] logit_list = []
_, feat_list = self.backbone(input)
low_level_x = feat_list[self.backbone_indices[0]] low_level_x = feat_list[self.backbone_indices[0]]
high_level_x = feat_list[self.backbone_indices[1]] high_level_x = feat_list[self.backbone_indices[1]]
x = self.fusion(low_level_x, high_level_x) x = self.fusion(low_level_x, high_level_x)
x = self.context(x) x = self.context(x)
logit = self.cls(x) logit = self.cls(x)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit) logit_list.append(logit)
if self.enable_auxiliary_loss: if self.enable_auxiliary_loss:
auxiliary_logit = self.auxlayer(low_level_x) auxiliary_logit = self.auxlayer(low_level_x)
auxiliary_logit = F.resize_bilinear(auxiliary_logit,
input.shape[2:])
logit_list.append(auxiliary_logit) logit_list.append(auxiliary_logit)
return logit_list return logit_list
def init_weight(self, pretrained_model=None): def init_weight(self):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
""" """
pass
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
class AFNB(nn.Layer): class AFNB(nn.Layer):
...@@ -171,7 +208,7 @@ class AFNB(nn.Layer): ...@@ -171,7 +208,7 @@ class AFNB(nn.Layer):
key_channels, value_channels, out_channels, key_channels, value_channels, out_channels,
size) for size in sizes size) for size in sizes
]) ])
self.conv_bn = layer_libs.ConvBn( self.conv_bn = ConvBN(
in_channels=out_channels + high_in_channels, in_channels=out_channels + high_in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=1) kernel_size=1)
...@@ -218,7 +255,7 @@ class APNB(nn.Layer): ...@@ -218,7 +255,7 @@ class APNB(nn.Layer):
SelfAttentionBlock_APNB(in_channels, out_channels, key_channels, SelfAttentionBlock_APNB(in_channels, out_channels, key_channels,
value_channels, size) for size in sizes value_channels, size) for size in sizes
]) ])
self.conv_bn = layer_libs.ConvBNReLU( self.conv_bn = ConvBNReLU(
in_channels=in_channels * 2, in_channels=in_channels * 2,
out_channels=out_channels, out_channels=out_channels,
kernel_size=1) kernel_size=1)
...@@ -279,11 +316,11 @@ class SelfAttentionBlock_AFNB(nn.Layer): ...@@ -279,11 +316,11 @@ class SelfAttentionBlock_AFNB(nn.Layer):
if out_channels == None: if out_channels == None:
self.out_channels = high_in_channels self.out_channels = high_in_channels
self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max") self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max")
self.f_key = layer_libs.ConvBNReLU( self.f_key = ConvBNReLU(
in_channels=low_in_channels, in_channels=low_in_channels,
out_channels=key_channels, out_channels=key_channels,
kernel_size=1) kernel_size=1)
self.f_query = layer_libs.ConvBNReLU( self.f_query = ConvBNReLU(
in_channels=high_in_channels, in_channels=high_in_channels,
out_channels=key_channels, out_channels=key_channels,
kernel_size=1) kernel_size=1)
...@@ -357,7 +394,7 @@ class SelfAttentionBlock_APNB(nn.Layer): ...@@ -357,7 +394,7 @@ class SelfAttentionBlock_APNB(nn.Layer):
self.value_channels = value_channels self.value_channels = value_channels
self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max") self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max")
self.f_key = layer_libs.ConvBNReLU( self.f_key = ConvBNReLU(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1) kernel_size=1)
......
...@@ -18,7 +18,8 @@ import paddle ...@@ -18,7 +18,8 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.models.common import pyramid_pool, layer_libs from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer
from paddleseg.utils import utils from paddleseg.utils import utils
__all__ = ['DeepLabV3P', 'DeepLabV3'] __all__ = ['DeepLabV3P', 'DeepLabV3']
...@@ -47,8 +48,7 @@ class DeepLabV3P(nn.Layer): ...@@ -47,8 +48,7 @@ class DeepLabV3P(nn.Layer):
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18). if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36). if output_stride=8, aspp_ratios is (1, 12, 24, 36).
aspp_out_channels (int): the output channels of ASPP module. aspp_out_channels (int): the output channels of ASPP module.
pretrained (str): the path of pretrained model for fine tuning. pretrained (str): the path of pretrained model. Default to None.
""" """
def __init__(self, def __init__(self,
...@@ -94,7 +94,7 @@ class DeepLabV3PHead(nn.Layer): ...@@ -94,7 +94,7 @@ class DeepLabV3PHead(nn.Layer):
each stage, so we set default (0, 3), which means taking feature map of the first each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP. stage as input of ASPP.
backbone_channels (tuple): returned channels of backbone backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
aspp_ratios (tuple): the dilation rate using in ASSP module. aspp_ratios (tuple): the dilation rate using in ASSP module.
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18). if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36). if output_stride=8, aspp_ratios is (1, 12, 24, 36).
...@@ -231,12 +231,12 @@ class Decoder(nn.Layer): ...@@ -231,12 +231,12 @@ class Decoder(nn.Layer):
def __init__(self, num_classes, in_channels): def __init__(self, num_classes, in_channels):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv_bn_relu1 = layer_libs.ConvBNReLU( self.conv_bn_relu1 = ConvBNReLU(
in_channels=in_channels, out_channels=48, kernel_size=1) in_channels=in_channels, out_channels=48, kernel_size=1)
self.conv_bn_relu2 = layer_libs.DepthwiseConvBNReLU( self.conv_bn_relu2 = DepthwiseConvBNReLU(
in_channels=304, out_channels=256, kernel_size=3, padding=1) in_channels=304, out_channels=256, kernel_size=3, padding=1)
self.conv_bn_relu3 = layer_libs.DepthwiseConvBNReLU( self.conv_bn_relu3 = DepthwiseConvBNReLU(
in_channels=256, out_channels=256, kernel_size=3, padding=1) in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels=256, out_channels=num_classes, kernel_size=1) in_channels=256, out_channels=num_classes, kernel_size=1)
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_libs, pyramid_pool
from paddleseg.cvlibs import manager
from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer
from paddleseg.utils import utils
@manager.MODELS.add_component @manager.MODELS.add_component
class FastSCNN(nn.Layer): class FastSCNN(nn.Layer):
...@@ -33,15 +35,15 @@ class FastSCNN(nn.Layer): ...@@ -33,15 +35,15 @@ class FastSCNN(nn.Layer):
Args: Args:
num_classes (int): the unique number of target classes. Default to 2. num_classes (int): the unique number of target classes. Default to 2.
model_pretrained (str): the path of pretrained model. Default to None.
enable_auxiliary_loss (bool): a bool values indicates whether adding auxiliary loss. enable_auxiliary_loss (bool): a bool values indicates whether adding auxiliary loss.
if true, auxiliary loss will be added after LearningToDownsample module, where the weight is 0.4. Default to False. if true, auxiliary loss will be added after LearningToDownsample module, where the weight is 0.4. Default to False.
pretrained (str): the path of pretrained model. Default to None.
""" """
def __init__(self, def __init__(self,
num_classes, num_classes,
model_pretrained=None, enable_auxiliary_loss=True,
enable_auxiliary_loss=True): pretrained=None):
super(FastSCNN, self).__init__() super(FastSCNN, self).__init__()
...@@ -52,11 +54,12 @@ class FastSCNN(nn.Layer): ...@@ -52,11 +54,12 @@ class FastSCNN(nn.Layer):
self.classifier = Classifier(128, num_classes) self.classifier = Classifier(128, num_classes)
if enable_auxiliary_loss: if enable_auxiliary_loss:
self.auxlayer = layer_libs.AuxLayer(64, 32, num_classes) self.auxlayer = AuxLayer(64, 32, num_classes)
self.enable_auxiliary_loss = enable_auxiliary_loss self.enable_auxiliary_loss = enable_auxiliary_loss
self.init_weight(model_pretrained) self.init_weight()
utils.load_entire_model(self, pretrained)
def forward(self, input, label=None): def forward(self, input, label=None):
...@@ -76,18 +79,11 @@ class FastSCNN(nn.Layer): ...@@ -76,18 +79,11 @@ class FastSCNN(nn.Layer):
return logit_list return logit_list
def init_weight(self, pretrained_model=None): def init_weight(self):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: pass
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class LearningToDownsample(nn.Layer): class LearningToDownsample(nn.Layer):
...@@ -105,15 +101,15 @@ class LearningToDownsample(nn.Layer): ...@@ -105,15 +101,15 @@ class LearningToDownsample(nn.Layer):
def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64): def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64):
super(LearningToDownsample, self).__init__() super(LearningToDownsample, self).__init__()
self.conv_bn_relu = layer_libs.ConvBNReLU( self.conv_bn_relu = ConvBNReLU(
in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2) in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2)
self.dsconv_bn_relu1 = layer_libs.DepthwiseConvBNReLU( self.dsconv_bn_relu1 = DepthwiseConvBNReLU(
in_channels=dw_channels1, in_channels=dw_channels1,
out_channels=dw_channels2, out_channels=dw_channels2,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
padding=1) padding=1)
self.dsconv_bn_relu2 = layer_libs.DepthwiseConvBNReLU( self.dsconv_bn_relu2 = DepthwiseConvBNReLU(
in_channels=dw_channels2, in_channels=dw_channels2,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
...@@ -208,13 +204,13 @@ class LinearBottleneck(nn.Layer): ...@@ -208,13 +204,13 @@ class LinearBottleneck(nn.Layer):
expand_channels = in_channels * expansion expand_channels = in_channels * expansion
self.block = nn.Sequential( self.block = nn.Sequential(
# pw # pw
layer_libs.ConvBNReLU( ConvBNReLU(
in_channels=in_channels, in_channels=in_channels,
out_channels=expand_channels, out_channels=expand_channels,
kernel_size=1, kernel_size=1,
bias_attr=False), bias_attr=False),
# dw # dw
layer_libs.ConvBNReLU( ConvBNReLU(
in_channels=expand_channels, in_channels=expand_channels,
out_channels=expand_channels, out_channels=expand_channels,
kernel_size=3, kernel_size=3,
...@@ -253,7 +249,7 @@ class FeatureFusionModule(nn.Layer): ...@@ -253,7 +249,7 @@ class FeatureFusionModule(nn.Layer):
super(FeatureFusionModule, self).__init__() super(FeatureFusionModule, self).__init__()
# There only depth-wise conv is used WITHOUT point-wise conv # There only depth-wise conv is used WITHOUT point-wise conv
self.dwconv = layer_libs.ConvBNReLU( self.dwconv = ConvBNReLU(
in_channels=low_in_channels, in_channels=low_in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
...@@ -289,9 +285,9 @@ class FeatureFusionModule(nn.Layer): ...@@ -289,9 +285,9 @@ class FeatureFusionModule(nn.Layer):
class Classifier(nn.Layer): class Classifier(nn.Layer):
""" """
The Classifier module implemetation. The Classifier module implementation.
This module consists of two depth-wsie conv and one conv. This module consists of two depth-wise conv and one conv.
Args: Args:
input_channels (int): the input channels to this module. input_channels (int): the input channels to this module.
...@@ -301,13 +297,13 @@ class Classifier(nn.Layer): ...@@ -301,13 +297,13 @@ class Classifier(nn.Layer):
def __init__(self, input_channels, num_classes): def __init__(self, input_channels, num_classes):
super(Classifier, self).__init__() super(Classifier, self).__init__()
self.dsconv1 = layer_libs.DepthwiseConvBNReLU( self.dsconv1 = DepthwiseConvBNReLU(
in_channels=input_channels, in_channels=input_channels,
out_channels=input_channels, out_channels=input_channels,
kernel_size=3, kernel_size=3,
padding=1) padding=1)
self.dsconv2 = layer_libs.DepthwiseConvBNReLU( self.dsconv2 = DepthwiseConvBNReLU(
in_channels=input_channels, in_channels=input_channels,
out_channels=input_channels, out_channels=input_channels,
kernel_size=3, kernel_size=3,
......
...@@ -18,10 +18,12 @@ import paddle ...@@ -18,10 +18,12 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_libs from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer
from paddleseg.utils import utils from paddleseg.utils import utils
@manager.MODELS.add_component @manager.MODELS.add_component
class GCNet(nn.Layer): class GCNet(nn.Layer):
""" """
...@@ -34,7 +36,54 @@ class GCNet(nn.Layer): ...@@ -34,7 +36,54 @@ class GCNet(nn.Layer):
Args: Args:
num_classes (int): the unique number of target classes. num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101. backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Default to None. backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
gc_channels (int): input channels to Global Context Block. Default to 512.
ratio (float): it indicates the ratio of attention channels and gc_channels. Default to 1/4.
enable_auxiliary_loss (bool): a bool values indicates whether adding auxiliary loss. Default to True.
pretrained (str): the path of pretrained model. Default to None.
"""
def __init__(self,
num_classes,
backbone,
backbone_indices=(2, 3),
gc_channels=512,
ratio=1 / 4,
enable_auxiliary_loss=True,
pretrained=None):
super(GCNet, self).__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = GCNetHead(
num_classes,
backbone_indices,
backbone_channels,
gc_channels,
ratio,
enable_auxiliary_loss)
utils.load_entire_model(self, pretrained)
def forward(self, input):
feat_list = self.backbone(input)
logit_list = self.head(feat_list)
return [
F.resize_bilinear(logit, input.shape[2:]) for logit in logit_list
]
class GCNetHead(nn.Layer):
"""
The GCNetHead implementation.
Args:
num_classes (int): the unique number of target classes.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone. backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer; the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of GlobalContextBlock. Usually backbone the second one will be taken as input of GlobalContextBlock. Usually backbone
...@@ -49,21 +98,16 @@ class GCNet(nn.Layer): ...@@ -49,21 +98,16 @@ class GCNet(nn.Layer):
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone,
model_pretrained=None,
backbone_indices=(2, 3), backbone_indices=(2, 3),
backbone_channels=(1024, 2048), backbone_channels=(1024, 2048),
gc_channels=512, gc_channels=512,
ratio=1 / 4, ratio=1 / 4,
enable_auxiliary_loss=True, enable_auxiliary_loss=True):
pretrained_model=None):
super(GCNet, self).__init__() super(GCNetHead, self).__init__()
self.backbone = backbone
in_channels = backbone_channels[1] in_channels = backbone_channels[1]
self.conv_bn_relu1 = layer_libs.ConvBNReLU( self.conv_bn_relu1 = ConvBNReLU(
in_channels=in_channels, in_channels=in_channels,
out_channels=gc_channels, out_channels=gc_channels,
kernel_size=3, kernel_size=3,
...@@ -71,13 +115,13 @@ class GCNet(nn.Layer): ...@@ -71,13 +115,13 @@ class GCNet(nn.Layer):
self.gc_block = GlobalContextBlock(in_channels=gc_channels, ratio=ratio) self.gc_block = GlobalContextBlock(in_channels=gc_channels, ratio=ratio)
self.conv_bn_relu2 = layer_libs.ConvBNReLU( self.conv_bn_relu2 = ConvBNReLU(
in_channels=gc_channels, in_channels=gc_channels,
out_channels=gc_channels, out_channels=gc_channels,
kernel_size=3, kernel_size=3,
padding=1) padding=1)
self.conv_bn_relu3 = layer_libs.ConvBNReLU( self.conv_bn_relu3 = ConvBNReLU(
in_channels=in_channels + gc_channels, in_channels=in_channels + gc_channels,
out_channels=gc_channels, out_channels=gc_channels,
kernel_size=3, kernel_size=3,
...@@ -87,7 +131,7 @@ class GCNet(nn.Layer): ...@@ -87,7 +131,7 @@ class GCNet(nn.Layer):
in_channels=gc_channels, out_channels=num_classes, kernel_size=1) in_channels=gc_channels, out_channels=num_classes, kernel_size=1)
if enable_auxiliary_loss: if enable_auxiliary_loss:
self.auxlayer = layer_libs.AuxLayer( self.auxlayer = AuxLayer(
in_channels=backbone_channels[0], in_channels=backbone_channels[0],
inter_channels=backbone_channels[0] // 4, inter_channels=backbone_channels[0] // 4,
out_channels=num_classes) out_channels=num_classes)
...@@ -95,12 +139,11 @@ class GCNet(nn.Layer): ...@@ -95,12 +139,11 @@ class GCNet(nn.Layer):
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
self.enable_auxiliary_loss = enable_auxiliary_loss self.enable_auxiliary_loss = enable_auxiliary_loss
self.init_weight(model_pretrained) self.init_weight()
def forward(self, input, label=None): def forward(self, feat_list):
logit_list = [] logit_list = []
_, feat_list = self.backbone(input)
x = feat_list[self.backbone_indices[1]] x = feat_list[self.backbone_indices[1]]
output = self.conv_bn_relu1(x) output = self.conv_bn_relu1(x)
...@@ -112,14 +155,11 @@ class GCNet(nn.Layer): ...@@ -112,14 +155,11 @@ class GCNet(nn.Layer):
output = F.dropout(output, p=0.1) # dropout_prob output = F.dropout(output, p=0.1) # dropout_prob
logit = self.conv(output) logit = self.conv(output)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit) logit_list.append(logit)
if self.enable_auxiliary_loss: if self.enable_auxiliary_loss:
low_level_feat = feat_list[self.backbone_indices[0]] low_level_feat = feat_list[self.backbone_indices[0]]
auxiliary_logit = self.auxlayer(low_level_feat) auxiliary_logit = self.auxlayer(low_level_feat)
auxiliary_logit = F.resize_bilinear(auxiliary_logit,
input.shape[2:])
logit_list.append(auxiliary_logit) logit_list.append(auxiliary_logit)
return logit_list return logit_list
...@@ -127,15 +167,8 @@ class GCNet(nn.Layer): ...@@ -127,15 +167,8 @@ class GCNet(nn.Layer):
def init_weight(self, pretrained_model=None): def init_weight(self, pretrained_model=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: pass
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class GlobalContextBlock(nn.Layer): class GlobalContextBlock(nn.Layer):
......
...@@ -17,7 +17,8 @@ import os ...@@ -17,7 +17,8 @@ import os
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_libs, pyramid_pool from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer
from paddleseg.utils import utils from paddleseg.utils import utils
...@@ -35,6 +36,54 @@ class PSPNet(nn.Layer): ...@@ -35,6 +36,54 @@ class PSPNet(nn.Layer):
num_classes (int): the unique number of target classes. num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101. backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Default to None. model_pretrained (str): the path of pretrained model. Default to None.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
pp_out_channels (int): output channels after Pyramid Pooling Module. Default to 1024.
bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
enable_auxiliary_loss (bool): a bool values indicates whether adding auxiliary loss. Default to True.
pretrained (str): the path of pretrained model. Default to None.
"""
def __init__(self,
num_classes,
backbone,
backbone_indices=(2, 3),
pp_out_channels=1024,
bin_sizes=(1, 2, 3, 6),
enable_auxiliary_loss=True,
pretrained=None):
super(PSPNet, self).__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = PSPNetHead(
num_classes,
backbone_indices,
backbone_channels,
pp_out_channels,
bin_sizes,
enable_auxiliary_loss)
utils.load_entire_model(self, pretrained)
def forward(self, input):
feat_list = self.backbone(input)
logit_list = self.head(feat_list)
return [
F.resize_bilinear(logit, input.shape[2:]) for logit in logit_list
]
class PSPNetHead(nn.Layer):
"""
The PSPNetHead implementation.
Args:
num_classes (int): the unique number of target classes.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone. backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer; the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of Pyramid Pooling Module (PPModule). the second one will be taken as input of Pyramid Pooling Module (PPModule).
...@@ -49,17 +98,14 @@ class PSPNet(nn.Layer): ...@@ -49,17 +98,14 @@ class PSPNet(nn.Layer):
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone,
model_pretrained=None,
backbone_indices=(2, 3), backbone_indices=(2, 3),
backbone_channels=(1024, 2048), backbone_channels=(1024, 2048),
pp_out_channels=1024, pp_out_channels=1024,
bin_sizes=(1, 2, 3, 6), bin_sizes=(1, 2, 3, 6),
enable_auxiliary_loss=True): enable_auxiliary_loss=True):
super(PSPNet, self).__init__() super(PSPNetHead, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
self.psp_module = pyramid_pool.PPModule( self.psp_module = pyramid_pool.PPModule(
...@@ -74,32 +120,28 @@ class PSPNet(nn.Layer): ...@@ -74,32 +120,28 @@ class PSPNet(nn.Layer):
if enable_auxiliary_loss: if enable_auxiliary_loss:
self.auxlayer = layer_libs.AuxLayer( self.auxlayer = AuxLayer(
in_channels=backbone_channels[0], in_channels=backbone_channels[0],
inter_channels=backbone_channels[0] // 4, inter_channels=backbone_channels[0] // 4,
out_channels=num_classes) out_channels=num_classes)
self.enable_auxiliary_loss = enable_auxiliary_loss self.enable_auxiliary_loss = enable_auxiliary_loss
self.init_weight(model_pretrained) self.init_weight()
def forward(self, input, label=None): def forward(self, feat_list):
logit_list = [] logit_list = []
_, feat_list = self.backbone(input)
x = feat_list[self.backbone_indices[1]] x = feat_list[self.backbone_indices[1]]
x = self.psp_module(x) x = self.psp_module(x)
x = F.dropout(x, p=0.1) # dropout_prob x = F.dropout(x, p=0.1) # dropout_prob
logit = self.conv(x) logit = self.conv(x)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit) logit_list.append(logit)
if self.enable_auxiliary_loss: if self.enable_auxiliary_loss:
auxiliary_feat = feat_list[self.backbone_indices[0]] auxiliary_feat = feat_list[self.backbone_indices[0]]
auxiliary_logit = self.auxlayer(auxiliary_feat) auxiliary_logit = self.auxlayer(auxiliary_feat)
auxiliary_logit = F.resize_bilinear(auxiliary_logit,
input.shape[2:])
logit_list.append(auxiliary_logit) logit_list.append(auxiliary_logit)
return logit_list return logit_list
...@@ -107,13 +149,6 @@ class PSPNet(nn.Layer): ...@@ -107,13 +149,6 @@ class PSPNet(nn.Layer):
def init_weight(self, pretrained_model=None): def init_weight(self, pretrained_model=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: pass
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册