提交 2f147565 编写于 作者: W wuzewu

Update layer lib

上级 b59a3db5
...@@ -21,11 +21,17 @@ from paddle.nn import SyncBatchNorm as BatchNorm ...@@ -21,11 +21,17 @@ from paddle.nn import SyncBatchNorm as BatchNorm
class ConvBNReLU(nn.Layer): class ConvBNReLU(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs) self._conv = Conv2d(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = BatchNorm(out_channels) self._batch_norm = BatchNorm(out_channels)
...@@ -37,10 +43,16 @@ class ConvBNReLU(nn.Layer): ...@@ -37,10 +43,16 @@ class ConvBNReLU(nn.Layer):
class ConvBN(nn.Layer): class ConvBN(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(ConvBN, self).__init__() super(ConvBN, self).__init__()
self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs) self._conv = Conv2d(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = BatchNorm(out_channels) self._batch_norm = BatchNorm(out_channels)
def forward(self, x): def forward(self, x):
...@@ -67,17 +79,23 @@ class ConvReluPool(nn.Layer): ...@@ -67,17 +79,23 @@ class ConvReluPool(nn.Layer):
return x return x
class DepthwiseConvBNReLU(nn.Layer): class SeparableConvBNReLU(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self,
super(DepthwiseConvBNReLU, self).__init__() in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(SeparableConvBNReLU, self).__init__()
self.depthwise_conv = ConvBN( self.depthwise_conv = ConvBN(
in_channels, in_channels,
out_channels=in_channels, out_channels=in_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding,
groups=in_channels, groups=in_channels,
**kwargs) **kwargs)
self.piontwise_conv = ConvBNReLU( self.piontwise_conv = ConvBNReLU(
in_channels, out_channels, kernel_size=1, groups=1) in_channels, out_channels, kernel_size=1, padding=padding, groups=1)
def forward(self, x): def forward(self, x):
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
...@@ -86,20 +104,23 @@ class DepthwiseConvBNReLU(nn.Layer): ...@@ -86,20 +104,23 @@ class DepthwiseConvBNReLU(nn.Layer):
class DepthwiseConvBN(nn.Layer): class DepthwiseConvBN(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(DepthwiseConvBN, self).__init__() super(DepthwiseConvBN, self).__init__()
self.depthwise_conv = ConvBN( self.depthwise_conv = ConvBN(
in_channels, in_channels,
out_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding,
groups=in_channels, groups=in_channels,
**kwargs) **kwargs)
self.piontwise_conv = ConvBN(
in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x): def forward(self, x):
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.piontwise_conv(x)
return x return x
......
...@@ -46,7 +46,7 @@ class ASPPModule(nn.Layer): ...@@ -46,7 +46,7 @@ class ASPPModule(nn.Layer):
for ratio in aspp_ratios: for ratio in aspp_ratios:
if sep_conv and ratio > 1: if sep_conv and ratio > 1:
conv_func = layer_libs.DepthwiseConvBNReLU conv_func = layer_libs.SeparableConvBNReLU
else: else:
conv_func = layer_libs.ConvBNReLU conv_func = layer_libs.ConvBNReLU
...@@ -134,7 +134,7 @@ class PPModule(nn.Layer): ...@@ -134,7 +134,7 @@ class PPModule(nn.Layer):
Create one pooling layer. Create one pooling layer.
In our implementation, we adopt the same dimension reduction as the original paper that might be In our implementation, we adopt the same dimension reduction as the original paper that might be
slightly different with other implementations. slightly different with other implementations.
After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations
keep the channels to be same. keep the channels to be same.
......
...@@ -19,7 +19,7 @@ import paddle.nn.functional as F ...@@ -19,7 +19,7 @@ import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.models.common import pyramid_pool from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer from paddleseg.models.common.layer_libs import ConvBNReLU, SeparableConvBNReLU, AuxLayer
from paddleseg.utils import utils from paddleseg.utils import utils
__all__ = ['DeepLabV3P', 'DeepLabV3'] __all__ = ['DeepLabV3P', 'DeepLabV3']
...@@ -99,7 +99,7 @@ class DeepLabV3PHead(nn.Layer): ...@@ -99,7 +99,7 @@ class DeepLabV3PHead(nn.Layer):
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18). if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36). if output_stride=8, aspp_ratios is (1, 12, 24, 36).
aspp_out_channels (int): the output channels of ASPP module. aspp_out_channels (int): the output channels of ASPP module.
""" """
def __init__(self, def __init__(self,
...@@ -146,7 +146,7 @@ class DeepLabV3(nn.Layer): ...@@ -146,7 +146,7 @@ class DeepLabV3(nn.Layer):
(https://arxiv.org/pdf/1706.05587.pdf) (https://arxiv.org/pdf/1706.05587.pdf)
Args: Args:
Refer to DeepLabV3P above Refer to DeepLabV3P above
""" """
def __init__(self, def __init__(self,
...@@ -234,9 +234,9 @@ class Decoder(nn.Layer): ...@@ -234,9 +234,9 @@ class Decoder(nn.Layer):
self.conv_bn_relu1 = ConvBNReLU( self.conv_bn_relu1 = ConvBNReLU(
in_channels=in_channels, out_channels=48, kernel_size=1) in_channels=in_channels, out_channels=48, kernel_size=1)
self.conv_bn_relu2 = DepthwiseConvBNReLU( self.conv_bn_relu2 = SeparableConvBNReLU(
in_channels=304, out_channels=256, kernel_size=3, padding=1) in_channels=304, out_channels=256, kernel_size=3, padding=1)
self.conv_bn_relu3 = DepthwiseConvBNReLU( self.conv_bn_relu3 = SeparableConvBNReLU(
in_channels=256, out_channels=256, kernel_size=3, padding=1) in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels=256, out_channels=num_classes, kernel_size=1) in_channels=256, out_channels=num_classes, kernel_size=1)
......
...@@ -17,18 +17,19 @@ from paddle import nn ...@@ -17,18 +17,19 @@ from paddle import nn
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.models.common import pyramid_pool from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer from paddleseg.models.common.layer_libs import ConvBNReLU, SeparableConvBNReLU, AuxLayer
from paddleseg.utils import utils from paddleseg.utils import utils
@manager.MODELS.add_component @manager.MODELS.add_component
class FastSCNN(nn.Layer): class FastSCNN(nn.Layer):
""" """
The FastSCNN implementation based on PaddlePaddle. The FastSCNN implementation based on PaddlePaddle.
As mentioned in the original paper, FastSCNN is a real-time segmentation algorithm (123.5fps) As mentioned in the original paper, FastSCNN is a real-time segmentation algorithm (123.5fps)
even for high resolution images (1024x2048). even for high resolution images (1024x2048).
The original article refers to The original article refers to
Poudel, Rudra PK, et al. "Fast-scnn: Fast semantic segmentation network." Poudel, Rudra PK, et al. "Fast-scnn: Fast semantic segmentation network."
(https://arxiv.org/pdf/1902.04502.pdf) (https://arxiv.org/pdf/1902.04502.pdf)
...@@ -40,9 +41,7 @@ class FastSCNN(nn.Layer): ...@@ -40,9 +41,7 @@ class FastSCNN(nn.Layer):
pretrained (str): the path of pretrained model. Default to None. pretrained (str): the path of pretrained model. Default to None.
""" """
def __init__(self, def __init__(self, num_classes, enable_auxiliary_loss=True,
num_classes,
enable_auxiliary_loss=True,
pretrained=None): pretrained=None):
super(FastSCNN, self).__init__() super(FastSCNN, self).__init__()
...@@ -103,13 +102,13 @@ class LearningToDownsample(nn.Layer): ...@@ -103,13 +102,13 @@ class LearningToDownsample(nn.Layer):
self.conv_bn_relu = ConvBNReLU( self.conv_bn_relu = ConvBNReLU(
in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2) in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2)
self.dsconv_bn_relu1 = DepthwiseConvBNReLU( self.dsconv_bn_relu1 = SeparableConvBNReLU(
in_channels=dw_channels1, in_channels=dw_channels1,
out_channels=dw_channels2, out_channels=dw_channels2,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
padding=1) padding=1)
self.dsconv_bn_relu2 = DepthwiseConvBNReLU( self.dsconv_bn_relu2 = SeparableConvBNReLU(
in_channels=dw_channels2, in_channels=dw_channels2,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
...@@ -127,7 +126,7 @@ class GlobalFeatureExtractor(nn.Layer): ...@@ -127,7 +126,7 @@ class GlobalFeatureExtractor(nn.Layer):
""" """
Global feature extractor module Global feature extractor module
This module consists of three LinearBottleneck blocks (like inverted residual introduced by MobileNetV2) and This module consists of three LinearBottleneck blocks (like inverted residual introduced by MobileNetV2) and
a PPModule (introduced by PSPNet). a PPModule (introduced by PSPNet).
Args: Args:
...@@ -297,13 +296,13 @@ class Classifier(nn.Layer): ...@@ -297,13 +296,13 @@ class Classifier(nn.Layer):
def __init__(self, input_channels, num_classes): def __init__(self, input_channels, num_classes):
super(Classifier, self).__init__() super(Classifier, self).__init__()
self.dsconv1 = DepthwiseConvBNReLU( self.dsconv1 = SeparableConvBNReLU(
in_channels=input_channels, in_channels=input_channels,
out_channels=input_channels, out_channels=input_channels,
kernel_size=3, kernel_size=3,
padding=1) padding=1)
self.dsconv2 = DepthwiseConvBNReLU( self.dsconv2 = SeparableConvBNReLU(
in_channels=input_channels, in_channels=input_channels,
out_channels=input_channels, out_channels=input_channels,
kernel_size=3, kernel_size=3,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册