# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import paddle import paddle.nn as nn import paddle.nn.functional as F from paddleseg import utils from paddleseg.cvlibs import manager, param_init from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer class SpatialGatherBlock(nn.Layer): """Aggregation layer to compute the pixel-region representation""" def forward(self, pixels, regions): n, c, h, w = pixels.shape _, k, _, _ = regions.shape # pixels: from (n, c, h, w) to (n, h*w, c) pixels = paddle.reshape(pixels, (n, c, h * w)) pixels = paddle.transpose(pixels, (0, 2, 1)) # regions: from (n, k, h, w) to (n, k, h*w) regions = paddle.reshape(regions, (n, k, h * w)) regions = F.softmax(regions, axis=2) # feats: from (n, k, c) to (n, c, k, 1) feats = paddle.bmm(regions, pixels) feats = paddle.transpose(feats, (0, 2, 1)) feats = paddle.unsqueeze(feats, axis=-1) return feats class SpatialOCRModule(nn.Layer): """Aggregate the global object representation to update the representation for each pixel""" def __init__(self, in_channels, key_channels, out_channels, dropout_rate=0.1): super(SpatialOCRModule, self).__init__() self.attention_block = ObjectAttentionBlock(in_channels, key_channels) self.dropout_rate = dropout_rate self.conv1x1 = nn.Sequential( nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1)) def forward(self, pixels, regions): context = self.attention_block(pixels, regions) feats = paddle.concat([context, pixels], axis=1) feats = self.conv1x1(feats) return feats class ObjectAttentionBlock(nn.Layer): """A self-attention module.""" def __init__(self, in_channels, key_channels): super(ObjectAttentionBlock, self).__init__() self.in_channels = in_channels self.key_channels = key_channels self.f_pixel = nn.Sequential( ConvBNReLU(in_channels, key_channels, 1), ConvBNReLU(key_channels, key_channels, 1)) self.f_object = nn.Sequential( ConvBNReLU(in_channels, key_channels, 1), ConvBNReLU(key_channels, key_channels, 1)) self.f_down = ConvBNReLU(in_channels, key_channels, 1) self.f_up = ConvBNReLU(key_channels, in_channels, 1) def forward(self, x, proxy): n, _, h, w = x.shape # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) query = self.f_pixel(x) query = paddle.reshape(query, (n, self.key_channels, -1)) query = paddle.transpose(query, (0, 2, 1)) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) key = self.f_object(proxy) key = paddle.reshape(key, (n, self.key_channels, -1)) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) value = self.f_down(proxy) value = paddle.reshape(value, (n, self.key_channels, -1)) value = paddle.transpose(value, (0, 2, 1)) # sim_map (n, h1*w1, h2*w2) sim_map = paddle.bmm(query, key) sim_map = (self.key_channels**-.5) * sim_map sim_map = F.softmax(sim_map, axis=-1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) context = paddle.bmm(sim_map, value) context = paddle.transpose(context, (0, 2, 1)) context = paddle.reshape(context, (n, self.key_channels, h, w)) context = self.f_up(context) return context class OCRHead(nn.Layer): """ The Object contextual representation head. Args: num_classes(int): the unique number of target classes. in_channels(tuple): the number of input channels. ocr_mid_channels(int): the number of middle channels in OCRHead. ocr_key_channels(int): the number of key channels in ObjectAttentionBlock. """ def __init__(self, num_classes, in_channels=None, ocr_mid_channels=512, ocr_key_channels=256): super(OCRHead, self).__init__() self.num_classes = num_classes self.spatial_gather = SpatialGatherBlock() self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels, ocr_mid_channels) self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1] self.conv3x3_ocr = ConvBNReLU( in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1) self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1) self.aux_head = AuxLayer(in_channels[self.indices[0]], in_channels[self.indices[0]], self.num_classes) self.init_weight() def forward(self, x, label=None): feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]] soft_regions = self.aux_head(feat_shallow) pixels = self.conv3x3_ocr(feat_deep) object_regions = self.spatial_gather(pixels, soft_regions) ocr = self.spatial_ocr(pixels, object_regions) logit = self.cls_head(ocr) return [logit, soft_regions] def init_weight(self): """Initialize the parameters of model parts.""" for sublayer in self.sublayers(): if isinstance(sublayer, nn.Conv2d): param_init.normal_init(sublayer.weight, scale=0.001) elif isinstance(sublayer, nn.SyncBatchNorm): param_init.constant_init(sublayer.weight, value=1) param_init.constant_init(sublayer.bias, value=0) @manager.MODELS.add_component class OCRNet(nn.Layer): """ The OCRNet implementation based on PaddlePaddle. The original article refers to Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation" (https://arxiv.org/pdf/1909.11065.pdf) Args: num_classes(int): the unique number of target classes. backbone(Paddle.nn.Layer): backbone network. pretrained(str): the path or url of pretrained model. Default to None. backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone. the first index will be taken as a deep-supervision feature in auxiliary layer; the second one will be taken as input of pixel representation. ocr_mid_channels(int): the number of middle channels in OCRHead. ocr_key_channels(int): the number of key channels in ObjectAttentionBlock. """ def __init__(self, num_classes, backbone, pretrained=None, backbone_indices=None, ocr_mid_channels=512, ocr_key_channels=256): super(OCRNet, self).__init__() self.backbone = backbone self.backbone_indices = backbone_indices in_channels = [self.backbone.channels[i] for i in backbone_indices] self.head = OCRHead( num_classes=num_classes, in_channels=in_channels, ocr_mid_channels=ocr_mid_channels, ocr_key_channels=ocr_key_channels) self.init_weight(pretrained) def forward(self, x, label=None): feats = self.backbone(x) feats = [feats[i] for i in self.backbone_indices] preds = self.head(feats, label) preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds] return preds def init_weight(self, pretrained=None): """ Initialize the parameters of model parts. Args: pretrained ([str], optional): the path of pretrained model.. Defaults to None. """ if pretrained is not None: if os.path.exists(pretrained): utils.load_pretrained_model(self, pretrained) else: raise Exception( 'Pretrained model is not found: {}'.format(pretrained))