ocrnet.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

W
wuzewu 已提交
17 18 19
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
20 21

from paddleseg import utils
W
wuzewu 已提交
22 23
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer
24 25


W
wuzewu 已提交
26 27 28
class SpatialGatherBlock(nn.Layer):
    """Aggregation layer to compute the pixel-region representation"""

29 30 31 32 33
    def forward(self, pixels, regions):
        n, c, h, w = pixels.shape
        _, k, _, _ = regions.shape

        # pixels: from (n, c, h, w) to (n, h*w, c)
W
wuzewu 已提交
34 35
        pixels = paddle.reshape(pixels, (n, c, h * w))
        pixels = paddle.transpose(pixels, (0, 2, 1))
36 37

        # regions: from (n, k, h, w) to (n, k, h*w)
W
wuzewu 已提交
38 39
        regions = paddle.reshape(regions, (n, k, h * w))
        regions = F.softmax(regions, axis=2)
40 41

        # feats: from (n, k, c) to (n, c, k, 1)
W
wuzewu 已提交
42 43 44
        feats = paddle.bmm(regions, pixels)
        feats = paddle.transpose(feats, (0, 2, 1))
        feats = paddle.unsqueeze(feats, axis=-1)
45 46 47 48

        return feats


W
wuzewu 已提交
49 50 51
class SpatialOCRModule(nn.Layer):
    """Aggregate the global object representation to update the representation for each pixel"""

52 53 54 55 56 57 58 59 60
    def __init__(self,
                 in_channels,
                 key_channels,
                 out_channels,
                 dropout_rate=0.1):
        super(SpatialOCRModule, self).__init__()

        self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
        self.dropout_rate = dropout_rate
W
wuzewu 已提交
61 62
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1))
63 64 65

    def forward(self, pixels, regions):
        context = self.attention_block(pixels, regions)
W
wuzewu 已提交
66
        feats = paddle.concat([context, pixels], axis=1)
67 68 69 70 71
        feats = self.conv1x1(feats)

        return feats


W
wuzewu 已提交
72 73 74
class ObjectAttentionBlock(nn.Layer):
    """A self-attention module."""

75 76 77 78 79 80
    def __init__(self, in_channels, key_channels):
        super(ObjectAttentionBlock, self).__init__()

        self.in_channels = in_channels
        self.key_channels = key_channels

W
wuzewu 已提交
81 82 83
        self.f_pixel = nn.Sequential(
            ConvBNReLU(in_channels, key_channels, 1),
            ConvBNReLU(key_channels, key_channels, 1))
84

W
wuzewu 已提交
85 86 87
        self.f_object = nn.Sequential(
            ConvBNReLU(in_channels, key_channels, 1),
            ConvBNReLU(key_channels, key_channels, 1))
88

W
wuzewu 已提交
89
        self.f_down = ConvBNReLU(in_channels, key_channels, 1)
90

W
wuzewu 已提交
91
        self.f_up = ConvBNReLU(key_channels, in_channels, 1)
92 93 94 95 96 97

    def forward(self, x, proxy):
        n, _, h, w = x.shape

        # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
        query = self.f_pixel(x)
W
wuzewu 已提交
98 99
        query = paddle.reshape(query, (n, self.key_channels, -1))
        query = paddle.transpose(query, (0, 2, 1))
100 101 102

        # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
        key = self.f_object(proxy)
W
wuzewu 已提交
103
        key = paddle.reshape(key, (n, self.key_channels, -1))
104 105 106

        # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
        value = self.f_down(proxy)
W
wuzewu 已提交
107 108
        value = paddle.reshape(value, (n, self.key_channels, -1))
        value = paddle.transpose(value, (0, 2, 1))
109 110

        # sim_map (n, h1*w1, h2*w2)
W
wuzewu 已提交
111
        sim_map = paddle.bmm(query, key)
112
        sim_map = (self.key_channels**-.5) * sim_map
W
wuzewu 已提交
113
        sim_map = F.softmax(sim_map, axis=-1)
114 115

        # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
W
wuzewu 已提交
116 117 118
        context = paddle.bmm(sim_map, value)
        context = paddle.transpose(context, (0, 2, 1))
        context = paddle.reshape(context, (n, self.key_channels, h, w))
119 120 121 122 123
        context = self.f_up(context)

        return context


W
wuzewu 已提交
124 125
class OCRHead(nn.Layer):
    """
W
wuzewu 已提交
126
    The Object contextual representation head.
W
wuzewu 已提交
127 128 129 130 131 132 133
    Args:
        num_classes(int): the unique number of target classes.
        in_channels(tuple): the number of input channels.
        ocr_mid_channels(int): the number of middle channels in OCRHead.
        ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
    """

134 135 136 137
    def __init__(self,
                 num_classes,
                 in_channels=None,
                 ocr_mid_channels=512,
W
wuzewu 已提交
138 139
                 ocr_key_channels=256):
        super(OCRHead, self).__init__()
140 141 142 143 144 145

        self.num_classes = num_classes
        self.spatial_gather = SpatialGatherBlock()
        self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
                                            ocr_mid_channels)

W
wuzewu 已提交
146
        self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1]
147

W
wuzewu 已提交
148 149 150 151 152 153
        self.conv3x3_ocr = ConvBNReLU(
            in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1)
        self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1)
        self.aux_head = AuxLayer(in_channels[self.indices[0]],
                                 in_channels[self.indices[0]], self.num_classes)
        self.init_weight()
154 155

    def forward(self, x, label=None):
W
wuzewu 已提交
156
        feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]]
157

W
wuzewu 已提交
158 159
        soft_regions = self.aux_head(feat_shallow)
        pixels = self.conv3x3_ocr(feat_deep)
160 161 162 163 164

        object_regions = self.spatial_gather(pixels, soft_regions)
        ocr = self.spatial_ocr(pixels, object_regions)

        logit = self.cls_head(ocr)
W
wuzewu 已提交
165 166 167 168 169 170 171 172
        return [logit, soft_regions]

    def init_weight(self):
        """Initialize the parameters of model parts."""
        for sublayer in self.sublayers():
            if isinstance(sublayer, nn.Conv2d):
                param_init.normal_init(sublayer.weight, scale=0.001)
            elif isinstance(sublayer, nn.SyncBatchNorm):
W
wuzewu 已提交
173 174
                param_init.constant_init(sublayer.weight, value=1.0)
                param_init.constant_init(sublayer.bias, value=0.0)
W
wuzewu 已提交
175 176 177 178 179 180


@manager.MODELS.add_component
class OCRNet(nn.Layer):
    """
    The OCRNet implementation based on PaddlePaddle.
W
wuzewu 已提交
181
    The original article refers to
W
wuzewu 已提交
182 183 184 185 186
        Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
        (https://arxiv.org/pdf/1909.11065.pdf)
    Args:
        num_classes(int): the unique number of target classes.
        backbone(Paddle.nn.Layer): backbone network.
W
wuzewu 已提交
187
        pretrained(str): the path or url of pretrained model. Default to None.
W
wuzewu 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
        backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
                        the first index will be taken as a deep-supervision feature in auxiliary layer;
                        the second one will be taken as input of pixel representation.
        ocr_mid_channels(int): the number of middle channels in OCRHead.
        ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
    """

    def __init__(self,
                 num_classes,
                 backbone,
                 pretrained=None,
                 backbone_indices=None,
                 ocr_mid_channels=512,
                 ocr_key_channels=256):
        super(OCRNet, self).__init__()

        self.backbone = backbone
        self.backbone_indices = backbone_indices
M
michaelowenliu 已提交
206
        in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
W
wuzewu 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223

        self.head = OCRHead(
            num_classes=num_classes,
            in_channels=in_channels,
            ocr_mid_channels=ocr_mid_channels,
            ocr_key_channels=ocr_key_channels)

        self.init_weight(pretrained)

    def forward(self, x, label=None):
        feats = self.backbone(x)
        feats = [feats[i] for i in self.backbone_indices]
        preds = self.head(feats, label)
        preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
        return preds

    def init_weight(self, pretrained=None):
224 225 226
        """
        Initialize the parameters of model parts.
        Args:
W
wuzewu 已提交
227
            pretrained ([str], optional): the path of pretrained model.. Defaults to None.
228
        """
W
wuzewu 已提交
229 230 231
        if pretrained is not None:
            if os.path.exists(pretrained):
                utils.load_pretrained_model(self, pretrained)
232
            else:
W
wuzewu 已提交
233
                raise Exception(
M
michaelowenliu 已提交
234
                    'Pretrained model is not found: {}'.format(pretrained))