diff --git a/dygraph/cvlibs/manager.py b/dygraph/cvlibs/manager.py index 7e179e1ed5ad3ba9385ba0c206c382e4b822720e..e4a952b3cca0f451b18a7d5bb2c9c0c4654d8c11 100644 --- a/dygraph/cvlibs/manager.py +++ b/dygraph/cvlibs/manager.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections +from collections.abc import Sequence import inspect @@ -98,13 +98,14 @@ class ComponentManager: """ # Check whether the type is a sequence - if isinstance(components, collections.Sequence): + if isinstance(components, Sequence): for component in components: self._add_single_component(component) else: component = components self._add_single_component(component) + return components MODELS = ComponentManager() BACKBONES = ComponentManager() \ No newline at end of file diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index caa734fefb7bad40e82a51160495b5889d82aa34..52b73c3b7aa7e38868ca3588e0df6fd430431bf0 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -16,3 +16,4 @@ from .architectures import * from .unet import UNet from .deeplab import * from .fcn import * +from .pspnet import * diff --git a/dygraph/models/architectures/layer_utils.py b/dygraph/models/architectures/layer_utils.py index 024748c93863de955abe828c0f9797c9a8b4bbb1..a9842f188276b6347f4f2ced100ff8c6c00f2715 100644 --- a/dygraph/models/architectures/layer_utils.py +++ b/dygraph/models/architectures/layer_utils.py @@ -13,24 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle.nn.functional as F from paddle import fluid from paddle.fluid import dygraph from paddle.fluid.dygraph import Conv2D -from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm -import cv2 -import os -import sys +from paddle.nn import SyncBatchNorm as BatchNorm +from paddle.nn.layer import activation class ConvBnRelu(dygraph.Layer): - def __init__(self, num_channels, num_filters, filter_size, using_sep_conv=False, **kwargs): - + super(ConvBnRelu, self).__init__() if using_sep_conv: @@ -41,16 +39,16 @@ class ConvBnRelu(dygraph.Layer): else: self.conv = Conv2D(num_channels, - num_filters, - filter_size, - **kwargs) + num_filters, + filter_size, + **kwargs) self.batch_norm = BatchNorm(num_filters) def forward(self, x): x = self.conv(x) x = self.batch_norm(x) - x = fluid.layers.relu(x) + x = F.relu(x) return x @@ -81,7 +79,7 @@ class ConvReluPool(dygraph.Layer): def forward(self, x): x = self.conv(x) - x = fluid.layers.relu(x) + x = F.relu(x) x = fluid.layers.pool2d(x, pool_size=2, pool_type="max", pool_stride=2) return x @@ -106,15 +104,15 @@ class DepthwiseConvBnRelu(dygraph.Layer): **kwargs): super(DepthwiseConvBnRelu, self).__init__() self.depthwise_conv = ConvBn(num_channels, - num_filters=num_channels, - filter_size=filter_size, - groups=num_channels, - use_cudnn=False, - **kwargs) + num_filters=num_channels, + filter_size=filter_size, + groups=num_channels, + use_cudnn=False, + **kwargs) self.piontwise_conv = ConvBnRelu(num_channels, - num_filters, - filter_size=1, - groups=1) + num_filters, + filter_size=1, + groups=1) def forward(self, x): x = self.depthwise_conv(x) @@ -122,20 +120,43 @@ class DepthwiseConvBnRelu(dygraph.Layer): return x -def compute_loss(logits, label, ignore_index=255): - mask = label != ignore_index - mask = fluid.layers.cast(mask, 'float32') - loss, probs = fluid.layers.softmax_with_cross_entropy( - logits, - label, - ignore_index=ignore_index, - return_softmax=True, - axis=1) +class Activation(fluid.dygraph.Layer): + """ + The wrapper of activations + For example: + >>> relu = Activation("relu") + >>> print(relu) + + >>> sigmoid = Activation("sigmoid") + >>> print(sigmoid) + + >>> not_exit_one = Activation("not_exit_one") + KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink', + 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax', + 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])" + + Args: + act (str): the activation name in lowercase + """ + + def __init__(self, act=None): + super(Activation, self).__init__() + + self._act = act + upper_act_names = activation.__all__ + lower_act_names = [act.lower() for act in upper_act_names] + act_dict = dict(zip(lower_act_names, upper_act_names)) + + if act is not None: + if act in act_dict.keys(): + act_name = act_dict[act] + self.act_func = eval("activation.{}()".format(act_name)) + else: + raise KeyError("{} does not exist in the current {}".format(act, act_dict.keys())) - loss = loss * mask - avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + 1e-5) + def forward(self, x): - label.stop_gradient = True - mask.stop_gradient = True - return avg_loss \ No newline at end of file + if self._act is not None: + return self.act_func(x) + else: + return x \ No newline at end of file diff --git a/dygraph/models/architectures/mobilenetv3.py b/dygraph/models/architectures/mobilenetv3.py index 91aa0563ebbca62284f399ffa37100bcca08042c..2899e3f76567cee638b07c86174896a19f51bd2f 100644 --- a/dygraph/models/architectures/mobilenetv3.py +++ b/dygraph/models/architectures/mobilenetv3.py @@ -16,15 +16,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import numpy as np + import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout - -import math +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout +from paddle.nn import SyncBatchNorm as BatchNorm +from dygraph.models.architectures import layer_utils from dygraph.cvlibs import manager __all__ = [ @@ -251,19 +253,18 @@ class ConvBNLayer(fluid.dygraph.Layer): bias_attr=False, use_cudnn=use_cudnn, act=None) - self.bn = fluid.dygraph.BatchNorm( - num_channels=out_c, - act=None, - param_attr=ParamAttr( + self.bn = BatchNorm( + num_features=out_c, + weight_attr=ParamAttr( name=name + "_bn_scale", regularizer=fluid.regularizer.L2DecayRegularizer( regularization_coeff=0.0)), bias_attr=ParamAttr( name=name + "_bn_offset", regularizer=fluid.regularizer.L2DecayRegularizer( - regularization_coeff=0.0)), - moving_mean_name=name + "_bn_mean", - moving_variance_name=name + "_bn_variance") + regularization_coeff=0.0))) + + self._act_op = layer_utils.Activation(act=None) def forward(self, x): x = self.conv(x) diff --git a/dygraph/models/architectures/resnet_vd.py b/dygraph/models/architectures/resnet_vd.py index b08dcd90c97605ac22a376342d801c8ddc4f378f..c27c810c46c0bbdc06053e747c7a7eaeb22be6e1 100644 --- a/dygraph/models/architectures/resnet_vd.py +++ b/dygraph/models/architectures/resnet_vd.py @@ -24,10 +24,11 @@ import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout +from paddle.nn import SyncBatchNorm as BatchNorm from dygraph.utils import utils - +from dygraph.models.architectures import layer_utils from dygraph.cvlibs import manager __all__ = [ @@ -69,17 +70,17 @@ class ConvBNLayer(fluid.dygraph.Layer): bn_name = "bn" + name[3:] self._batch_norm = BatchNorm( num_filters, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') + weight_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset')) + self._act_op = layer_utils.Activation(act=act) def forward(self, inputs): if self.is_vd_mode: inputs = self._pool2d_avg(inputs) y = self._conv(inputs) y = self._batch_norm(y) + y = self._act_op(y) + return y diff --git a/dygraph/models/architectures/xception_deeplab.py b/dygraph/models/architectures/xception_deeplab.py index 1cb0f2a9e9cdb2f2a2406bc36dec8a0ee06ed395..f96dcb6936e25444c1d79b2461b941634fbb4c2f 100644 --- a/dygraph/models/architectures/xception_deeplab.py +++ b/dygraph/models/architectures/xception_deeplab.py @@ -2,8 +2,10 @@ import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout +from paddle.nn import SyncBatchNorm as BatchNorm +from dygraph.models.architectures import layer_utils from dygraph.cvlibs import manager __all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"] @@ -79,17 +81,17 @@ class ConvBNLayer(fluid.dygraph.Layer): param_attr=ParamAttr(name=name + "/weights"), bias_attr=False) self._bn = BatchNorm( - num_channels=output_channels, - act=act, + num_features=output_channels, epsilon=1e-3, momentum=0.99, - param_attr=ParamAttr(name=name + "/BatchNorm/gamma"), - bias_attr=ParamAttr(name=name + "/BatchNorm/beta"), - moving_mean_name=name + "/BatchNorm/moving_mean", - moving_variance_name=name + "/BatchNorm/moving_variance") + weight_attr=ParamAttr(name=name + "/BatchNorm/gamma"), + bias_attr=ParamAttr(name=name + "/BatchNorm/beta")) + + self._act_op = layer_utils.Activation(act=act) def forward(self, inputs): - return self._bn(self._conv(inputs)) + + return self._act_op(self._bn(self._conv(inputs))) class Seperate_Conv(fluid.dygraph.Layer): @@ -115,13 +117,13 @@ class Seperate_Conv(fluid.dygraph.Layer): bias_attr=False) self._bn1 = BatchNorm( input_channels, - act=act, epsilon=1e-3, momentum=0.99, - param_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"), - bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"), - moving_mean_name=name + "/depthwise/BatchNorm/moving_mean", - moving_variance_name=name + "/depthwise/BatchNorm/moving_variance") + weight_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"), + bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta")) + + self._act_op1 = layer_utils.Activation(act=act) + self._conv2 = Conv2D( input_channels, output_channels, @@ -133,19 +135,21 @@ class Seperate_Conv(fluid.dygraph.Layer): bias_attr=False) self._bn2 = BatchNorm( output_channels, - act=act, epsilon=1e-3, momentum=0.99, - param_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"), - bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"), - moving_mean_name=name + "/pointwise/BatchNorm/moving_mean", - moving_variance_name=name + "/pointwise/BatchNorm/moving_variance") + weight_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"), + bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta")) + + self._act_op2 = layer_utils.Activation(act=act) + def forward(self, inputs): x = self._conv1(inputs) x = self._bn1(x) + x = self._act_op1(x) x = self._conv2(x) x = self._bn2(x) + x = self._act_op2(x) return x diff --git a/dygraph/models/model_utils.py b/dygraph/models/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a88c355a78d98ff312aaa75cf175a2369ffa5d --- /dev/null +++ b/dygraph/models/model_utils.py @@ -0,0 +1,102 @@ +# -*- encoding: utf-8 -*- +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F +from paddle import fluid +from paddle.fluid import dygraph +from paddle.fluid.dygraph import Conv2D +from paddle.nn import SyncBatchNorm as BatchNorm + +from dygraph.models.architectures import layer_utils + + +class FCNHead(fluid.dygraph.Layer): + """ + The FCNHead implementation used in auxilary layer + + Args: + in_channels (int): the number of input channels + out_channels (int): the number of output channels + """ + + def __init__(self, in_channels, out_channels): + super(FCNHead, self).__init__() + + inter_channels = in_channels // 4 + self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=in_channels, + num_filters=inter_channels, + filter_size=3, + padding=1) + + self.conv = Conv2D(num_channels=inter_channels, + num_filters=out_channels, + filter_size=1) + + def forward(self, x): + x = self.conv_bn_relu(x) + x = F.dropout(x, p=0.1) + x = self.conv(x) + return x + + +def get_loss(logit, label, ignore_index=255, EPS=1e-5): + """ + compute forward loss of the model + + Args: + logit (tensor): the logit of model output + label (tensor): ground truth + + Returns: + avg_loss (tensor): forward loss + """ + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + label = fluid.layers.transpose(label, [0, 2, 3, 1]) + mask = label != ignore_index + mask = fluid.layers.cast(mask, 'float32') + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + ignore_index=ignore_index, + return_softmax=True, + axis=-1) + + loss = loss * mask + avg_loss = paddle.mean(loss) / (paddle.mean(mask) + EPS) + + label.stop_gradient = True + mask.stop_gradient = True + + return avg_loss + + +def get_pred_score_map(logit): + """ + Get prediction and score map output in inference phase. + + Args: + logit (tensor): output logit of network + + Returns: + pred (tensor): predition map + score_map (tensor): score map + """ + score_map = F.softmax(logit, axis=1) + score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) + pred = fluid.layers.argmax(score_map, axis=3) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + + return pred, score_map \ No newline at end of file diff --git a/dygraph/models/pspnet.py b/dygraph/models/pspnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2b30d43c1c008d4aafb8f5eba8907da14e1dbfb1 --- /dev/null +++ b/dygraph/models/pspnet.py @@ -0,0 +1,225 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import paddle.nn.functional as F +from paddle import fluid +from paddle.fluid.dygraph import Conv2D + +from dygraph.cvlibs import manager +from dygraph.models import model_utils +from dygraph.models.architectures import layer_utils +from dygraph.utils import utils + + +class PSPNet(fluid.dygraph.Layer): + """ + The PSPNet implementation + + The orginal artile refers to + Zhao, Hengshuang, et al. "Pyramid scene parsing network." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. + (https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf) + + Args: + backbone (str): backbone name, currently support Resnet50/101. + + num_classes (int): the unique number of target classes. Default 2. + + output_stride (int): the ratio of input size and final feature size. Default 16. + + backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone. + the first index will be taken as a deep-supervision feature in auxiliary layer; + the second one will be taken as input of Pyramid Pooling Module (PPModule). + Usually backbone consists of four downsampling stage, and return an output of + each stage, so we set default (2, 3), which means taking feature map of the third + stage (res4b22) in backbone, and feature map of the fourth stage (res5c) as input of PPModule. + + backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index. + + pp_out_channels (int): output channels after Pyramid Pooling Module. Default to 1024. + + bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6). + + enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True. + + ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255. + + pretrained_model (str): the pretrained_model path of backbone. + """ + + def __init__(self, + backbone, + num_classes=2, + output_stride=16, + backbone_indices=(2, 3), + backbone_channels=(1024, 2048), + pp_out_channels=1024, + bin_sizes=(1, 2, 3, 6), + enable_auxiliary_loss=True, + ignore_index=255, + pretrained_model=None): + + super(PSPNet, self).__init__() + self.backbone = manager.BACKBONES[backbone](output_stride=output_stride, + multi_grid=(1, 1, 1)) + self.backbone_indices = backbone_indices + + self.psp_module = PPModule(in_channels=backbone_channels[1], + out_channels=pp_out_channels, + bin_sizes=bin_sizes) + + self.conv = Conv2D(num_channels=pp_out_channels, + num_filters=num_classes, + filter_size=1) + + if enable_auxiliary_loss: + self.fcn_head = model_utils.FCNHead(in_channels=backbone_channels[0], out_channels=num_classes) + + self.enable_auxiliary_loss = enable_auxiliary_loss + self.ignore_index = ignore_index + + self.init_weight(pretrained_model) + + def forward(self, input, label=None): + + _, feat_list = self.backbone(input) + + x = feat_list[self.backbone_indices[1]] + x = self.psp_module(x) + x = F.dropout(x, dropout_prob=0.1) + logit = self.conv(x) + logit = fluid.layers.resize_bilinear(logit, input.shape[2:]) + + if self.enable_auxiliary_loss: + auxiliary_feat = feat_list[self.backbone_indices[0]] + auxiliary_logit = self.fcn_head(auxiliary_feat) + auxiliary_logit = fluid.layers.resize_bilinear(auxiliary_logit, input.shape[2:]) + + if self.training: + loss = model_utils.get_loss(logit, label) + if self.enable_auxiliary_loss: + auxiliary_loss = model_utils.get_loss(auxiliary_logit, label) + loss += (0.4 * auxiliary_loss) + return loss + + + else: + pred, score_map = model_utils.get_pred_score_map(logit) + return pred, score_map + + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + + Args: + pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + """ + + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self.backbone, pretrained_model) + + +class PPModule(fluid.dygraph.Layer): + """ + Pyramid pooling module + + Args: + in_channels (int): the number of intput channels to pyramid pooling module. + + out_channels (int): the number of output channels after pyramid pooling module. + + bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6). + """ + + def __init__(self, in_channels, out_channels, bin_sizes=(1, 2, 3, 6)): + super(PPModule, self).__init__() + self.bin_sizes = bin_sizes + + # we use dimension reduction after pooling mentioned in original implementation. + self.stages = fluid.dygraph.LayerList([self._make_stage(in_channels, size) for size in bin_sizes]) + + self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=in_channels * 2, + num_filters=out_channels, + filter_size=3, + padding=1) + + def _make_stage(self, in_channels, size): + """ + Create one pooling layer. + + In our implementation, we adopt the same dimention reduction as the original paper that might be + slightly different with other implementations. + + After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations + keep the channels to be same. + + + Args: + in_channels (int): the number of intput channels to pyramid pooling module. + + size (int): the out size of the pooled layer. + + Returns: + conv (tensor): a tensor after Pyramid Pooling Module + """ + + # this paddle version does not support AdaptiveAvgPool2d, so skip it here. + # prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) + conv = layer_utils.ConvBnRelu(num_channels=in_channels, + num_filters=in_channels // len(self.bin_sizes), + filter_size=1) + + return conv + + def forward(self, input): + cat_layers = [] + for i, stage in enumerate(self.stages): + size = self.bin_sizes[i] + x = fluid.layers.adaptive_pool2d(input, pool_size=(size, size), pool_type="max") + x = stage(x) + x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:]) + cat_layers.append(x) + cat_layers = [input] + cat_layers[::-1] + cat = fluid.layers.concat(cat_layers, axis=1) + out = self.conv_bn_relu2(cat) + + return out + + +@manager.MODELS.add_component +def pspnet_resnet101_vd(*args, **kwargs): + pretrained_model = None + return PSPNet(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) + + +@manager.MODELS.add_component +def pspnet_resnet101_vd_os8(*args, **kwargs): + pretrained_model = None + return PSPNet(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + + +@manager.MODELS.add_component +def pspnet_resnet50_vd(*args, **kwargs): + pretrained_model = None + return PSPNet(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) + + +@manager.MODELS.add_component +def pspnet_resnet50_vd_os8(*args, **kwargs): + pretrained_model = None + return PSPNet(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)