diff --git a/dygraph/configs/_base_/cityscapes.yml b/dygraph/configs/_base_/cityscapes.yml index a17f195794afb53d6ca572ee6b8e1996cf2f2ee1..bd146b66d8ce50abbaafe2b43ea9fd4498d8e99f 100644 --- a/dygraph/configs/_base_/cityscapes.yml +++ b/dygraph/configs/_base_/cityscapes.yml @@ -4,7 +4,7 @@ learning_rate: 0.01 train_dataset: type: Cityscapes - dataset_root: data/cityscapes + dataset_root: /mnt/liuyi22/.cache/paddle/dataset/cityscapes transforms: - type: ResizeStepScaling min_scale_factor: 0.5 @@ -18,7 +18,7 @@ train_dataset: val_dataset: type: Cityscapes - dataset_root: data/cityscapes + dataset_root: /mnt/liuyi22/.cache/paddle/dataset/cityscapes transforms: - type: Normalize mode: val diff --git a/dygraph/paddleseg/core/seg_train.py b/dygraph/paddleseg/core/seg_train.py index 01ed64e1999ddc36cfd42acb4db28094b50cbf35..ff9cdcbcd42b2553776a673302db8199cd1f9eaa 100644 --- a/dygraph/paddleseg/core/seg_train.py +++ b/dygraph/paddleseg/core/seg_train.py @@ -87,7 +87,8 @@ def seg_train(model, out_labels = ["loss", "reader_cost", "batch_cost"] base_logger = callbacks.BaseLogger(period=log_iters) train_logger = callbacks.TrainLogger(log_freq=log_iters) - model_ckpt = callbacks.ModelCheckpoint(save_dir, save_params_only=False, period=save_interval_iters) + model_ckpt = callbacks.ModelCheckpoint( + save_dir, save_params_only=False, period=save_interval_iters) vdl = callbacks.VisualDL(log_dir=os.path.join(save_dir, "log")) cbks_list = [base_logger, train_logger, model_ckpt, vdl] @@ -120,7 +121,7 @@ def seg_train(model, iter += 1 if iter > iters: break - + logs["reader_cost"] = timer.elapsed_time() ############## 2 ################ cbks.on_iter_begin(iter, logs) @@ -136,7 +137,7 @@ def seg_train(model, loss = ddp_model.scale_loss(loss) loss.backward() ddp_model.apply_collective_grads() - + else: logits = model(images) loss = loss_computation(logits, labels, losses) @@ -148,7 +149,7 @@ def seg_train(model, model.clear_gradients() logs['loss'] = loss.numpy()[0] - + logs["batch_cost"] = timer.elapsed_time() ############## 3 ################ @@ -159,4 +160,6 @@ def seg_train(model, ############### 4 ############### cbks.on_train_end(logs) -################################# \ No newline at end of file + + +################################# diff --git a/dygraph/paddleseg/core/val.py b/dygraph/paddleseg/core/val.py index c104b2d8bf67419c58f15ba75989720662b0a2d8..b0f408c3b96f0040d8ca0882b701c7e56315c595 100644 --- a/dygraph/paddleseg/core/val.py +++ b/dygraph/paddleseg/core/val.py @@ -67,7 +67,7 @@ def evaluate(model, pred = pred[np.newaxis, :, :, np.newaxis] pred = pred.astype('int64') mask = label != ignore_index - + # To-DO Test Execution Time conf_mat.calculate(pred=pred, label=label, ignore=mask) _, iou = conf_mat.mean_iou() diff --git a/dygraph/paddleseg/cvlibs/callbacks.py b/dygraph/paddleseg/cvlibs/callbacks.py index 952c97d843fcfd8fc198d5d09f61f7f58652ec6e..e948344ae499f333e435fa41fbd8f458cb8b3e2e 100644 --- a/dygraph/paddleseg/cvlibs/callbacks.py +++ b/dygraph/paddleseg/cvlibs/callbacks.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os import time @@ -24,6 +23,7 @@ from visualdl import LogWriter from paddleseg.utils.progbar import Progbar import paddleseg.utils.logger as logger + class CallbackList(object): """Container abstracting a list of callbacks. # Arguments @@ -44,7 +44,7 @@ class CallbackList(object): def set_model(self, model): for callback in self.callbacks: callback.set_model(model) - + def set_optimizer(self, optimizer): for callback in self.callbacks: callback.set_optimizer(optimizer) @@ -82,6 +82,7 @@ class CallbackList(object): def __iter__(self): return iter(self.callbacks) + class Callback(object): """Abstract base class used to build new callbacks. """ @@ -94,7 +95,7 @@ class Callback(object): def set_model(self, model): self.model = model - + def set_optimizer(self, optimizer): self.optimizer = optimizer @@ -110,18 +111,18 @@ class Callback(object): def on_train_end(self, logs=None): pass -class BaseLogger(Callback): +class BaseLogger(Callback): def __init__(self, period=10): super(BaseLogger, self).__init__() self.period = period - + def _reset(self): self.totals = {} def on_train_begin(self, logs=None): self.totals = {} - + def on_iter_end(self, iter, logs=None): logs = logs or {} #(iter - 1) // iters_per_epoch + 1 @@ -132,13 +133,13 @@ class BaseLogger(Callback): self.totals[k] = v if iter % self.period == 0 and ParallelEnv().local_rank == 0: - + for k in self.totals: logs[k] = self.totals[k] / self.period self._reset() -class TrainLogger(Callback): +class TrainLogger(Callback): def __init__(self, log_freq=10): self.log_freq = log_freq @@ -154,7 +155,7 @@ class TrainLogger(Callback): return result.format(*arr) def on_iter_end(self, iter, logs=None): - + if iter % self.log_freq == 0 and ParallelEnv().local_rank == 0: total_iters = self.params["total_iters"] iters_per_epoch = self.params["iters_per_epoch"] @@ -167,49 +168,50 @@ class TrainLogger(Callback): reader_cost = logs["reader_cost"] logger.info( - "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}". - format(current_epoch, iter, total_iters, - loss, lr, batch_cost, reader_cost, eta)) + "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}" + .format(current_epoch, iter, total_iters, loss, lr, batch_cost, + reader_cost, eta)) -class ProgbarLogger(Callback): +class ProgbarLogger(Callback): def __init__(self): super(ProgbarLogger, self).__init__() def on_train_begin(self, logs=None): self.verbose = self.params["verbose"] self.total_iters = self.params["total_iters"] - self.target = self.params["total_iters"] + self.target = self.params["total_iters"] self.progbar = Progbar(target=self.target, verbose=self.verbose) self.seen = 0 self.log_values = [] - + def on_iter_begin(self, iter, logs=None): #self.seen = 0 if self.seen < self.target: self.log_values = [] - + def on_iter_end(self, iter, logs=None): logs = logs or {} self.seen += 1 for k in self.params['metrics']: if k in logs: self.log_values.append((k, logs[k])) - + #if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0: - #print(self.log_values) + #print(self.log_values) if self.seen < self.target: self.progbar.update(self.seen, self.log_values) - - - + class ModelCheckpoint(Callback): + def __init__(self, + save_dir, + monitor="miou", + save_best_only=False, + save_params_only=True, + mode="max", + period=1): - def __init__(self, save_dir, monitor="miou", - save_best_only=False, save_params_only=True, - mode="max", period=1): - super(ModelCheckpoint, self).__init__() self.monitor = monitor self.save_dir = save_dir @@ -241,7 +243,7 @@ class ModelCheckpoint(Callback): current_save_dir = os.path.join(self.save_dir, "iter_{}".format(iter)) current_save_dir = os.path.abspath(current_save_dir) #if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0: - #self.iters_since_last_save = 0 + #self.iters_since_last_save = 0 if iter % self.period == 0 and ParallelEnv().local_rank == 0: if self.verbose > 0: print("iter {iter_num}: saving model to {path}".format( @@ -252,11 +254,9 @@ class ModelCheckpoint(Callback): if not self.save_params_only: paddle.save(self.optimizer.state_dict(), filepath) - class VisualDL(Callback): - def __init__(self, log_dir="./log", freq=1): super(VisualDL, self).__init__() self.log_dir = log_dir @@ -274,4 +274,4 @@ class VisualDL(Callback): self.writer.flush() def on_train_end(self, logs=None): - self.writer.close() \ No newline at end of file + self.writer.close() diff --git a/dygraph/paddleseg/models/danet.py b/dygraph/paddleseg/models/danet.py new file mode 100644 index 0000000000000000000000000000000000000000..5af7668e7517ec8e8a3893806ad1e5c0b17440a3 --- /dev/null +++ b/dygraph/paddleseg/models/danet.py @@ -0,0 +1,217 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.utils import utils +from paddleseg.cvlibs import manager, param_init +from paddleseg.models.common.layer_libs import ConvBNReLU + + +class PAM(nn.Layer): + """Position attention module""" + + def __init__(self, in_channels): + super(PAM, self).__init__() + mid_channels = in_channels // 8 + + self.query_conv = nn.Conv2d(in_channels, mid_channels, 1, 1) + self.key_conv = nn.Conv2d(in_channels, mid_channels, 1, 1) + self.value_conv = nn.Conv2d(in_channels, in_channels, 1, 1) + + self.gamma = self.create_parameter( + shape=[1], + dtype='float32', + default_initializer=nn.initializer.Constant(0)) + + def forward(self, x): + n, _, h, w = x.shape + + # query: n, h * w, c1 + query = self.query_conv(x) + query = paddle.reshape(query, (n, -1, h * w)) + query = paddle.transpose(query, (0, 2, 1)) + + # key: n, c1, h * w + key = self.key_conv(x) + key = paddle.reshape(key, (n, -1, h * w)) + + # sim: n, h * w, h * w + sim = paddle.bmm(query, key) + sim = F.softmax(sim, axis=-1) + + value = self.value_conv(x) + value = paddle.reshape(value, (n, -1, h * w)) + sim = paddle.transpose(sim, (0, 2, 1)) + + # feat: from (n, c2, h * w) -> (n, c2, h, w) + feat = paddle.bmm(value, sim) + feat = paddle.reshape(feat, (n, -1, h, w)) + + out = self.gamma * feat + x + return out + + +class CAM(nn.Layer): + """Channel attention module""" + + def __init__(self): + super(CAM, self).__init__() + + self.gamma = self.create_parameter( + shape=[1], + dtype='float32', + default_initializer=nn.initializer.Constant(0)) + + def forward(self, x): + n, c, h, w = x.shape + + # query: n, c, h * w + query = paddle.reshape(x, (n, c, h * w)) + # key: n, h * w, c + key = paddle.reshape(x, (n, c, h * w)) + key = paddle.transpose(key, (0, 2, 1)) + + # sim: n, c, c + sim = paddle.bmm(query, key) + # The danet author claims that this can avoid gradient divergence + sim = paddle.max(sim, axis=-1, keepdim=True).expand_as(sim) - sim + sim = F.softmax(sim, axis=-1) + + # feat: from (n, c, h * w) to (n, c, h, w) + value = paddle.reshape(x, (n, c, h * w)) + feat = paddle.bmm(sim, value) + feat = paddle.reshape(feat, (n, c, h, w)) + + out = self.gamma * feat + x + return out + + +class DAHead(nn.Layer): + """ + The Dual attention head. + + Args: + num_classes(int): the unique number of target classes. + in_channels(tuple): the number of input channels. + """ + + def __init__(self, num_classes, in_channels=None): + super(DAHead, self).__init__() + in_channels = in_channels[-1] + inter_channels = in_channels // 4 + + self.channel_conv = ConvBNReLU( + in_channels, inter_channels, 3, padding=1) + self.position_conv = ConvBNReLU( + in_channels, inter_channels, 3, padding=1) + self.pam = PAM(inter_channels) + self.cam = CAM() + self.conv1 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1) + self.conv2 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1) + + self.aux_head_pam = nn.Sequential( + nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1)) + + self.aux_head_cam = nn.Sequential( + nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1)) + + self.cls_head = nn.Sequential( + nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1)) + + self.init_weight() + + def forward(self, x, label=None): + feats = x[-1] + channel_feats = self.channel_conv(feats) + channel_feats = self.cam(channel_feats) + channel_feats = self.conv1(channel_feats) + cam_head = self.aux_head_cam(channel_feats) + + position_feats = self.position_conv(feats) + position_feats = self.pam(position_feats) + position_feats = self.conv2(position_feats) + pam_head = self.aux_head_pam(position_feats) + + feats_sum = position_feats + channel_feats + cam_logit = self.aux_head_cam(channel_feats) + pam_logit = self.aux_head_cam(position_feats) + logit = self.cls_head(feats_sum) + return [logit, cam_logit, pam_logit] + + def init_weight(self): + """Initialize the parameters of model parts.""" + for sublayer in self.sublayers(): + if isinstance(sublayer, nn.Conv2d): + param_init.normal_init(sublayer.weight, scale=0.001) + elif isinstance(sublayer, nn.SyncBatchNorm): + param_init.constant_init(sublayer.weight, value=1) + param_init.constant_init(sublayer.bias, value=0) + + +@manager.MODELS.add_component +class DANet(nn.Layer): + """ + The DANet implementation based on PaddlePaddle. + + The original article refers to + Fu, jun, et al. "Dual Attention Network for Scene Segmentation" + (https://arxiv.org/pdf/1809.02983.pdf) + + Args: + num_classes(int): the unique number of target classes. + backbone(Paddle.nn.Layer): backbone network. + pretrained(str): the path or url of pretrained model. Default to None. + backbone_indices(tuple): values in the tuple indicate the indices of output of backbone. + Only the last indice is used. + """ + + def __init__(self, + num_classes, + backbone, + pretrained=None, + backbone_indices=None): + super(DANet, self).__init__() + + self.backbone = backbone + self.backbone_indices = backbone_indices + in_channels = [self.backbone.channels[i] for i in backbone_indices] + + self.head = DAHead(num_classes=num_classes, in_channels=in_channels) + + self.init_weight(pretrained) + + def forward(self, x, label=None): + feats = self.backbone(x) + feats = [feats[i] for i in self.backbone_indices] + preds = self.head(feats, label) + preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds] + return preds + + def init_weight(self, pretrained=None): + """ + Initialize the parameters of model parts. + Args: + pretrained ([str], optional): the path of pretrained model.. Defaults to None. + """ + if pretrained is not None: + if os.path.exists(pretrained): + utils.load_pretrained_model(self, pretrained) + else: + raise Exception( + 'Pretrained model is not found: {}'.format(pretrained)) diff --git a/dygraph/paddleseg/models/ocrnet.py b/dygraph/paddleseg/models/ocrnet.py index 61f4815347d015ad03c89269b817099c2b460120..01f7214b7077ba17f6c5bf361783af6cc58bdb8d 100644 --- a/dygraph/paddleseg/models/ocrnet.py +++ b/dygraph/paddleseg/models/ocrnet.py @@ -14,36 +14,41 @@ import os -import paddle.fluid as fluid -from paddle.fluid.dygraph import Sequential, Conv2D +import paddle +import paddle.nn as nn +import paddle.nn.functional as F -from paddleseg.cvlibs import manager -from paddleseg.models.common.layer_libs import ConvBNReLU from paddleseg import utils +from paddleseg.cvlibs import manager, param_init +from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer -class SpatialGatherBlock(fluid.dygraph.Layer): +class SpatialGatherBlock(nn.Layer): + """Aggregation layer to compute the pixel-region representation""" + def forward(self, pixels, regions): n, c, h, w = pixels.shape _, k, _, _ = regions.shape # pixels: from (n, c, h, w) to (n, h*w, c) - pixels = fluid.layers.reshape(pixels, (n, c, h * w)) - pixels = fluid.layers.transpose(pixels, (0, 2, 1)) + pixels = paddle.reshape(pixels, (n, c, h * w)) + pixels = paddle.transpose(pixels, (0, 2, 1)) # regions: from (n, k, h, w) to (n, k, h*w) - regions = fluid.layers.reshape(regions, (n, k, h * w)) - regions = fluid.layers.softmax(regions, axis=2) + regions = paddle.reshape(regions, (n, k, h * w)) + regions = F.softmax(regions, axis=2) # feats: from (n, k, c) to (n, c, k, 1) - feats = fluid.layers.matmul(regions, pixels) - feats = fluid.layers.transpose(feats, (0, 2, 1)) - feats = fluid.layers.unsqueeze(feats, axes=[-1]) + feats = paddle.bmm(regions, pixels) + feats = paddle.transpose(feats, (0, 2, 1)) + feats = paddle.unsqueeze(feats, axis=-1) return feats -class SpatialOCRModule(fluid.dygraph.Layer): +class SpatialOCRModule(nn.Layer): + """Aggregate the global object representation to update the representation for each pixel""" + def __init__(self, in_channels, key_channels, @@ -53,30 +58,31 @@ class SpatialOCRModule(fluid.dygraph.Layer): self.attention_block = ObjectAttentionBlock(in_channels, key_channels) self.dropout_rate = dropout_rate - self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1) + self.conv1x1 = nn.Sequential( + nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1)) def forward(self, pixels, regions): context = self.attention_block(pixels, regions) - feats = fluid.layers.concat([context, pixels], axis=1) - + feats = paddle.concat([context, pixels], axis=1) feats = self.conv1x1(feats) - feats = fluid.layers.dropout(feats, self.dropout_rate) return feats -class ObjectAttentionBlock(fluid.dygraph.Layer): +class ObjectAttentionBlock(nn.Layer): + """A self-attention module.""" + def __init__(self, in_channels, key_channels): super(ObjectAttentionBlock, self).__init__() self.in_channels = in_channels self.key_channels = key_channels - self.f_pixel = Sequential( + self.f_pixel = nn.Sequential( ConvBNReLU(in_channels, key_channels, 1), ConvBNReLU(key_channels, key_channels, 1)) - self.f_object = Sequential( + self.f_object = nn.Sequential( ConvBNReLU(in_channels, key_channels, 1), ConvBNReLU(key_channels, key_channels, 1)) @@ -89,127 +95,140 @@ class ObjectAttentionBlock(fluid.dygraph.Layer): # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) query = self.f_pixel(x) - query = fluid.layers.reshape(query, (n, self.key_channels, -1)) - query = fluid.layers.transpose(query, (0, 2, 1)) + query = paddle.reshape(query, (n, self.key_channels, -1)) + query = paddle.transpose(query, (0, 2, 1)) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) key = self.f_object(proxy) - key = fluid.layers.reshape(key, (n, self.key_channels, -1)) + key = paddle.reshape(key, (n, self.key_channels, -1)) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) value = self.f_down(proxy) - value = fluid.layers.reshape(value, (n, self.key_channels, -1)) - value = fluid.layers.transpose(value, (0, 2, 1)) + value = paddle.reshape(value, (n, self.key_channels, -1)) + value = paddle.transpose(value, (0, 2, 1)) # sim_map (n, h1*w1, h2*w2) - sim_map = fluid.layers.matmul(query, key) + sim_map = paddle.bmm(query, key) sim_map = (self.key_channels**-.5) * sim_map - sim_map = fluid.layers.softmax(sim_map, axis=-1) + sim_map = F.softmax(sim_map, axis=-1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) - context = fluid.layers.matmul(sim_map, value) - context = fluid.layers.transpose(context, (0, 2, 1)) - context = fluid.layers.reshape(context, (n, self.key_channels, h, w)) + context = paddle.bmm(sim_map, value) + context = paddle.transpose(context, (0, 2, 1)) + context = paddle.reshape(context, (n, self.key_channels, h, w)) context = self.f_up(context) return context -@manager.MODELS.add_component -class OCRNet(fluid.dygraph.Layer): +class OCRHead(nn.Layer): + """ + The Object contextual representation head. + Args: + num_classes(int): the unique number of target classes. + in_channels(tuple): the number of input channels. + ocr_mid_channels(int): the number of middle channels in OCRHead. + ocr_key_channels(int): the number of key channels in ObjectAttentionBlock. + """ + def __init__(self, num_classes, - backbone, - model_pretrained=None, in_channels=None, ocr_mid_channels=512, - ocr_key_channels=256, - ignore_index=255): - super(OCRNet, self).__init__() + ocr_key_channels=256): + super(OCRHead, self).__init__() - self.ignore_index = ignore_index self.num_classes = num_classes - self.EPS = 1e-5 - - self.backbone = backbone self.spatial_gather = SpatialGatherBlock() self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels, ocr_mid_channels) - self.conv3x3_ocr = ConvBNReLU( - in_channels, ocr_mid_channels, 3, padding=1) - self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1) - self.aux_head = Sequential( - ConvBNReLU(in_channels, in_channels, 3, padding=1), - Conv2D(in_channels, self.num_classes, 1)) + self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1] - self.init_weight(model_pretrained) + self.conv3x3_ocr = ConvBNReLU( + in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1) + self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1) + self.aux_head = AuxLayer(in_channels[self.indices[0]], + in_channels[self.indices[0]], self.num_classes) + self.init_weight() def forward(self, x, label=None): - feats = self.backbone(x) + feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]] - soft_regions = self.aux_head(feats) - pixels = self.conv3x3_ocr(feats) + soft_regions = self.aux_head(feat_shallow) + pixels = self.conv3x3_ocr(feat_deep) object_regions = self.spatial_gather(pixels, soft_regions) ocr = self.spatial_ocr(pixels, object_regions) logit = self.cls_head(ocr) - logit = fluid.layers.resize_bilinear(logit, x.shape[2:]) - - if self.training: - soft_regions = fluid.layers.resize_bilinear(soft_regions, - x.shape[2:]) - cls_loss = self._get_loss(logit, label) - aux_loss = self._get_loss(soft_regions, label) - return cls_loss + 0.4 * aux_loss - - score_map = fluid.layers.softmax(logit, axis=1) - score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) - pred = fluid.layers.argmax(score_map, axis=3) - pred = fluid.layers.unsqueeze(pred, axes=[3]) - return pred, score_map - - def init_weight(self, pretrained_model=None): + return [logit, soft_regions] + + def init_weight(self): + """Initialize the parameters of model parts.""" + for sublayer in self.sublayers(): + if isinstance(sublayer, nn.Conv2d): + param_init.normal_init(sublayer.weight, scale=0.001) + elif isinstance(sublayer, nn.SyncBatchNorm): + param_init.constant_init(sublayer.weight, value=1) + param_init.constant_init(sublayer.bias, value=0) + + +@manager.MODELS.add_component +class OCRNet(nn.Layer): + """ + The OCRNet implementation based on PaddlePaddle. + The original article refers to + Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation" + (https://arxiv.org/pdf/1909.11065.pdf) + Args: + num_classes(int): the unique number of target classes. + backbone(Paddle.nn.Layer): backbone network. + pretrained(str): the path or url of pretrained model. Default to None. + backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone. + the first index will be taken as a deep-supervision feature in auxiliary layer; + the second one will be taken as input of pixel representation. + ocr_mid_channels(int): the number of middle channels in OCRHead. + ocr_key_channels(int): the number of key channels in ObjectAttentionBlock. + """ + + def __init__(self, + num_classes, + backbone, + pretrained=None, + backbone_indices=None, + ocr_mid_channels=512, + ocr_key_channels=256): + super(OCRNet, self).__init__() + + self.backbone = backbone + self.backbone_indices = backbone_indices + in_channels = [self.backbone.channels[i] for i in backbone_indices] + + self.head = OCRHead( + num_classes=num_classes, + in_channels=in_channels, + ocr_mid_channels=ocr_mid_channels, + ocr_key_channels=ocr_key_channels) + + self.init_weight(pretrained) + + def forward(self, x, label=None): + feats = self.backbone(x) + feats = [feats[i] for i in self.backbone_indices] + preds = self.head(feats, label) + preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds] + return preds + + def init_weight(self, pretrained=None): """ Initialize the parameters of model parts. Args: - pretrained_model ([str], optional): the path of pretrained model.. Defaults to None. + pretrained ([str], optional): the path of pretrained model.. Defaults to None. """ - if pretrained_model is not None: - if os.path.exists(pretrained_model): - utils.load_pretrained_model(self, pretrained_model) + if pretrained is not None: + if os.path.exists(pretrained): + utils.load_pretrained_model(self, pretrained) else: - raise Exception('Pretrained model is not found: {}'.format( - pretrained_model)) - - def _get_loss(self, logit, label): - """ - compute forward loss of the model - - Args: - logit (tensor): the logit of model output - label (tensor): ground truth - - Returns: - avg_loss (tensor): forward loss - """ - logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) - label = fluid.layers.transpose(label, [0, 2, 3, 1]) - mask = label != self.ignore_index - mask = fluid.layers.cast(mask, 'float32') - loss, probs = fluid.layers.softmax_with_cross_entropy( - logit, - label, - ignore_index=self.ignore_index, - return_softmax=True, - axis=-1) - - loss = loss * mask - avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + self.EPS) - - label.stop_gradient = True - mask.stop_gradient = True - - return avg_loss + raise Exception( + 'Pretrained model is not found: {}'.format(pretrained)) \ No newline at end of file diff --git a/dygraph/paddleseg/utils/metrics.py b/dygraph/paddleseg/utils/metrics.py index b107cbd57a936fb909086567fc8b703fb86963b7..44eaf4ed0580699cdb0be35f53a343cb8f70b751 100644 --- a/dygraph/paddleseg/utils/metrics.py +++ b/dygraph/paddleseg/utils/metrics.py @@ -41,7 +41,7 @@ class ConfusionMatrix(object): label = np.asarray(label)[mask] pred = np.asarray(pred)[mask] one = np.ones_like(pred) - # Accumuate ([row=label, col=pred], 1) into sparse matrix + # Accumuate ([row=label, col=pred], 1) into sparse spm = csr_matrix((one, (label, pred)), shape=(self.num_classes, self.num_classes)) spm = spm.todense() diff --git a/dygraph/paddleseg/utils/progbar.py b/dygraph/paddleseg/utils/progbar.py index b5e82881abe238f2eb686e8dfd28214f70b97819..d57ce707c407b7837acb9cbd4d2ad244ff6575a7 100644 --- a/dygraph/paddleseg/utils/progbar.py +++ b/dygraph/paddleseg/utils/progbar.py @@ -17,8 +17,9 @@ import time import numpy as np + class Progbar(object): - """Displays a progress bar. + """Displays a progress bar. refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py Arguments: target: Total number of steps expected, None if unknown. @@ -31,39 +32,39 @@ class Progbar(object): unit_name: Display name for step counts (usually "step" or "sample"). """ - def __init__(self, - target, - width=30, - verbose=1, - interval=0.05, - stateful_metrics=None, - unit_name='step'): - self.target = target - self.width = width - self.verbose = verbose - self.interval = interval - self.unit_name = unit_name - if stateful_metrics: - self.stateful_metrics = set(stateful_metrics) - else: - self.stateful_metrics = set() - - self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and - sys.stdout.isatty()) or - 'ipykernel' in sys.modules or - 'posix' in sys.modules or - 'PYCHARM_HOSTED' in os.environ) - self._total_width = 0 - self._seen_so_far = 0 - # We use a dict + list to avoid garbage collection - # issues found in OrderedDict - self._values = {} - self._values_order = [] - self._start = time.time() - self._last_update = 0 - - def update(self, current, values=None, finalize=None): - """Updates the progress bar. + def __init__(self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name='step'): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') + and sys.stdout.isatty()) + or 'ipykernel' in sys.modules + or 'posix' in sys.modules + or 'PYCHARM_HOSTED' in os.environ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. Arguments: current: Index of current step. values: List of tuples: `(name, value_for_last_step)`. If `name` is in @@ -72,129 +73,131 @@ class Progbar(object): finalize: Whether this is the last update for the progress bar. If `None`, defaults to `current >= self.target`. """ - if finalize is None: - if self.target is None: - finalize = False - else: - finalize = current >= self.target - - values = values or [] - for k, v in values: - if k not in self._values_order: - self._values_order.append(k) - if k not in self.stateful_metrics: - # In the case that progress bar doesn't have a target value in the first - # epoch, both on_batch_end and on_epoch_end will be called, which will - # cause 'current' and 'self._seen_so_far' to have the same value. Force - # the minimal value to 1 here, otherwise stateful_metric will be 0s. - value_base = max(current - self._seen_so_far, 1) - if k not in self._values: - self._values[k] = [v * value_base, value_base] - else: - self._values[k][0] += v * value_base - self._values[k][1] += value_base - else: - # Stateful metrics output a numeric value. This representation - # means "take an average from a single value" but keeps the - # numeric formatting. - self._values[k] = [v, 1] - self._seen_so_far = current - - now = time.time() - info = ' - %.0fs' % (now - self._start) - if self.verbose == 1: - if now - self._last_update < self.interval and not finalize: - return - - prev_total_width = self._total_width - if self._dynamic_display: - sys.stdout.write('\b' * prev_total_width) - sys.stdout.write('\r') - else: - sys.stdout.write('\n') - - if self.target is not None: - numdigits = int(np.log10(self.target)) + 1 - bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) - prog = float(current) / self.target - prog_width = int(self.width * prog) - if prog_width > 0: - bar += ('=' * (prog_width - 1)) - if current < self.target: - bar += '>' - else: - bar += '=' - bar += ('.' * (self.width - prog_width)) - bar += ']' - else: - bar = '%7d/Unknown' % current - - self._total_width = len(bar) - sys.stdout.write(bar) - - if current: - time_per_unit = (now - self._start) / current - else: - time_per_unit = 0 - - if self.target is None or finalize: - if time_per_unit >= 1 or time_per_unit == 0: - info += ' %.0fs/%s' % (time_per_unit, self.unit_name) - elif time_per_unit >= 1e-3: - info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) - else: - info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) - else: - eta = time_per_unit * (self.target - current) - if eta > 3600: - eta_format = '%d:%02d:%02d' % (eta // 3600, - (eta % 3600) // 60, eta % 60) - elif eta > 60: - eta_format = '%d:%02d' % (eta // 60, eta % 60) - else: - eta_format = '%ds' % eta - - info = ' - ETA: %s' % eta_format - - for k in self._values_order: - info += ' - %s:' % k - if isinstance(self._values[k], list): - avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) - if abs(avg) > 1e-3: - info += ' %.4f' % avg - else: - info += ' %.4e' % avg - else: - info += ' %s' % self._values[k] - - self._total_width += len(info) - if prev_total_width > self._total_width: - info += (' ' * (prev_total_width - self._total_width)) - - if finalize: - info += '\n' - - sys.stdout.write(info) - sys.stdout.flush() - - elif self.verbose == 2: - if finalize: - numdigits = int(np.log10(self.target)) + 1 - count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) - info = count + info - for k in self._values_order: - info += ' - %s:' % k - avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) - if avg > 1e-3: - info += ' %.4f' % avg - else: - info += ' %.4e' % avg - info += '\n' - - sys.stdout.write(info) - sys.stdout.flush() - - self._last_update = now - - def add(self, n, values=None): - self.update(self._seen_so_far + n, values) \ No newline at end of file + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in the first + # epoch, both on_batch_end and on_epoch_end will be called, which will + # cause 'current' and 'self._seen_so_far' to have the same value. Force + # the minimal value to 1 here, otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + now = time.time() + info = ' - %.0fs' % (now - self._start) + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += ('=' * (prog_width - 1)) + if current < self.target: + bar += '>' + else: + bar += '=' + bar += ('.' * (self.width - prog_width)) + bar += ']' + else: + bar = '%7d/Unknown' % current + + self._total_width = len(bar) + sys.stdout.write(bar) + + if current: + time_per_unit = (now - self._start) / current + else: + time_per_unit = 0 + + if self.target is None or finalize: + if time_per_unit >= 1 or time_per_unit == 0: + info += ' %.0fs/%s' % (time_per_unit, self.unit_name) + elif time_per_unit >= 1e-3: + info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) + else: + info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format + + for k in self._values_order: + info += ' - %s:' % k + if isinstance(self._values[k], list): + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1])) + if abs(avg) > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + else: + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) + + if finalize: + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) + info = count + info + for k in self._values_order: + info += ' - %s:' % k + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1])) + if avg > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values)