提交 ea9c7cfd 编写于 作者: M michaelowenliu

merge layer_utils and model_utils into layer_libs

上级 661c2ffc
......@@ -13,5 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import layer_utils
from . import model_utils
\ No newline at end of file
from . import layer_libs
from . import activation
from . import pyramid_pool
\ No newline at end of file
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
from paddle.nn.layer import activation
class Activation(nn.Layer):
"""
The wrapper of activations
For example:
>>> relu = Activation("relu")
>>> print(relu)
<class 'paddle.nn.layer.activation.ReLU'>
>>> sigmoid = Activation("sigmoid")
>>> print(sigmoid)
<class 'paddle.nn.layer.activation.Sigmoid'>
>>> not_exit_one = Activation("not_exit_one")
KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
Args:
act (str): the activation name in lowercase
"""
def __init__(self, act=None):
super(Activation, self).__init__()
self._act = act
upper_act_names = activation.__all__
lower_act_names = [act.lower() for act in upper_act_names]
act_dict = dict(zip(lower_act_names, upper_act_names))
if act is not None:
if act in act_dict.keys():
act_name = act_dict[act]
self.act_func = eval("activation.{}()".format(act_name))
else:
raise KeyError("{} does not exist in the current {}".format(
act, act_dict.keys()))
def forward(self, x):
if self._act is not None:
return self.act_func(x)
else:
return x
\ No newline at end of file
......@@ -70,18 +70,6 @@ class ConvReluPool(nn.Layer):
return x
# class ConvBnReluUpsample(nn.Layer):
# def __init__(self, in_channels, out_channels):
# super(ConvBnReluUpsample, self).__init__()
# self.conv_bn_relu = ConvBnRelu(in_channels, out_channels)
# def forward(self, x, upsample_scale=2):
# x = self.conv_bn_relu(x)
# new_shape = [x.shape[2] * upsample_scale, x.shape[3] * upsample_scale]
# x = F.resize_bilinear(x, new_shape)
# return x
class DepthwiseConvBnRelu(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(DepthwiseConvBnRelu, self).__init__()
......@@ -100,44 +88,43 @@ class DepthwiseConvBnRelu(nn.Layer):
return x
class Activation(nn.Layer):
class AuxLayer(nn.Layer):
"""
The wrapper of activations
For example:
>>> relu = Activation("relu")
>>> print(relu)
<class 'paddle.nn.layer.activation.ReLU'>
>>> sigmoid = Activation("sigmoid")
>>> print(sigmoid)
<class 'paddle.nn.layer.activation.Sigmoid'>
>>> not_exit_one = Activation("not_exit_one")
KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
The auxilary layer implementation for auxilary loss
Args:
act (str): the activation name in lowercase
in_channels (int): the number of input channels.
inter_channels (int): intermediate channels.
out_channels (int): the number of output channels, which is usually num_classes.
dropout_prob (float): the droput rate. Default to 0.1.
"""
def __init__(self, act=None):
super(Activation, self).__init__()
def __init__(self,
in_channels,
inter_channels,
out_channels,
dropout_prob=0.1):
super(AuxLayer, self).__init__()
self._act = act
upper_act_names = activation.__all__
lower_act_names = [act.lower() for act in upper_act_names]
act_dict = dict(zip(lower_act_names, upper_act_names))
self.conv_bn_relu = ConvBnRelu(
in_channels=in_channels,
out_channels=inter_channels,
kernel_size=3,
padding=1)
if act is not None:
if act in act_dict.keys():
act_name = act_dict[act]
self.act_func = eval("activation.{}()".format(act_name))
else:
raise KeyError("{} does not exist in the current {}".format(
act, act_dict.keys()))
self.conv = nn.Conv2d(
in_channels=inter_channels,
out_channels=out_channels,
kernel_size=1)
def forward(self, x):
self.dropout_prob = dropout_prob
if self._act is not None:
return self.act_func(x)
else:
def forward(self, x):
x = self.conv_bn_relu(x)
x = F.dropout(x, p=self.dropout_prob)
x = self.conv(x)
return x
......@@ -13,85 +13,96 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.models.common import layer_utils
from paddleseg.models.common import layer_libs
class FCNHead(nn.Layer):
class ASPPModule(nn.Layer):
"""
The FCNHead implementation used in auxilary layer
Atrous Spatial Pyramid Pooling
Args:
in_channels (int): the number of input channels
out_channels (int): the number of output channels
"""
def __init__(self, in_channels, out_channels):
super(FCNHead, self).__init__()
inter_channels = in_channels // 4
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels=in_channels,
out_channels=inter_channels,
kernel_size=3,
padding=1)
aspp_ratios (tuple): the dilation rate using in ASSP module.
self.conv = nn.Conv2d(
in_channels=inter_channels,
out_channels=out_channels,
kernel_size=1)
in_channels (int): the number of input channels.
def forward(self, x):
x = self.conv_bn_relu(x)
x = F.dropout(x, p=0.1)
x = self.conv(x)
return x
out_channels (int): the number of output channels.
sep_conv (bool): if using separable conv in ASPP module.
class AuxLayer(nn.Layer):
"""
The auxilary layer implementation for auxilary loss
image_pooling: if augmented with image-level features.
Args:
in_channels (int): the number of input channels.
inter_channels (int): intermediate channels.
out_channels (int): the number of output channels, which is usually num_classes.
"""
def __init__(self,
aspp_ratios,
in_channels,
inter_channels,
out_channels,
dropout_prob=0.1):
super(AuxLayer, self).__init__()
sep_conv=False,
image_pooling=False):
super(ASPPModule, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels=in_channels,
out_channels=inter_channels,
kernel_size=3,
padding=1)
self.aspp_blocks = []
for ratio in aspp_ratios:
if sep_conv and ratio > 1:
conv_func = layer_libs.DepthwiseConvBnRelu
else:
conv_func = layer_libs.ConvBnRelu
self.conv = nn.Conv2d(
in_channels=inter_channels,
block = conv_func(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1 if ratio == 1 else 3,
dilation=ratio,
padding=0 if ratio == 1 else ratio
)
self.aspp_blocks.append(block)
out_size = len(self.aspp_blocks)
if image_pooling:
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
layer_libs.ConvBnRelu(in_channels, out_channels, kernel_size=1, bias_attr=False)
)
out_size += 1
self.image_pooling = image_pooling
self.conv_bn_relu = layer_libs.ConvBnRelu(
in_channels=out_channels * out_size,
out_channels=out_channels,
kernel_size=1)
self.dropout_prob = dropout_prob
self.dropout = nn.Dropout(p=0.1) # drop rate
def forward(self, x):
outputs = []
for block in self.aspp_blocks:
outputs.append(block(x))
if self.image_pooling:
img_avg = self.global_avg_pool(x)
img_avg = F.resize_bilinear(img_avg, out_shape=x.shape[2:])
outputs.append(img_avg)
x = paddle.concat(outputs, axis=1)
x = self.conv_bn_relu(x)
x = F.dropout(x, p=self.dropout_prob)
x = self.conv(x)
x = self.dropout(x)
return x
class PPModule(nn.Layer):
"""
Pyramid pooling module
Pyramid pooling module orginally in PSPNet
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
......@@ -109,6 +120,7 @@ class PPModule(nn.Layer):
bin_sizes=(1, 2, 3, 6),
dim_reduction=True):
super(PPModule, self).__init__()
self.bin_sizes = bin_sizes
inter_channels = in_channels
......@@ -121,7 +133,7 @@ class PPModule(nn.Layer):
for size in bin_sizes
])
self.conv_bn_relu2 = layer_utils.ConvBnRelu(
self.conv_bn_relu2 = layer_libs.ConvBnRelu(
in_channels=in_channels + inter_channels * len(bin_sizes),
out_channels=out_channels,
kernel_size=3,
......@@ -147,20 +159,17 @@ class PPModule(nn.Layer):
conv (tensor): a tensor after Pyramid Pooling Module
"""
# this paddle version does not support AdaptiveAvgPool2d, so skip it here.
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = layer_utils.ConvBnRelu(
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = layer_libs.ConvBnRelu(
in_channels=in_channels, out_channels=out_channels, kernel_size=1)
return conv
return nn.Sequential(prior, conv)
def forward(self, input):
cat_layers = []
for i, stage in enumerate(self.stages):
size = self.bin_sizes[i]
x = F.adaptive_pool2d(
input, pool_size=(size, size), pool_type="max")
x = stage(x)
x = stage(input)
x = F.resize_bilinear(x, out_shape=input.shape[2:])
cat_layers.append(x)
cat_layers = [input] + cat_layers[::-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册