ann.py 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import paddle.nn.functional as F
from paddle import nn
M
michaelowenliu 已提交
20

21
from paddleseg.cvlibs import manager
M
michaelowenliu 已提交
22
from paddleseg.models.common import layer_libs
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
from paddleseg.utils import utils


@manager.MODELS.add_component
class ANN(nn.Layer):
    """
    The ANN implementation based on PaddlePaddle.

    The orginal artile refers to 
        Zhen, Zhu, et al. "Asymmetric Non-local Neural Networks for Semantic Segmentation."
        (https://arxiv.org/pdf/1908.07678.pdf)

    It mainly consists of AFNB and APNB modules.

    Args:
        num_classes (int): the unique number of target classes.
        backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
        model_pretrained (str): the path of pretrained model. Defaullt to None.
        backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
M
michaelowenliu 已提交
42 43 44 45
            the first index will be taken as low-level features; the second one will be 
            taken as high-level features in AFNB module. Usually backbone consists of four 
            downsampling stage, and return an output of each stage, so we set default (2, 3), 
            which means taking feature map of the third stage and the fourth stage in backbone.
46 47
        backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
        key_value_channels (int): the key and value channels of self-attention map in both AFNB and APNB modules.
M
michaelowenliu 已提交
48
            Default to 256.
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        inter_channels (int): both input and output channels of APNB modules.
        psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
        enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
    """

    def __init__(self,
                 num_classes,
                 backbone,
                 model_pretrained=None,
                 backbone_indices=(2, 3),
                 backbone_channels=(1024, 2048),
                 key_value_channels=256,
                 inter_channels=512,
                 psp_size=(1, 3, 6, 8),
                 enable_auxiliary_loss=True):
        super(ANN, self).__init__()

        self.backbone = backbone

        low_in_channels = backbone_channels[0]
        high_in_channels = backbone_channels[1]

        self.fusion = AFNB(
            low_in_channels=low_in_channels,
            high_in_channels=high_in_channels,
            out_channels=high_in_channels,
            key_channels=key_value_channels,
            value_channels=key_value_channels,
            dropout_prob=0.05,
            sizes=([1]),
            psp_size=psp_size)

        self.context = nn.Sequential(
M
michaelowenliu 已提交
82
            layer_libs.ConvBnRelu(
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
                in_channels=high_in_channels,
                out_channels=inter_channels,
                kernel_size=3,
                padding=1),
            APNB(
                in_channels=inter_channels,
                out_channels=inter_channels,
                key_channels=key_value_channels,
                value_channels=key_value_channels,
                dropout_prob=0.05,
                sizes=([1]),
                psp_size=psp_size))

        self.cls = nn.Conv2d(
            in_channels=inter_channels,
            out_channels=num_classes,
            kernel_size=1)
M
michaelowenliu 已提交
100
        self.auxlayer = layer_libs.AuxLayer(
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
            in_channels=low_in_channels,
            inter_channels=low_in_channels // 2,
            out_channels=num_classes,
            dropout_prob=0.05)

        self.backbone_indices = backbone_indices
        self.enable_auxiliary_loss = enable_auxiliary_loss

        self.init_weight(model_pretrained)

    def forward(self, input, label=None):

        logit_list = []
        _, feat_list = self.backbone(input)
        low_level_x = feat_list[self.backbone_indices[0]]
        high_level_x = feat_list[self.backbone_indices[1]]
        x = self.fusion(low_level_x, high_level_x)
        x = self.context(x)
        logit = self.cls(x)
        logit = F.resize_bilinear(logit, input.shape[2:])
        logit_list.append(logit)

        if self.enable_auxiliary_loss:
            auxiliary_logit = self.auxlayer(low_level_x)
            auxiliary_logit = F.resize_bilinear(auxiliary_logit, input.shape[2:])
            logit_list.append(auxiliary_logit)

        return logit_list

    def init_weight(self, pretrained_model=None):
        """
        Initialize the parameters of model parts.

        Args:
            pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
        """

        if pretrained_model is not None:
            if os.path.exists(pretrained_model):
                utils.load_pretrained_model(self.backbone, pretrained_model)


class AFNB(nn.Layer):
    """
    Asymmetric Fusion Non-local Block

    Args:
        low_in_channels (int): low-level-feature channels.
        high_in_channels (int): high-level-feature channels.
        out_channels (int): out channels of AFNB module.
        key_channels (int): the key channels in self-attention block.
        value_channels (int): the value channels in self-attention block.
        dropout_prob (float): the dropout rate of output.
        sizes (tuple): the number of AFNB modules. Default to ([1]).
        psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
    """

    def __init__(self,
                 low_in_channels,
                 high_in_channels,
                 out_channels,
                 key_channels,
                 value_channels,
                 dropout_prob,
                 sizes=([1]),
                 psp_size=(1, 3, 6, 8)):
        super(AFNB, self).__init__()

        self.psp_size = psp_size
        self.stages = nn.LayerList([
            SelfAttentionBlock_AFNB(low_in_channels, high_in_channels,
                                    key_channels, value_channels, out_channels,
                                    size) for size in sizes
        ])
M
michaelowenliu 已提交
175
        self.conv_bn = layer_libs.ConvBn(
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
            in_channels=out_channels + high_in_channels,
            out_channels=out_channels,
            kernel_size=1)
        self.dropout_prob = dropout_prob

    def forward(self, low_feats, high_feats):
        priors = [stage(low_feats, high_feats) for stage in self.stages]
        context = priors[0]
        for i in range(1, len(priors)):
            context += priors[i]

        output = self.conv_bn(paddle.concat([context, high_feats], axis=1))
        output = F.dropout(output, p=self.dropout_prob)  # dropout_prob

        return output


class APNB(nn.Layer):
    """
    Asymmetric Pyramid Non-local Block

    Args:
        in_channels (int): the input channels of APNB module.
        out_channels (int): out channels of APNB module.
        key_channels (int): the key channels in self-attention block.
        value_channels (int): the value channels in self-attention block.
        dropout_prob (float): the dropout rate of output.
        sizes (tuple): the number of AFNB modules. Default to ([1]).
        psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 key_channels,
                 value_channels,
                 dropout_prob,
                 sizes=([1]),
                 psp_size=(1, 3, 6, 8)):
        super(APNB, self).__init__()

        self.psp_size = psp_size
        self.stages = nn.LayerList([
            SelfAttentionBlock_APNB(in_channels, out_channels, key_channels,
                                    value_channels, size) for size in sizes
        ])
M
michaelowenliu 已提交
222
        self.conv_bn = layer_libs.ConvBnRelu(
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
            in_channels=in_channels * 2,
            out_channels=out_channels,
            kernel_size=1)
        self.dropout_prob = dropout_prob

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]
        context = priors[0]
        for i in range(1, len(priors)):
            context += priors[i]

        output = self.conv_bn(paddle.concat([context, feats], axis=1))
        output = F.dropout(output, p=self.dropout_prob)  # dropout_prob

        return output


def _pp_module(x, psp_size):
    n, c, h, w = x.shape
    priors = []
    for size in psp_size:
        feat = F.adaptive_pool2d(x, pool_size=size, pool_type="avg")
        feat = paddle.reshape(feat, shape=(n, c, -1))
        priors.append(feat)
    center = paddle.concat(priors, axis=-1)
    return center


class SelfAttentionBlock_AFNB(nn.Layer):
    """
    Self-Attention Block for AFNB module.

    Args:
        low_in_channels (int): low-level-feature channels.
        high_in_channels (int): high-level-feature channels.
        key_channels (int): the key channels in self-attention block.
        value_channels (int): the value channels in self-attention block.
        out_channels (int): out channels of AFNB module.
        scale (int): pooling size. Defaut to 1.
        psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
    """

    def __init__(self,
                 low_in_channels,
                 high_in_channels,
                 key_channels,
                 value_channels,
                 out_channels=None,
                 scale=1,
                 psp_size=(1, 3, 6, 8)):
        super(SelfAttentionBlock_AFNB, self).__init__()

        self.scale = scale
        self.in_channels = low_in_channels
        self.out_channels = out_channels
        self.key_channels = key_channels
        self.value_channels = value_channels
        if out_channels == None:
            self.out_channels = high_in_channels
        self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max")
M
michaelowenliu 已提交
283
        self.f_key = layer_libs.ConvBnRelu(
284 285 286
            in_channels=low_in_channels,
            out_channels=key_channels,
            kernel_size=1)
M
michaelowenliu 已提交
287
        self.f_query = layer_libs.ConvBnRelu(
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
            in_channels=high_in_channels,
            out_channels=key_channels,
            kernel_size=1)
        self.f_value = nn.Conv2d(
            in_channels=low_in_channels,
            out_channels=value_channels,
            kernel_size=1)

        self.W = nn.Conv2d(
            in_channels=value_channels,
            out_channels=out_channels,
            kernel_size=1)

        self.psp_size = psp_size

    def forward(self, low_feats, high_feats):
        batch_size, _, h, w = high_feats.shape

        value = self.f_value(low_feats)
        value = _pp_module(value, self.psp_size)
        value = paddle.transpose(value, (0, 2, 1))

        query = self.f_query(high_feats)
        query = paddle.reshape(query, shape=(batch_size, self.key_channels, -1))
        query = paddle.transpose(query, perm=(0, 2, 1))

        key = self.f_key(low_feats)
        key = _pp_module(key, self.psp_size)

        sim_map = paddle.matmul(query, key)
        sim_map = (self.key_channels ** -.5) * sim_map
        sim_map = F.softmax(sim_map, axis=-1)

        context = paddle.matmul(sim_map, value)
        context = paddle.transpose(context, perm=(0, 2, 1))
        context = paddle.reshape(
            context,
            shape=[batch_size, self.value_channels, *high_feats.shape[2:]])

        context = self.W(context)

        return context


class SelfAttentionBlock_APNB(nn.Layer):
    """
    Self-Attention Block for APNB module.

    Args:
        in_channels (int): the input channels of APNB module.
        out_channels (int): out channels of APNB module.
        key_channels (int): the key channels in self-attention block.
        value_channels (int): the value channels in self-attention block.
        scale (int): pooling size. Defaut to 1.
        psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 key_channels,
                 value_channels,
                 scale=1,
                 psp_size=(1, 3, 6, 8)):
        super(SelfAttentionBlock_APNB, self).__init__()

        self.scale = scale
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.key_channels = key_channels
        self.value_channels = value_channels

        self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max")
M
michaelowenliu 已提交
361
        self.f_key = layer_libs.ConvBnRelu(
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
            in_channels=self.in_channels,
            out_channels=self.key_channels,
            kernel_size=1)
        self.f_query = self.f_key
        self.f_value = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.value_channels,
            kernel_size=1)
        self.W = nn.Conv2d(
            in_channels=self.value_channels,
            out_channels=self.out_channels,
            kernel_size=1)

        self.psp_size = psp_size

    def forward(self, x):
        batch_size, _, h, w = x.shape
        if self.scale > 1:
            x = self.pool(x)

        value = self.f_value(x)
        value = _pp_module(value, self.psp_size)
        value = paddle.transpose(value, perm=(0, 2, 1))

        query = self.f_query(x)
        query = paddle.reshape(
            query, shape=(batch_size, self.key_channels, -1))
        query = paddle.transpose(query, perm=(0, 2, 1))

        key = self.f_key(x)
        key = _pp_module(key, self.psp_size)

        sim_map = paddle.matmul(query, key)
        sim_map = (self.key_channels ** -.5) * sim_map
        sim_map = F.softmax(sim_map, axis=-1)

        context = paddle.matmul(sim_map, value)
        context = paddle.transpose(context, perm=(0, 2, 1))
        context = paddle.reshape(
            context, shape=[batch_size, self.value_channels, *x.shape[2:]])
        context = self.W(context)

        return context