diff --git a/dygraph/paddleseg/models/common/layer_libs.py b/dygraph/paddleseg/models/common/layer_libs.py index 6f79f84ed2bee9059cdd0137783760d2fb80fb0d..0335cfacfd190f1d84a52e73feb05a316d7af4a2 100644 --- a/dygraph/paddleseg/models/common/layer_libs.py +++ b/dygraph/paddleseg/models/common/layer_libs.py @@ -21,11 +21,17 @@ from paddle.nn import SyncBatchNorm as BatchNorm class ConvBNReLU(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): super(ConvBNReLU, self).__init__() - self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs) + self._conv = Conv2d( + in_channels, out_channels, kernel_size, padding=padding, **kwargs) self._batch_norm = BatchNorm(out_channels) @@ -37,10 +43,16 @@ class ConvBNReLU(nn.Layer): class ConvBN(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): super(ConvBN, self).__init__() - self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs) + self._conv = Conv2d( + in_channels, out_channels, kernel_size, padding=padding, **kwargs) self._batch_norm = BatchNorm(out_channels) def forward(self, x): @@ -67,17 +79,23 @@ class ConvReluPool(nn.Layer): return x -class DepthwiseConvBNReLU(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): - super(DepthwiseConvBNReLU, self).__init__() +class SeparableConvBNReLU(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): + super(SeparableConvBNReLU, self).__init__() self.depthwise_conv = ConvBN( in_channels, out_channels=in_channels, kernel_size=kernel_size, + padding=padding, groups=in_channels, **kwargs) self.piontwise_conv = ConvBNReLU( - in_channels, out_channels, kernel_size=1, groups=1) + in_channels, out_channels, kernel_size=1, padding=padding, groups=1) def forward(self, x): x = self.depthwise_conv(x) @@ -86,20 +104,23 @@ class DepthwiseConvBNReLU(nn.Layer): class DepthwiseConvBN(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): super(DepthwiseConvBN, self).__init__() self.depthwise_conv = ConvBN( in_channels, - out_channels=in_channels, + out_channels=out_channels, kernel_size=kernel_size, + padding=padding, groups=in_channels, **kwargs) - self.piontwise_conv = ConvBN( - in_channels, out_channels, kernel_size=1, groups=1) def forward(self, x): x = self.depthwise_conv(x) - x = self.piontwise_conv(x) return x diff --git a/dygraph/paddleseg/models/common/pyramid_pool.py b/dygraph/paddleseg/models/common/pyramid_pool.py index 86a98b38625221db34eef765a000f91536ec4b72..f196be509b693b198e7b595c96ca0f63fcd1a01f 100644 --- a/dygraph/paddleseg/models/common/pyramid_pool.py +++ b/dygraph/paddleseg/models/common/pyramid_pool.py @@ -46,7 +46,7 @@ class ASPPModule(nn.Layer): for ratio in aspp_ratios: if sep_conv and ratio > 1: - conv_func = layer_libs.DepthwiseConvBNReLU + conv_func = layer_libs.SeparableConvBNReLU else: conv_func = layer_libs.ConvBNReLU @@ -134,7 +134,7 @@ class PPModule(nn.Layer): Create one pooling layer. In our implementation, we adopt the same dimension reduction as the original paper that might be - slightly different with other implementations. + slightly different with other implementations. After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations keep the channels to be same. diff --git a/dygraph/paddleseg/models/deeplab.py b/dygraph/paddleseg/models/deeplab.py index 1b041e5dacb16d7c68fb5f251a8a29f8e39e8370..465503eeeb5502be2c63d9b1a113f780b7e6ef8f 100644 --- a/dygraph/paddleseg/models/deeplab.py +++ b/dygraph/paddleseg/models/deeplab.py @@ -19,7 +19,7 @@ import paddle.nn.functional as F from paddle import nn from paddleseg.cvlibs import manager from paddleseg.models.common import pyramid_pool -from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer +from paddleseg.models.common.layer_libs import ConvBNReLU, SeparableConvBNReLU, AuxLayer from paddleseg.utils import utils __all__ = ['DeepLabV3P', 'DeepLabV3'] @@ -99,7 +99,7 @@ class DeepLabV3PHead(nn.Layer): if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18). if output_stride=8, aspp_ratios is (1, 12, 24, 36). aspp_out_channels (int): the output channels of ASPP module. - + """ def __init__(self, @@ -146,7 +146,7 @@ class DeepLabV3(nn.Layer): (https://arxiv.org/pdf/1706.05587.pdf) Args: - Refer to DeepLabV3P above + Refer to DeepLabV3P above """ def __init__(self, @@ -234,9 +234,9 @@ class Decoder(nn.Layer): self.conv_bn_relu1 = ConvBNReLU( in_channels=in_channels, out_channels=48, kernel_size=1) - self.conv_bn_relu2 = DepthwiseConvBNReLU( + self.conv_bn_relu2 = SeparableConvBNReLU( in_channels=304, out_channels=256, kernel_size=3, padding=1) - self.conv_bn_relu3 = DepthwiseConvBNReLU( + self.conv_bn_relu3 = SeparableConvBNReLU( in_channels=256, out_channels=256, kernel_size=3, padding=1) self.conv = nn.Conv2d( in_channels=256, out_channels=num_classes, kernel_size=1) diff --git a/dygraph/paddleseg/models/fast_scnn.py b/dygraph/paddleseg/models/fast_scnn.py index baddfe333117eb57ad1916bab5630e14c9cd51f3..2a916835241581f9f3cab4616bcbf39330ad70fb 100644 --- a/dygraph/paddleseg/models/fast_scnn.py +++ b/dygraph/paddleseg/models/fast_scnn.py @@ -17,18 +17,19 @@ from paddle import nn from paddleseg.cvlibs import manager from paddleseg.models.common import pyramid_pool -from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer +from paddleseg.models.common.layer_libs import ConvBNReLU, SeparableConvBNReLU, AuxLayer from paddleseg.utils import utils + @manager.MODELS.add_component class FastSCNN(nn.Layer): """ The FastSCNN implementation based on PaddlePaddle. - As mentioned in the original paper, FastSCNN is a real-time segmentation algorithm (123.5fps) + As mentioned in the original paper, FastSCNN is a real-time segmentation algorithm (123.5fps) even for high resolution images (1024x2048). - The original article refers to + The original article refers to Poudel, Rudra PK, et al. "Fast-scnn: Fast semantic segmentation network." (https://arxiv.org/pdf/1902.04502.pdf) @@ -40,9 +41,7 @@ class FastSCNN(nn.Layer): pretrained (str): the path of pretrained model. Default to None. """ - def __init__(self, - num_classes, - enable_auxiliary_loss=True, + def __init__(self, num_classes, enable_auxiliary_loss=True, pretrained=None): super(FastSCNN, self).__init__() @@ -103,13 +102,13 @@ class LearningToDownsample(nn.Layer): self.conv_bn_relu = ConvBNReLU( in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2) - self.dsconv_bn_relu1 = DepthwiseConvBNReLU( + self.dsconv_bn_relu1 = SeparableConvBNReLU( in_channels=dw_channels1, out_channels=dw_channels2, kernel_size=3, stride=2, padding=1) - self.dsconv_bn_relu2 = DepthwiseConvBNReLU( + self.dsconv_bn_relu2 = SeparableConvBNReLU( in_channels=dw_channels2, out_channels=out_channels, kernel_size=3, @@ -127,7 +126,7 @@ class GlobalFeatureExtractor(nn.Layer): """ Global feature extractor module - This module consists of three LinearBottleneck blocks (like inverted residual introduced by MobileNetV2) and + This module consists of three LinearBottleneck blocks (like inverted residual introduced by MobileNetV2) and a PPModule (introduced by PSPNet). Args: @@ -297,13 +296,13 @@ class Classifier(nn.Layer): def __init__(self, input_channels, num_classes): super(Classifier, self).__init__() - self.dsconv1 = DepthwiseConvBNReLU( + self.dsconv1 = SeparableConvBNReLU( in_channels=input_channels, out_channels=input_channels, kernel_size=3, padding=1) - self.dsconv2 = DepthwiseConvBNReLU( + self.dsconv2 = SeparableConvBNReLU( in_channels=input_channels, out_channels=input_channels, kernel_size=3,