From ea9c7cfdc3261fd99125a3c5c07e75ef1f142895 Mon Sep 17 00:00:00 2001 From: michaelowenliu Date: Tue, 15 Sep 2020 17:17:11 +0800 Subject: [PATCH] merge layer_utils and model_utils into layer_libs --- dygraph/paddleseg/models/common/__init__.py | 5 +- dygraph/paddleseg/models/common/activation.py | 60 ++++++++ .../common/{layer_utils.py => layer_libs.py} | 73 ++++------ .../{model_utils.py => pyramid_pool.py} | 133 ++++++++++-------- 4 files changed, 164 insertions(+), 107 deletions(-) create mode 100644 dygraph/paddleseg/models/common/activation.py rename dygraph/paddleseg/models/common/{layer_utils.py => layer_libs.py} (59%) rename dygraph/paddleseg/models/common/{model_utils.py => pyramid_pool.py} (58%) diff --git a/dygraph/paddleseg/models/common/__init__.py b/dygraph/paddleseg/models/common/__init__.py index 9f30b50f..33b2611d 100644 --- a/dygraph/paddleseg/models/common/__init__.py +++ b/dygraph/paddleseg/models/common/__init__.py @@ -13,5 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import layer_utils -from . import model_utils \ No newline at end of file +from . import layer_libs +from . import activation +from . import pyramid_pool \ No newline at end of file diff --git a/dygraph/paddleseg/models/common/activation.py b/dygraph/paddleseg/models/common/activation.py new file mode 100644 index 00000000..69af72e0 --- /dev/null +++ b/dygraph/paddleseg/models/common/activation.py @@ -0,0 +1,60 @@ +# -*- encoding: utf-8 -*- +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import nn +from paddle.nn.layer import activation + + +class Activation(nn.Layer): + """ + The wrapper of activations + For example: + >>> relu = Activation("relu") + >>> print(relu) + + >>> sigmoid = Activation("sigmoid") + >>> print(sigmoid) + + >>> not_exit_one = Activation("not_exit_one") + KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink', + 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax', + 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])" + + Args: + act (str): the activation name in lowercase + """ + + def __init__(self, act=None): + super(Activation, self).__init__() + + self._act = act + upper_act_names = activation.__all__ + lower_act_names = [act.lower() for act in upper_act_names] + act_dict = dict(zip(lower_act_names, upper_act_names)) + + if act is not None: + if act in act_dict.keys(): + act_name = act_dict[act] + self.act_func = eval("activation.{}()".format(act_name)) + else: + raise KeyError("{} does not exist in the current {}".format( + act, act_dict.keys())) + + def forward(self, x): + + if self._act is not None: + return self.act_func(x) + else: + return x \ No newline at end of file diff --git a/dygraph/paddleseg/models/common/layer_utils.py b/dygraph/paddleseg/models/common/layer_libs.py similarity index 59% rename from dygraph/paddleseg/models/common/layer_utils.py rename to dygraph/paddleseg/models/common/layer_libs.py index 8d41ebb1..8da38bca 100644 --- a/dygraph/paddleseg/models/common/layer_utils.py +++ b/dygraph/paddleseg/models/common/layer_libs.py @@ -70,18 +70,6 @@ class ConvReluPool(nn.Layer): return x -# class ConvBnReluUpsample(nn.Layer): -# def __init__(self, in_channels, out_channels): -# super(ConvBnReluUpsample, self).__init__() -# self.conv_bn_relu = ConvBnRelu(in_channels, out_channels) - -# def forward(self, x, upsample_scale=2): -# x = self.conv_bn_relu(x) -# new_shape = [x.shape[2] * upsample_scale, x.shape[3] * upsample_scale] -# x = F.resize_bilinear(x, new_shape) -# return x - - class DepthwiseConvBnRelu(nn.Layer): def __init__(self, in_channels, out_channels, kernel_size, **kwargs): super(DepthwiseConvBnRelu, self).__init__() @@ -100,44 +88,43 @@ class DepthwiseConvBnRelu(nn.Layer): return x -class Activation(nn.Layer): +class AuxLayer(nn.Layer): """ - The wrapper of activations - For example: - >>> relu = Activation("relu") - >>> print(relu) - - >>> sigmoid = Activation("sigmoid") - >>> print(sigmoid) - - >>> not_exit_one = Activation("not_exit_one") - KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink', - 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax', - 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])" + The auxilary layer implementation for auxilary loss Args: - act (str): the activation name in lowercase + in_channels (int): the number of input channels. + + inter_channels (int): intermediate channels. + + out_channels (int): the number of output channels, which is usually num_classes. + + dropout_prob (float): the droput rate. Default to 0.1. """ - def __init__(self, act=None): - super(Activation, self).__init__() + def __init__(self, + in_channels, + inter_channels, + out_channels, + dropout_prob=0.1): + super(AuxLayer, self).__init__() + + self.conv_bn_relu = ConvBnRelu( + in_channels=in_channels, + out_channels=inter_channels, + kernel_size=3, + padding=1) - self._act = act - upper_act_names = activation.__all__ - lower_act_names = [act.lower() for act in upper_act_names] - act_dict = dict(zip(lower_act_names, upper_act_names)) + self.conv = nn.Conv2d( + in_channels=inter_channels, + out_channels=out_channels, + kernel_size=1) - if act is not None: - if act in act_dict.keys(): - act_name = act_dict[act] - self.act_func = eval("activation.{}()".format(act_name)) - else: - raise KeyError("{} does not exist in the current {}".format( - act, act_dict.keys())) + self.dropout_prob = dropout_prob def forward(self, x): + x = self.conv_bn_relu(x) + x = F.dropout(x, p=self.dropout_prob) + x = self.conv(x) + return x - if self._act is not None: - return self.act_func(x) - else: - return x diff --git a/dygraph/paddleseg/models/common/model_utils.py b/dygraph/paddleseg/models/common/pyramid_pool.py similarity index 58% rename from dygraph/paddleseg/models/common/model_utils.py rename to dygraph/paddleseg/models/common/pyramid_pool.py index 7de39c8e..a69eb0f6 100644 --- a/dygraph/paddleseg/models/common/model_utils.py +++ b/dygraph/paddleseg/models/common/pyramid_pool.py @@ -13,85 +13,96 @@ # See the License for the specific language governing permissions and # limitations under the License. + import paddle from paddle import nn import paddle.nn.functional as F from paddle.nn import SyncBatchNorm as BatchNorm -from paddleseg.models.common import layer_utils +from paddleseg.models.common import layer_libs -class FCNHead(nn.Layer): +class ASPPModule(nn.Layer): """ - The FCNHead implementation used in auxilary layer + Atrous Spatial Pyramid Pooling Args: - in_channels (int): the number of input channels - out_channels (int): the number of output channels - """ + aspp_ratios (tuple): the dilation rate using in ASSP module. - def __init__(self, in_channels, out_channels): - super(FCNHead, self).__init__() - - inter_channels = in_channels // 4 - self.conv_bn_relu = layer_utils.ConvBnRelu( - in_channels=in_channels, - out_channels=inter_channels, - kernel_size=3, - padding=1) - - self.conv = nn.Conv2d( - in_channels=inter_channels, - out_channels=out_channels, - kernel_size=1) + in_channels (int): the number of input channels. - def forward(self, x): - x = self.conv_bn_relu(x) - x = F.dropout(x, p=0.1) - x = self.conv(x) - return x + out_channels (int): the number of output channels. + sep_conv (bool): if using separable conv in ASPP module. -class AuxLayer(nn.Layer): - """ - The auxilary layer implementation for auxilary loss + image_pooling: if augmented with image-level features. - Args: - in_channels (int): the number of input channels. - inter_channels (int): intermediate channels. - out_channels (int): the number of output channels, which is usually num_classes. """ - def __init__(self, - in_channels, - inter_channels, - out_channels, - dropout_prob=0.1): - super(AuxLayer, self).__init__() - - self.conv_bn_relu = layer_utils.ConvBnRelu( - in_channels=in_channels, - out_channels=inter_channels, - kernel_size=3, - padding=1) - - self.conv = nn.Conv2d( - in_channels=inter_channels, - out_channels=out_channels, + def __init__(self, + aspp_ratios, + in_channels, + out_channels, + sep_conv=False, + image_pooling=False): + super(ASPPModule, self).__init__() + + self.aspp_blocks = [] + + for ratio in aspp_ratios: + + if sep_conv and ratio > 1: + conv_func = layer_libs.DepthwiseConvBnRelu + else: + conv_func = layer_libs.ConvBnRelu + + block = conv_func( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1 if ratio == 1 else 3, + dilation=ratio, + padding=0 if ratio == 1 else ratio + ) + self.aspp_blocks.append(block) + + out_size = len(self.aspp_blocks) + + if image_pooling: + self.global_avg_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(output_size=(1, 1)), + layer_libs.ConvBnRelu(in_channels, out_channels, kernel_size=1, bias_attr=False) + ) + out_size += 1 + self.image_pooling = image_pooling + + self.conv_bn_relu = layer_libs.ConvBnRelu( + in_channels=out_channels * out_size, + out_channels=out_channels, kernel_size=1) - self.dropout_prob = dropout_prob + self.dropout = nn.Dropout(p=0.1) # drop rate def forward(self, x): + + outputs = [] + for block in self.aspp_blocks: + outputs.append(block(x)) + + if self.image_pooling: + img_avg = self.global_avg_pool(x) + img_avg = F.resize_bilinear(img_avg, out_shape=x.shape[2:]) + outputs.append(img_avg) + + x = paddle.concat(outputs, axis=1) x = self.conv_bn_relu(x) - x = F.dropout(x, p=self.dropout_prob) - x = self.conv(x) - return x + x = self.dropout(x) + return x + class PPModule(nn.Layer): """ - Pyramid pooling module + Pyramid pooling module orginally in PSPNet Args: in_channels (int): the number of intput channels to pyramid pooling module. @@ -109,6 +120,7 @@ class PPModule(nn.Layer): bin_sizes=(1, 2, 3, 6), dim_reduction=True): super(PPModule, self).__init__() + self.bin_sizes = bin_sizes inter_channels = in_channels @@ -121,7 +133,7 @@ class PPModule(nn.Layer): for size in bin_sizes ]) - self.conv_bn_relu2 = layer_utils.ConvBnRelu( + self.conv_bn_relu2 = layer_libs.ConvBnRelu( in_channels=in_channels + inter_channels * len(bin_sizes), out_channels=out_channels, kernel_size=3, @@ -147,24 +159,21 @@ class PPModule(nn.Layer): conv (tensor): a tensor after Pyramid Pooling Module """ - # this paddle version does not support AdaptiveAvgPool2d, so skip it here. - # prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) - conv = layer_utils.ConvBnRelu( + prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) + conv = layer_libs.ConvBnRelu( in_channels=in_channels, out_channels=out_channels, kernel_size=1) - return conv + return nn.Sequential(prior, conv) def forward(self, input): cat_layers = [] for i, stage in enumerate(self.stages): size = self.bin_sizes[i] - x = F.adaptive_pool2d( - input, pool_size=(size, size), pool_type="max") - x = stage(x) + x = stage(input) x = F.resize_bilinear(x, out_shape=input.shape[2:]) cat_layers.append(x) cat_layers = [input] + cat_layers[::-1] cat = paddle.concat(cat_layers, axis=1) out = self.conv_bn_relu2(cat) - return out + return out \ No newline at end of file -- GitLab