From fc2c779983d8bccbd1b83f2e62d2f82922b04911 Mon Sep 17 00:00:00 2001 From: michaelowenliu Date: Tue, 15 Sep 2020 17:20:51 +0800 Subject: [PATCH] use new api --- dygraph/paddleseg/models/ann.py | 17 ++- .../paddleseg/models/backbones/mobilenetv3.py | 2 +- .../paddleseg/models/backbones/resnet_vd.py | 6 +- .../models/backbones/xception_deeplab.py | 2 +- dygraph/paddleseg/models/deeplab.py | 141 ++++-------------- dygraph/paddleseg/models/fast_scnn.py | 18 +-- dygraph/paddleseg/models/gcnet.py | 16 +- dygraph/paddleseg/models/ocrnet.py | 2 +- dygraph/paddleseg/models/pspnet.py | 13 +- 9 files changed, 70 insertions(+), 147 deletions(-) diff --git a/dygraph/paddleseg/models/ann.py b/dygraph/paddleseg/models/ann.py index 48c381d2..3cde2992 100644 --- a/dygraph/paddleseg/models/ann.py +++ b/dygraph/paddleseg/models/ann.py @@ -17,8 +17,9 @@ import os import paddle import paddle.nn.functional as F from paddle import nn + from paddleseg.cvlibs import manager -from paddleseg.models.common import layer_utils, model_utils +from paddleseg.models.common import layer_libs from paddleseg.utils import utils @@ -88,7 +89,7 @@ class ANN(nn.Layer): psp_size=psp_size) self.context = nn.Sequential( - layer_utils.ConvBnRelu( + layer_libs.ConvBnRelu( in_channels=high_in_channels, out_channels=inter_channels, kernel_size=3, @@ -106,7 +107,7 @@ class ANN(nn.Layer): in_channels=inter_channels, out_channels=num_classes, kernel_size=1) - self.auxlayer = model_utils.AuxLayer( + self.auxlayer = layer_libs.AuxLayer( in_channels=low_in_channels, inter_channels=low_in_channels // 2, out_channels=num_classes, @@ -189,7 +190,7 @@ class AFNB(nn.Layer): key_channels, value_channels, out_channels, size) for size in sizes ]) - self.conv_bn = layer_utils.ConvBn( + self.conv_bn = layer_libs.ConvBn( in_channels=out_channels + high_in_channels, out_channels=out_channels, kernel_size=1) @@ -243,7 +244,7 @@ class APNB(nn.Layer): SelfAttentionBlock_APNB(in_channels, out_channels, key_channels, value_channels, size) for size in sizes ]) - self.conv_bn = layer_utils.ConvBnRelu( + self.conv_bn = layer_libs.ConvBnRelu( in_channels=in_channels * 2, out_channels=out_channels, kernel_size=1) @@ -310,11 +311,11 @@ class SelfAttentionBlock_AFNB(nn.Layer): if out_channels == None: self.out_channels = high_in_channels self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max") - self.f_key = layer_utils.ConvBnRelu( + self.f_key = layer_libs.ConvBnRelu( in_channels=low_in_channels, out_channels=key_channels, kernel_size=1) - self.f_query = layer_utils.ConvBnRelu( + self.f_query = layer_libs.ConvBnRelu( in_channels=high_in_channels, out_channels=key_channels, kernel_size=1) @@ -393,7 +394,7 @@ class SelfAttentionBlock_APNB(nn.Layer): self.value_channels = value_channels self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max") - self.f_key = layer_utils.ConvBnRelu( + self.f_key = layer_libs.ConvBnRelu( in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1) diff --git a/dygraph/paddleseg/models/backbones/mobilenetv3.py b/dygraph/paddleseg/models/backbones/mobilenetv3.py index 6204d773..ac1778ad 100644 --- a/dygraph/paddleseg/models/backbones/mobilenetv3.py +++ b/dygraph/paddleseg/models/backbones/mobilenetv3.py @@ -27,7 +27,7 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout from paddle.nn import SyncBatchNorm as BatchNorm -from paddleseg.models.common import layer_utils +from paddleseg.models.common import layer_libs from paddleseg.cvlibs import manager from paddleseg.utils import utils diff --git a/dygraph/paddleseg/models/backbones/resnet_vd.py b/dygraph/paddleseg/models/backbones/resnet_vd.py index d7dfc66f..787f6a3b 100644 --- a/dygraph/paddleseg/models/backbones/resnet_vd.py +++ b/dygraph/paddleseg/models/backbones/resnet_vd.py @@ -28,7 +28,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout from paddle.nn import SyncBatchNorm as BatchNorm from paddleseg.utils import utils -from paddleseg.models.common import layer_utils +from paddleseg.models.common import layer_libs, activation from paddleseg.cvlibs import manager __all__ = [ @@ -77,7 +77,7 @@ class ConvBNLayer(fluid.dygraph.Layer): num_filters, weight_attr=ParamAttr(name=bn_name + '_scale'), bias_attr=ParamAttr(bn_name + '_offset')) - self._act_op = layer_utils.Activation(act=act) + self._act_op = activation.Activation(act=act) def forward(self, inputs): if self.is_vd_mode: @@ -213,7 +213,7 @@ class ResNet_vd(fluid.dygraph.Layer): layers=50, class_dim=1000, output_stride=None, - multi_grid=(1, 2, 4)): + multi_grid=(1, 1, 1)): super(ResNet_vd, self).__init__() self.layers = layers diff --git a/dygraph/paddleseg/models/backbones/xception_deeplab.py b/dygraph/paddleseg/models/backbones/xception_deeplab.py index f512e31a..b07d3ac1 100644 --- a/dygraph/paddleseg/models/backbones/xception_deeplab.py +++ b/dygraph/paddleseg/models/backbones/xception_deeplab.py @@ -21,7 +21,7 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout from paddle.nn import SyncBatchNorm as BatchNorm -from paddleseg.models.common import layer_utils +from paddleseg.models.common import layer_libs from paddleseg.cvlibs import manager from paddleseg.utils import utils diff --git a/dygraph/paddleseg/models/deeplab.py b/dygraph/paddleseg/models/deeplab.py index 7c7e0cb1..ff530b2f 100644 --- a/dygraph/paddleseg/models/deeplab.py +++ b/dygraph/paddleseg/models/deeplab.py @@ -18,7 +18,7 @@ import paddle import paddle.nn.functional as F from paddle import nn from paddleseg.cvlibs import manager -from paddleseg.models.common import layer_utils +from paddleseg.models.common import pyramid_pool, layer_libs from paddleseg.utils import utils __all__ = ['DeepLabV3P', 'DeepLabV3'] @@ -43,8 +43,9 @@ class DeepLabV3P(nn.Layer): model_pretrained (str): the path of pretrained model. - output_stride (int): the ratio of input size and final feature size. - Support 16 or 8. Default to 16. + aspp_ratios (tuple): the dilation rate using in ASSP module. + if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18). + if output_stride=8, aspp_ratios is (1, 12, 24, 36). backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone. the first index will be taken as a low-level feature in Deconder component; @@ -61,18 +62,24 @@ class DeepLabV3P(nn.Layer): def __init__(self, num_classes, backbone, + backbone_pretrained=None, model_pretrained=None, backbone_indices=(0, 3), backbone_channels=(256, 2048), - output_stride=16): + aspp_ratios=(1, 6, 12, 18), + aspp_out_channels=256): super(DeepLabV3P, self).__init__() self.backbone = backbone - self.aspp = ASPP(output_stride, backbone_channels[1]) + self.backbone_pretrained = backbone_pretrained + self.model_pretrained = model_pretrained + + self.aspp = pyramid_pool.ASPPModule( + aspp_ratios, backbone_channels[1], aspp_out_channels, sep_conv=True, image_pooling=True) self.decoder = Decoder(num_classes, backbone_channels[0]) self.backbone_indices = backbone_indices - self.init_weight(model_pretrained) + self.init_weight() def forward(self, input, label=None): @@ -87,19 +94,17 @@ class DeepLabV3P(nn.Layer): return logit_list - def init_weight(self, pretrained_model=None): + def init_weight(self): """ Initialize the parameters of model parts. Args: pretrained_model ([str], optional): the path of pretrained model. Defaults to None. """ - if pretrained_model is not None: - if os.path.exists(pretrained_model): - utils.load_pretrained_model(self, pretrained_model) - else: - raise Exception('Pretrained model is not found: {}'.format( - pretrained_model)) - + if self.model_pretrained is not None: + utils.load_pretrained_model(self, self.model_pretrained) + elif self.backbone_pretrained is not None: + utils.load_pretrained_model(self.backbone, self.backbone_pretrained) + @manager.MODELS.add_component class DeepLabV3(nn.Layer): @@ -119,15 +124,21 @@ class DeepLabV3(nn.Layer): def __init__(self, num_classes, backbone, + backbone_pretrained=None, model_pretrained=None, backbone_indices=(3,), backbone_channels=(2048,), - output_stride=16): + aspp_ratios=(1, 6, 12, 18), + aspp_out_channels=256): super(DeepLabV3, self).__init__() self.backbone = backbone - self.aspp = ASPP(output_stride, backbone_channels[0]) + + self.aspp = pyramid_pool.ASPPModule( + aspp_ratios, backbone_channels[0], aspp_out_channels, + sep_conv=False, image_pooling=True) + self.cls = nn.Conv2d( in_channels=backbone_channels[0], out_channels=num_classes, @@ -161,98 +172,6 @@ class DeepLabV3(nn.Layer): pretrained_model)) -class ImageAverage(nn.Layer): - """ - Global average pooling - - Args: - in_channels (int): the number of input channels. - - """ - - def __init__(self, in_channels): - super(ImageAverage, self).__init__() - self.conv_bn_relu = layer_utils.ConvBnRelu( - in_channels, out_channels=256, kernel_size=1) - - def forward(self, input): - x = paddle.reduce_mean(input, dim=[2, 3], keep_dim=True) - x = self.conv_bn_relu(x) - x = F.resize_bilinear(x, out_shape=input.shape[2:]) - return x - - -class ASPP(nn.Layer): - """ - Decoder module of DeepLabV3P model - - Args: - output_stride (int): the ratio of input size and final feature size. Support 16 or 8. - - in_channels (int): the number of input channels in decoder module. - - """ - - def __init__(self, output_stride, in_channels): - super(ASPP, self).__init__() - - if output_stride == 16: - aspp_ratios = (6, 12, 18) - elif output_stride == 8: - aspp_ratios = (12, 24, 36) - else: - raise NotImplementedError( - "Only support output_stride is 8 or 16, but received{}".format( - output_stride)) - - self.image_average = ImageAverage(in_channels=in_channels) - - # The first aspp using 1*1 conv - self.aspp1 = layer_utils.DepthwiseConvBnRelu( - in_channels=in_channels, out_channels=256, kernel_size=1) - - # The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0] - self.aspp2 = layer_utils.DepthwiseConvBnRelu( - in_channels=in_channels, - out_channels=256, - kernel_size=3, - dilation=aspp_ratios[0], - padding=aspp_ratios[0]) - - # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1] - self.aspp3 = layer_utils.DepthwiseConvBnRelu( - in_channels=in_channels, - out_channels=256, - kernel_size=3, - dilation=aspp_ratios[1], - padding=aspp_ratios[1]) - - # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2] - self.aspp4 = layer_utils.DepthwiseConvBnRelu( - in_channels=in_channels, - out_channels=256, - kernel_size=3, - dilation=aspp_ratios[2], - padding=aspp_ratios[2]) - - # After concat op, using 1*1 conv - self.conv_bn_relu = layer_utils.ConvBnRelu( - in_channels=1280, out_channels=256, kernel_size=1) - - def forward(self, x): - - x1 = self.image_average(x) - x2 = self.aspp1(x) - x3 = self.aspp2(x) - x4 = self.aspp3(x) - x5 = self.aspp4(x) - x = paddle.concat([x1, x2, x3, x4, x5], axis=1) - - x = self.conv_bn_relu(x) - x = F.dropout(x, p=0.1) # dropout_prob - return x - - class Decoder(nn.Layer): """ Decoder module of DeepLabV3P model @@ -267,12 +186,12 @@ class Decoder(nn.Layer): def __init__(self, num_classes, in_channels): super(Decoder, self).__init__() - self.conv_bn_relu1 = layer_utils.ConvBnRelu( + self.conv_bn_relu1 = layer_libs.ConvBnRelu( in_channels=in_channels, out_channels=48, kernel_size=1) - self.conv_bn_relu2 = layer_utils.DepthwiseConvBnRelu( + self.conv_bn_relu2 = layer_libs.DepthwiseConvBnRelu( in_channels=304, out_channels=256, kernel_size=3, padding=1) - self.conv_bn_relu3 = layer_utils.DepthwiseConvBnRelu( + self.conv_bn_relu3 = layer_libs.DepthwiseConvBnRelu( in_channels=256, out_channels=256, kernel_size=3, padding=1) self.conv = nn.Conv2d( in_channels=256, out_channels=num_classes, kernel_size=1) diff --git a/dygraph/paddleseg/models/fast_scnn.py b/dygraph/paddleseg/models/fast_scnn.py index 434f083e..3abbcffc 100644 --- a/dygraph/paddleseg/models/fast_scnn.py +++ b/dygraph/paddleseg/models/fast_scnn.py @@ -15,7 +15,7 @@ import paddle.nn.functional as F from paddle import nn from paddleseg.cvlibs import manager -from paddleseg.models.common import layer_utils, model_utils +from paddleseg.models.common import layer_libs @manager.MODELS.add_component @@ -110,15 +110,15 @@ class LearningToDownsample(nn.Layer): def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64): super(LearningToDownsample, self).__init__() - self.conv_bn_relu = layer_utils.ConvBnRelu( + self.conv_bn_relu = layer_libs.ConvBnRelu( in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2) - self.dsconv_bn_relu1 = layer_utils.DepthwiseConvBnRelu( + self.dsconv_bn_relu1 = layer_libs.DepthwiseConvBnRelu( in_channels=dw_channels1, out_channels=dw_channels2, kernel_size=3, stride=2, padding=1) - self.dsconv_bn_relu2 = layer_utils.DepthwiseConvBnRelu( + self.dsconv_bn_relu2 = layer_libs.DepthwiseConvBnRelu( in_channels=dw_channels2, out_channels=out_channels, kernel_size=3, @@ -220,13 +220,13 @@ class LinearBottleneck(nn.Layer): expand_channels = in_channels * expansion self.block = nn.Sequential( # pw - layer_utils.ConvBnRelu( + layer_libs.ConvBnRelu( in_channels=in_channels, out_channels=expand_channels, kernel_size=1, bias_attr=False), # dw - layer_utils.ConvBnRelu( + layer_libs.ConvBnRelu( in_channels=expand_channels, out_channels=expand_channels, kernel_size=3, @@ -267,7 +267,7 @@ class FeatureFusionModule(nn.Layer): super(FeatureFusionModule, self).__init__() # There only depth-wise conv is used WITHOUT point-wise conv - self.dwconv = layer_utils.ConvBnRelu( + self.dwconv = layer_libs.ConvBnRelu( in_channels=low_in_channels, out_channels=out_channels, kernel_size=3, @@ -317,13 +317,13 @@ class Classifier(nn.Layer): def __init__(self, input_channels, num_classes): super(Classifier, self).__init__() - self.dsconv1 = layer_utils.DepthwiseConvBnRelu( + self.dsconv1 = layer_libs.DepthwiseConvBnRelu( in_channels=input_channels, out_channels=input_channels, kernel_size=3, padding=1) - self.dsconv2 = layer_utils.DepthwiseConvBnRelu( + self.dsconv2 = layer_libs.DepthwiseConvBnRelu( in_channels=input_channels, out_channels=input_channels, kernel_size=3, diff --git a/dygraph/paddleseg/models/gcnet.py b/dygraph/paddleseg/models/gcnet.py index 97a70d13..09a90065 100644 --- a/dygraph/paddleseg/models/gcnet.py +++ b/dygraph/paddleseg/models/gcnet.py @@ -18,7 +18,7 @@ import paddle import paddle.nn.functional as F from paddle import nn from paddleseg.cvlibs import manager -from paddleseg.models.common import layer_utils, model_utils +from paddleseg.models.common import layer_libs from paddleseg.utils import utils @@ -72,7 +72,7 @@ class GCNet(nn.Layer): self.backbone = backbone in_channels = backbone_channels[1] - self.conv_bn_relu1 = layer_utils.ConvBnRelu( + self.conv_bn_relu1 = layer_libs.ConvBnRelu( in_channels=in_channels, out_channels=gc_channels, kernel_size=3, @@ -80,13 +80,13 @@ class GCNet(nn.Layer): self.gc_block = GlobalContextBlock(in_channels=gc_channels, ratio=ratio) - self.conv_bn_relu2 = layer_utils.ConvBnRelu( + self.conv_bn_relu2 = layer_libs.ConvBnRelu( in_channels=gc_channels, out_channels=gc_channels, kernel_size=3, padding=1) - self.conv_bn_relu3 = layer_utils.ConvBnRelu( + self.conv_bn_relu3 = layer_libs.ConvBnRelu( in_channels=in_channels + gc_channels, out_channels=gc_channels, kernel_size=3, @@ -96,7 +96,7 @@ class GCNet(nn.Layer): in_channels=gc_channels, out_channels=num_classes, kernel_size=1) if enable_auxiliary_loss: - self.auxlayer = model_utils.AuxLayer( + self.auxlayer = layer_libs.AuxLayer( in_channels=backbone_channels[0], inter_channels=backbone_channels[0] // 4, out_channels=num_classes) @@ -161,9 +161,9 @@ class GlobalContextBlock(nn.Layer): self.conv_mask = nn.Conv2d( in_channels=in_channels, out_channels=1, kernel_size=1) - # current paddle version does not support Softmax class - # self.softmax = layer_utils.Activation("softmax", dim=2) + self.softmax = nn.Softmax(axis=2) + inter_channels = int(in_channels * ratio) self.channel_add_conv = nn.Sequential( nn.Conv2d( @@ -188,7 +188,7 @@ class GlobalContextBlock(nn.Layer): # [N, 1, H * W] context_mask = paddle.reshape( context_mask, shape=[batch, 1, height * width]) - context_mask = F.softmax(context_mask) + context_mask = self.softmax(context_mask) # [N, 1, H * W, 1] context_mask = paddle.unsqueeze(context_mask, axis=-1) # [N, 1, C, 1] diff --git a/dygraph/paddleseg/models/ocrnet.py b/dygraph/paddleseg/models/ocrnet.py index 78dfd136..00cf079c 100644 --- a/dygraph/paddleseg/models/ocrnet.py +++ b/dygraph/paddleseg/models/ocrnet.py @@ -18,7 +18,7 @@ import paddle.fluid as fluid from paddle.fluid.dygraph import Sequential, Conv2D from paddleseg.cvlibs import manager -from paddleseg.models.common.layer_utils import ConvBnRelu +from paddleseg.models.common.layer_libs import ConvBnRelu from paddleseg import utils diff --git a/dygraph/paddleseg/models/pspnet.py b/dygraph/paddleseg/models/pspnet.py index 764749ce..69b831eb 100644 --- a/dygraph/paddleseg/models/pspnet.py +++ b/dygraph/paddleseg/models/pspnet.py @@ -17,7 +17,7 @@ import os import paddle.nn.functional as F from paddle import nn from paddleseg.cvlibs import manager -from paddleseg.models.common import model_utils +from paddleseg.models.common import layer_libs, pyramid_pool from paddleseg.utils import utils @@ -70,7 +70,7 @@ class PSPNet(nn.Layer): self.backbone = backbone self.backbone_indices = backbone_indices - self.psp_module = model_utils.PPModule( + self.psp_module = pyramid_pool.PPModule( in_channels=backbone_channels[1], out_channels=pp_out_channels, bin_sizes=bin_sizes) @@ -81,8 +81,11 @@ class PSPNet(nn.Layer): kernel_size=1) if enable_auxiliary_loss: - self.fcn_head = model_utils.FCNHead( - in_channels=backbone_channels[0], out_channels=num_classes) + + self.auxlayer = layer_libs.AuxLayer( + in_channels=backbone_channels[0], + inter_channels=backbone_channels[0] // 4, + out_channels=num_classes) self.enable_auxiliary_loss = enable_auxiliary_loss @@ -102,7 +105,7 @@ class PSPNet(nn.Layer): if self.enable_auxiliary_loss: auxiliary_feat = feat_list[self.backbone_indices[0]] - auxiliary_logit = self.fcn_head(auxiliary_feat) + auxiliary_logit = self.auxlayer(auxiliary_feat) auxiliary_logit = F.resize_bilinear(auxiliary_logit, input.shape[2:]) logit_list.append(auxiliary_logit) -- GitLab