pyramid_pool.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import SyncBatchNorm as BatchNorm

20
from paddleseg.models.common import layer_libs
21 22


23
class ASPPModule(nn.Layer):
24
    """
25
     Atrous Spatial Pyramid Pooling
26 27

    Args:
28 29 30 31 32
        aspp_ratios (tuple): the dilation rate using in ASSP module.
        in_channels (int): the number of input channels.
        out_channels (int): the number of output channels.
        sep_conv (bool): if using separable conv in ASPP module.
        image_pooling: if augmented with image-level features.
33 34
    """

M
michaelowenliu 已提交
35 36 37 38 39
    def __init__(self,
                 aspp_ratios,
                 in_channels,
                 out_channels,
                 sep_conv=False,
40 41 42 43 44 45 46 47
                 image_pooling=False):
        super(ASPPModule, self).__init__()

        self.aspp_blocks = []

        for ratio in aspp_ratios:

            if sep_conv and ratio > 1:
W
wuzewu 已提交
48
                conv_func = layer_libs.SeparableConvBNReLU
49
            else:
M
michaelowenliu 已提交
50
                conv_func = layer_libs.ConvBNReLU
51 52 53 54 55 56

            block = conv_func(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1 if ratio == 1 else 3,
                dilation=ratio,
M
michaelowenliu 已提交
57
                padding=0 if ratio == 1 else ratio)
58
            self.aspp_blocks.append(block)
M
michaelowenliu 已提交
59

60 61 62 63 64
        out_size = len(self.aspp_blocks)

        if image_pooling:
            self.global_avg_pool = nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
M
michaelowenliu 已提交
65 66
                layer_libs.ConvBNReLU(
                    in_channels, out_channels, kernel_size=1, bias_attr=False))
67 68 69
            out_size += 1
        self.image_pooling = image_pooling

M
michaelowenliu 已提交
70 71 72
        self.conv_bn_relu = layer_libs.ConvBNReLU(
            in_channels=out_channels * out_size,
            out_channels=out_channels,
73 74
            kernel_size=1)

M
michaelowenliu 已提交
75
        self.dropout = nn.Dropout(p=0.1)  # drop rate
76 77

    def forward(self, x):
78 79 80

        outputs = []
        for block in self.aspp_blocks:
M
michaelowenliu 已提交
81 82 83
            y = block(x)
            y = F.resize_bilinear(y, out_shape=x.shape[2:])
            outputs.append(y)
M
michaelowenliu 已提交
84

85 86 87 88
        if self.image_pooling:
            img_avg = self.global_avg_pool(x)
            img_avg = F.resize_bilinear(img_avg, out_shape=x.shape[2:])
            outputs.append(img_avg)
W
wuyefeilin 已提交
89

90
        x = paddle.concat(outputs, axis=1)
91
        x = self.conv_bn_relu(x)
92
        x = self.dropout(x)
93

94
        return x
M
michaelowenliu 已提交
95

96 97 98

class PPModule(nn.Layer):
    """
M
michaelowenliu 已提交
99
    Pyramid pooling module originally in PSPNet
100 101 102 103 104

    Args:
        in_channels (int): the number of intput channels to pyramid pooling module.
        out_channels (int): the number of output channels after pyramid pooling module.
        bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
M
michaelowenliu 已提交
105
        dim_reduction (bool): a bool value represent if reducing dimension after pooling. Default to True.
106 107 108 109 110 111 112 113
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 bin_sizes=(1, 2, 3, 6),
                 dim_reduction=True):
        super(PPModule, self).__init__()
114

115 116 117 118 119 120 121 122 123 124 125 126
        self.bin_sizes = bin_sizes

        inter_channels = in_channels
        if dim_reduction:
            inter_channels = in_channels // len(bin_sizes)

        # we use dimension reduction after pooling mentioned in original implementation.
        self.stages = nn.LayerList([
            self._make_stage(in_channels, inter_channels, size)
            for size in bin_sizes
        ])

M
michaelowenliu 已提交
127
        self.conv_bn_relu2 = layer_libs.ConvBNReLU(
128 129 130 131 132 133 134 135 136
            in_channels=in_channels + inter_channels * len(bin_sizes),
            out_channels=out_channels,
            kernel_size=3,
            padding=1)

    def _make_stage(self, in_channels, out_channels, size):
        """
        Create one pooling layer.

M
michaelowenliu 已提交
137
        In our implementation, we adopt the same dimension reduction as the original paper that might be
W
wuzewu 已提交
138
        slightly different with other implementations.
139 140 141 142 143 144 145 146 147 148 149 150 151

        After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations
        keep the channels to be same.


        Args:
            in_channels (int): the number of intput channels to pyramid pooling module.
            size (int): the out size of the pooled layer.

        Returns:
            conv (tensor): a tensor after Pyramid Pooling Module
        """

152
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
M
michaelowenliu 已提交
153
        conv = layer_libs.ConvBNReLU(
154 155
            in_channels=in_channels, out_channels=out_channels, kernel_size=1)

156
        return nn.Sequential(prior, conv)
157 158 159 160 161

    def forward(self, input):
        cat_layers = []
        for i, stage in enumerate(self.stages):
            size = self.bin_sizes[i]
162
            x = stage(input)
163 164 165 166 167 168
            x = F.resize_bilinear(x, out_shape=input.shape[2:])
            cat_layers.append(x)
        cat_layers = [input] + cat_layers[::-1]
        cat = paddle.concat(cat_layers, axis=1)
        out = self.conv_bn_relu2(cat)

M
michaelowenliu 已提交
169
        return out