deeplab.py 13.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os

M
michaelowenliu 已提交
18 19
from dygraph.cvlibs import manager
from dygraph.models.architectures import layer_utils
20 21 22 23 24
from paddle import fluid
from paddle.fluid import dygraph
from paddle.fluid.dygraph import Conv2D

from dygraph.utils import utils
25 26 27 28 29

__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
           "deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
           "deeplabv3p_xception65_deeplab",
           "deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"]
30 31 32 33 34 35 36 37 38 39 40 41 42 43


class ImageAverage(dygraph.Layer):
    """
    Global average pooling

    Args:
        num_channels (int): the number of input channels.

    """

    def __init__(self, num_channels):
        super(ImageAverage, self).__init__()
        self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels,
44 45
                                                   num_filters=256,
                                                   filter_size=1)
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    def forward(self, input):
        x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
        x = self.conv_bn_relu(x)
        x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:])
        return x


class ASPP(dygraph.Layer):
    """
     Decoder module of DeepLabV3P model

    Args:
        output_stride (int): the ratio of input size and final feature size. Support 16 or 8.
        in_channels (int): the number of input channels in decoder module.
        using_sep_conv (bool): whether use separable conv or not. Default to True.
    """

    def __init__(self, output_stride, in_channels, using_sep_conv=True):
        super(ASPP, self).__init__()

        if output_stride == 16:
            aspp_ratios = (6, 12, 18)
        elif output_stride == 8:
            aspp_ratios = (12, 24, 36)
        else:
            raise NotImplementedError("Only support output_stride is 8 or 16, but received{}".format(output_stride))

        self.image_average = ImageAverage(num_channels=in_channels)

        # The first aspp using 1*1 conv
        self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels,
                                            num_filters=256,
                                            filter_size=1,
80 81
                                            using_sep_conv=False)

82 83 84 85 86 87 88
        # The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0]
        self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels,
                                            num_filters=256,
                                            filter_size=3,
                                            using_sep_conv=using_sep_conv,
                                            dilation=aspp_ratios[0],
                                            padding=aspp_ratios[0])
89

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1]
        self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels,
                                            num_filters=256,
                                            filter_size=3,
                                            using_sep_conv=using_sep_conv,
                                            dilation=aspp_ratios[1],
                                            padding=aspp_ratios[1])

        # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2]
        self.aspp4 = layer_utils.ConvBnRelu(num_channels=in_channels,
                                            num_filters=256,
                                            filter_size=3,
                                            using_sep_conv=using_sep_conv,
                                            dilation=aspp_ratios[2],
                                            padding=aspp_ratios[2])
105

106 107
        # After concat op, using 1*1 conv
        self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280,
108 109
                                                   num_filters=256,
                                                   filter_size=1)
110 111

    def forward(self, x):
112

113 114 115 116 117 118
        x1 = self.image_average(x)
        x2 = self.aspp1(x)
        x3 = self.aspp2(x)
        x4 = self.aspp3(x)
        x5 = self.aspp4(x)
        x = fluid.layers.concat([x1, x2, x3, x4, x5], axis=1)
119

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        x = self.conv_bn_relu(x)
        x = fluid.layers.dropout(x, dropout_prob=0.1)
        return x


class Decoder(dygraph.Layer):
    """
    Decoder module of DeepLabV3P model

    Args:
        num_classes (int): the number of classes.
        in_channels (int): the number of input channels in decoder module.
        using_sep_conv (bool): whether use separable conv or not. Default to True.

    """

    def __init__(self, num_classes, in_channels, using_sep_conv=True):
        super(Decoder, self).__init__()
138

139 140 141
        self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels,
                                                    num_filters=48,
                                                    filter_size=1)
142

143 144 145 146 147 148 149 150 151 152
        self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304,
                                                    num_filters=256,
                                                    filter_size=3,
                                                    using_sep_conv=using_sep_conv,
                                                    padding=1)
        self.conv_bn_relu3 = layer_utils.ConvBnRelu(num_channels=256,
                                                    num_filters=256,
                                                    filter_size=3,
                                                    using_sep_conv=using_sep_conv,
                                                    padding=1)
153 154
        self.conv = Conv2D(num_channels=256,
                           num_filters=num_classes,
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
                           filter_size=1)

    def forward(self, x, low_level_feat):
        low_level_feat = self.conv_bn_relu1(low_level_feat)
        x = fluid.layers.resize_bilinear(x, low_level_feat.shape[2:])
        x = fluid.layers.concat([x, low_level_feat], axis=1)
        x = self.conv_bn_relu2(x)
        x = self.conv_bn_relu3(x)
        x = self.conv(x)
        return x


class DeepLabV3P(dygraph.Layer):
    """
    The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder
170
    The orginal artile refers to
171 172 173 174 175 176 177 178 179 180 181 182 183
    "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
     Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
     (https://arxiv.org/abs/1802.02611)

    Args:
        backbone (str): backbone name, currently support Xception65, Resnet101_vd. Default Resnet101_vd.

        num_classes (int): the unique number of target classes. Default 2.

        output_stride (int): the ratio of input size and final feature size. Default 16.

        backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
                        the first index will be taken as a low-level feature in Deconder component;
184
                        the second one will be taken as input of ASPP component.
185 186 187 188 189 190 191 192 193
                        Usually backbone consists of four downsampling stage, and return an output of
                        each stage, so we set default (0, 3), which means taking feature map of the first
                        stage in backbone as low-level feature used in Decoder, and feature map of the fourth
                        stage as input of ASPP.

        backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.

        ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default 255.

194
        using_sep_conv (bool): a bool value indicates whether using separable convolutions
195 196 197
                        in ASPP and Decoder components. Default True.
        pretrained_model (str): the pretrained_model path of backbone.
    """
198 199 200 201

    def __init__(self,
                 backbone,
                 num_classes=2,
202
                 output_stride=16,
203
                 backbone_indices=(0, 3),
204 205 206 207 208 209 210
                 backbone_channels=(256, 2048),
                 ignore_index=255,
                 using_sep_conv=True,
                 pretrained_model=None):

        super(DeepLabV3P, self).__init__()

211
        self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
212 213 214 215 216 217 218
        self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv)
        self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv)
        self.ignore_index = ignore_index
        self.EPS = 1e-5
        self.backbone_indices = backbone_indices
        self.init_weight(pretrained_model)

219 220
    def forward(self, input, label=None):

221 222 223 224 225 226
        _, feat_list = self.backbone(input)
        low_level_feat = feat_list[self.backbone_indices[0]]
        x = feat_list[self.backbone_indices[1]]
        x = self.aspp(x)
        logit = self.decoder(x, low_level_feat)
        logit = fluid.layers.resize_bilinear(logit, input.shape[2:])
227

228 229 230 231 232 233 234 235
        if self.training:
            return self._get_loss(logit, label)
        else:
            score_map = fluid.layers.softmax(logit, axis=1)
            score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
            pred = fluid.layers.argmax(score_map, axis=3)
            pred = fluid.layers.unsqueeze(pred, axes=[3])
            return pred, score_map
236

237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    def init_weight(self, pretrained_model=None):
        """
        Initialize the parameters of model parts.
        Args:
            pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
        """
        if pretrained_model is not None:
            if os.path.exists(pretrained_model):
                utils.load_pretrained_model(self.backbone, pretrained_model)
                # utils.load_pretrained_model(self, pretrained_model)
                # for param in self.backbone.parameters():
                #     param.stop_gradient = True

    def _get_loss(self, logit, label):
        """
        compute forward loss of the model

        Args:
            logit (tensor): the logit of model output
            label (tensor): ground truth

        Returns:
            avg_loss (tensor): forward loss
        """
        logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
        label = fluid.layers.transpose(label, [0, 2, 3, 1])
        mask = label != self.ignore_index
        mask = fluid.layers.cast(mask, 'float32')
        loss, probs = fluid.layers.softmax_with_cross_entropy(
            logit,
            label,
            ignore_index=self.ignore_index,
            return_softmax=True,
            axis=-1)

        loss = loss * mask
        avg_loss = fluid.layers.mean(loss) / (
274
                fluid.layers.mean(mask) + self.EPS)
275 276 277 278

        label.stop_gradient = True
        mask.stop_gradient = True

279
        return avg_loss
280 281 282 283 284


def build_aspp(output_stride, using_sep_conv):
    return ASPP(output_stride=output_stride, using_sep_conv=using_sep_conv)

285

286 287 288
def build_decoder(num_classes, using_sep_conv):
    return Decoder(num_classes, using_sep_conv=using_sep_conv)

M
michaelowenliu 已提交
289

290
@manager.MODELS.add_component
291 292 293 294
def deeplabv3p_resnet101_vd(*args, **kwargs):
    pretrained_model = None
    return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)

M
michaelowenliu 已提交
295

296
@manager.MODELS.add_component
297 298 299 300
def deeplabv3p_resnet101_vd_os8(*args, **kwargs):
    pretrained_model = None
    return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)

M
michaelowenliu 已提交
301

302
@manager.MODELS.add_component
303 304 305 306
def deeplabv3p_resnet50_vd(*args, **kwargs):
    pretrained_model = None
    return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)

M
michaelowenliu 已提交
307

308
@manager.MODELS.add_component
309 310 311 312
def deeplabv3p_resnet50_vd_os8(*args, **kwargs):
    pretrained_model = None
    return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)

M
michaelowenliu 已提交
313

314
@manager.MODELS.add_component
315 316
def deeplabv3p_xception65_deeplab(*args, **kwargs):
    pretrained_model = None
317
    return DeepLabV3P(backbone='Xception65_deeplab',
318
                      pretrained_model=pretrained_model,
319
                      backbone_indices=(0, 1),
320
                      backbone_channels=(128, 2048),
321 322
                      **kwargs)

M
michaelowenliu 已提交
323

324 325 326 327 328 329 330 331 332
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_large(*args, **kwargs):
    pretrained_model = None
    return DeepLabV3P(backbone='MobileNetV3_large_x1_0',
                      pretrained_model=pretrained_model,
                      backbone_indices=(0, 3),
                      backbone_channels=(24, 160),
                      **kwargs)

M
michaelowenliu 已提交
333

334 335 336 337 338 339 340 341
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_small(*args, **kwargs):
    pretrained_model = None
    return DeepLabV3P(backbone='MobileNetV3_small_x1_0',
                      pretrained_model=pretrained_model,
                      backbone_indices=(0, 3),
                      backbone_channels=(16, 96),
                      **kwargs)