From fca7ec8545a85eb05c0201defba90c9e6b1e557f Mon Sep 17 00:00:00 2001 From: wuzewu Date: Tue, 22 Sep 2020 15:27:22 +0800 Subject: [PATCH] Update ocrnet --- dygraph/paddleseg/models/ocrnet.py | 242 ++++++++++++++++------------- 1 file changed, 132 insertions(+), 110 deletions(-) diff --git a/dygraph/paddleseg/models/ocrnet.py b/dygraph/paddleseg/models/ocrnet.py index 00cf079c..623502ad 100644 --- a/dygraph/paddleseg/models/ocrnet.py +++ b/dygraph/paddleseg/models/ocrnet.py @@ -14,36 +14,41 @@ import os -import paddle.fluid as fluid -from paddle.fluid.dygraph import Sequential, Conv2D +import paddle +import paddle.nn as nn +import paddle.nn.functional as F -from paddleseg.cvlibs import manager -from paddleseg.models.common.layer_libs import ConvBnRelu from paddleseg import utils +from paddleseg.cvlibs import manager, param_init +from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer -class SpatialGatherBlock(fluid.dygraph.Layer): +class SpatialGatherBlock(nn.Layer): + """Aggregation layer to compute the pixel-region representation""" + def forward(self, pixels, regions): n, c, h, w = pixels.shape _, k, _, _ = regions.shape # pixels: from (n, c, h, w) to (n, h*w, c) - pixels = fluid.layers.reshape(pixels, (n, c, h * w)) - pixels = fluid.layers.transpose(pixels, (0, 2, 1)) + pixels = paddle.reshape(pixels, (n, c, h * w)) + pixels = paddle.transpose(pixels, (0, 2, 1)) # regions: from (n, k, h, w) to (n, k, h*w) - regions = fluid.layers.reshape(regions, (n, k, h * w)) - regions = fluid.layers.softmax(regions, axis=2) + regions = paddle.reshape(regions, (n, k, h * w)) + regions = F.softmax(regions, axis=2) # feats: from (n, k, c) to (n, c, k, 1) - feats = fluid.layers.matmul(regions, pixels) - feats = fluid.layers.transpose(feats, (0, 2, 1)) - feats = fluid.layers.unsqueeze(feats, axes=[-1]) + feats = paddle.bmm(regions, pixels) + feats = paddle.transpose(feats, (0, 2, 1)) + feats = paddle.unsqueeze(feats, axis=-1) return feats -class SpatialOCRModule(fluid.dygraph.Layer): +class SpatialOCRModule(nn.Layer): + """Aggregate the global object representation to update the representation for each pixel""" + def __init__(self, in_channels, key_channels, @@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer): self.attention_block = ObjectAttentionBlock(in_channels, key_channels) self.dropout_rate = dropout_rate - self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1) + self.conv1x1 = nn.Sequential( + nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1)) def forward(self, pixels, regions): context = self.attention_block(pixels, regions) - feats = fluid.layers.concat([context, pixels], axis=1) - + feats = paddle.concat([context, pixels], axis=1) feats = self.conv1x1(feats) - feats = fluid.layers.dropout(feats, self.dropout_rate) return feats -class ObjectAttentionBlock(fluid.dygraph.Layer): +class ObjectAttentionBlock(nn.Layer): + """A self-attention module.""" + def __init__(self, in_channels, key_channels): super(ObjectAttentionBlock, self).__init__() self.in_channels = in_channels self.key_channels = key_channels - self.f_pixel = Sequential( - ConvBnRelu(in_channels, key_channels, 1), - ConvBnRelu(key_channels, key_channels, 1)) + self.f_pixel = nn.Sequential( + ConvBNReLU(in_channels, key_channels, 1), + ConvBNReLU(key_channels, key_channels, 1)) - self.f_object = Sequential( - ConvBnRelu(in_channels, key_channels, 1), - ConvBnRelu(key_channels, key_channels, 1)) + self.f_object = nn.Sequential( + ConvBNReLU(in_channels, key_channels, 1), + ConvBNReLU(key_channels, key_channels, 1)) - self.f_down = ConvBnRelu(in_channels, key_channels, 1) + self.f_down = ConvBNReLU(in_channels, key_channels, 1) - self.f_up = ConvBnRelu(key_channels, in_channels, 1) + self.f_up = ConvBNReLU(key_channels, in_channels, 1) def forward(self, x, proxy): n, _, h, w = x.shape # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) query = self.f_pixel(x) - query = fluid.layers.reshape(query, (n, self.key_channels, -1)) - query = fluid.layers.transpose(query, (0, 2, 1)) + query = paddle.reshape(query, (n, self.key_channels, -1)) + query = paddle.transpose(query, (0, 2, 1)) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) key = self.f_object(proxy) - key = fluid.layers.reshape(key, (n, self.key_channels, -1)) + key = paddle.reshape(key, (n, self.key_channels, -1)) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) value = self.f_down(proxy) - value = fluid.layers.reshape(value, (n, self.key_channels, -1)) - value = fluid.layers.transpose(value, (0, 2, 1)) + value = paddle.reshape(value, (n, self.key_channels, -1)) + value = paddle.transpose(value, (0, 2, 1)) # sim_map (n, h1*w1, h2*w2) - sim_map = fluid.layers.matmul(query, key) + sim_map = paddle.bmm(query, key) sim_map = (self.key_channels**-.5) * sim_map - sim_map = fluid.layers.softmax(sim_map, axis=-1) + sim_map = F.softmax(sim_map, axis=-1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) - context = fluid.layers.matmul(sim_map, value) - context = fluid.layers.transpose(context, (0, 2, 1)) - context = fluid.layers.reshape(context, (n, self.key_channels, h, w)) + context = paddle.bmm(sim_map, value) + context = paddle.transpose(context, (0, 2, 1)) + context = paddle.reshape(context, (n, self.key_channels, h, w)) context = self.f_up(context) return context -@manager.MODELS.add_component -class OCRNet(fluid.dygraph.Layer): +class OCRHead(nn.Layer): + """ + The OCR Head. + + Args: + num_classes(int): the unique number of target classes. + in_channels(tuple): the number of input channels. + ocr_mid_channels(int): the number of middle channels in OCRHead. + ocr_key_channels(int): the number of key channels in ObjectAttentionBlock. + """ + def __init__(self, num_classes, - backbone, - model_pretrained=None, in_channels=None, ocr_mid_channels=512, - ocr_key_channels=256, - ignore_index=255): - super(OCRNet, self).__init__() + ocr_key_channels=256): + super(OCRHead, self).__init__() - self.ignore_index = ignore_index self.num_classes = num_classes - self.EPS = 1e-5 - - self.backbone = backbone self.spatial_gather = SpatialGatherBlock() self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels, ocr_mid_channels) - self.conv3x3_ocr = ConvBnRelu( - in_channels, ocr_mid_channels, 3, padding=1) - self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1) - self.aux_head = Sequential( - ConvBnRelu(in_channels, in_channels, 3, padding=1), - Conv2D(in_channels, self.num_classes, 1)) + self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1] - self.init_weight(model_pretrained) + self.conv3x3_ocr = ConvBNReLU( + in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1) + self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1) + self.aux_head = AuxLayer(in_channels[self.indices[0]], + in_channels[self.indices[0]], self.num_classes) + self.init_weight() def forward(self, x, label=None): - feats = self.backbone(x) + feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]] - soft_regions = self.aux_head(feats) - pixels = self.conv3x3_ocr(feats) + soft_regions = self.aux_head(feat_shallow) + pixels = self.conv3x3_ocr(feat_deep) object_regions = self.spatial_gather(pixels, soft_regions) ocr = self.spatial_ocr(pixels, object_regions) logit = self.cls_head(ocr) - logit = fluid.layers.resize_bilinear(logit, x.shape[2:]) - - if self.training: - soft_regions = fluid.layers.resize_bilinear(soft_regions, - x.shape[2:]) - cls_loss = self._get_loss(logit, label) - aux_loss = self._get_loss(soft_regions, label) - return cls_loss + 0.4 * aux_loss - - score_map = fluid.layers.softmax(logit, axis=1) - score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) - pred = fluid.layers.argmax(score_map, axis=3) - pred = fluid.layers.unsqueeze(pred, axes=[3]) - return pred, score_map - - def init_weight(self, pretrained_model=None): + return [logit, soft_regions] + + def init_weight(self): + """Initialize the parameters of model parts.""" + for sublayer in self.sublayers(): + if isinstance(sublayer, nn.Conv2d): + param_init.normal_init(sublayer.weight, scale=0.001) + elif isinstance(sublayer, nn.SyncBatchNorm): + param_init.constant_init(sublayer.weight, value=1) + param_init.constant_init(sublayer.bias, value=0) + + +@manager.MODELS.add_component +class OCRNet(nn.Layer): + """ + The OCRNet implementation based on PaddlePaddle. + + The orginal artile refers to + Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation" + (https://arxiv.org/pdf/1909.11065.pdf) + + Args: + num_classes(int): the unique number of target classes. + backbone(Paddle.nn.Layer): backbone network. + pretrained(str): the path or url of pretrained model. Defaullt to None. + backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone. + the first index will be taken as a deep-supervision feature in auxiliary layer; + the second one will be taken as input of pixel representation. + ocr_mid_channels(int): the number of middle channels in OCRHead. + ocr_key_channels(int): the number of key channels in ObjectAttentionBlock. + """ + + def __init__(self, + num_classes, + backbone, + pretrained=None, + backbone_indices=None, + ocr_mid_channels=512, + ocr_key_channels=256): + super(OCRNet, self).__init__() + + self.backbone = backbone + self.backbone_indices = backbone_indices + in_channels = [self.backbone.channels[i] for i in backbone_indices] + + self.head = OCRHead( + num_classes=num_classes, + in_channels=in_channels, + ocr_mid_channels=ocr_mid_channels, + ocr_key_channels=ocr_key_channels) + + self.init_weight(pretrained) + + def forward(self, x, label=None): + feats = self.backbone(x) + feats = [feats[i] for i in self.backbone_indices] + preds = self.head(feats, label) + preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds] + return preds + + def init_weight(self, pretrained=None): """ Initialize the parameters of model parts. Args: - pretrained_model ([str], optional): the path of pretrained model.. Defaults to None. + pretrained ([str], optional): the path of pretrained model.. Defaults to None. """ - if pretrained_model is not None: - if os.path.exists(pretrained_model): - utils.load_pretrained_model(self, pretrained_model) + if pretrained is not None: + if os.path.exists(pretrained): + utils.load_pretrained_model(self, pretrained) else: - raise Exception('Pretrained model is not found: {}'.format( - pretrained_model)) - - def _get_loss(self, logit, label): - """ - compute forward loss of the model - - Args: - logit (tensor): the logit of model output - label (tensor): ground truth - - Returns: - avg_loss (tensor): forward loss - """ - logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) - label = fluid.layers.transpose(label, [0, 2, 3, 1]) - mask = label != self.ignore_index - mask = fluid.layers.cast(mask, 'float32') - loss, probs = fluid.layers.softmax_with_cross_entropy( - logit, - label, - ignore_index=self.ignore_index, - return_softmax=True, - axis=-1) - - loss = loss * mask - avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + self.EPS) - - label.stop_gradient = True - mask.stop_gradient = True - - return avg_loss + raise Exception( + 'Pretrained model is not found: {}'.format(pretrained)) -- GitLab