提交 2f147565 编写于 作者: W wuzewu

Update layer lib

上级 b59a3db5
......@@ -21,11 +21,17 @@ from paddle.nn import SyncBatchNorm as BatchNorm
class ConvBNReLU(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(ConvBNReLU, self).__init__()
self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self._conv = Conv2d(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = BatchNorm(out_channels)
......@@ -37,10 +43,16 @@ class ConvBNReLU(nn.Layer):
class ConvBN(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(ConvBN, self).__init__()
self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self._conv = Conv2d(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = BatchNorm(out_channels)
def forward(self, x):
......@@ -67,17 +79,23 @@ class ConvReluPool(nn.Layer):
return x
class DepthwiseConvBNReLU(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(DepthwiseConvBNReLU, self).__init__()
class SeparableConvBNReLU(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(SeparableConvBNReLU, self).__init__()
self.depthwise_conv = ConvBN(
in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBNReLU(
in_channels, out_channels, kernel_size=1, groups=1)
in_channels, out_channels, kernel_size=1, padding=padding, groups=1)
def forward(self, x):
x = self.depthwise_conv(x)
......@@ -86,20 +104,23 @@ class DepthwiseConvBNReLU(nn.Layer):
class DepthwiseConvBN(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super(DepthwiseConvBN, self).__init__()
self.depthwise_conv = ConvBN(
in_channels,
out_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBN(
in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x):
x = self.depthwise_conv(x)
x = self.piontwise_conv(x)
return x
......
......@@ -46,7 +46,7 @@ class ASPPModule(nn.Layer):
for ratio in aspp_ratios:
if sep_conv and ratio > 1:
conv_func = layer_libs.DepthwiseConvBNReLU
conv_func = layer_libs.SeparableConvBNReLU
else:
conv_func = layer_libs.ConvBNReLU
......
......@@ -19,7 +19,7 @@ import paddle.nn.functional as F
from paddle import nn
from paddleseg.cvlibs import manager
from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer
from paddleseg.models.common.layer_libs import ConvBNReLU, SeparableConvBNReLU, AuxLayer
from paddleseg.utils import utils
__all__ = ['DeepLabV3P', 'DeepLabV3']
......@@ -234,9 +234,9 @@ class Decoder(nn.Layer):
self.conv_bn_relu1 = ConvBNReLU(
in_channels=in_channels, out_channels=48, kernel_size=1)
self.conv_bn_relu2 = DepthwiseConvBNReLU(
self.conv_bn_relu2 = SeparableConvBNReLU(
in_channels=304, out_channels=256, kernel_size=3, padding=1)
self.conv_bn_relu3 = DepthwiseConvBNReLU(
self.conv_bn_relu3 = SeparableConvBNReLU(
in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv = nn.Conv2d(
in_channels=256, out_channels=num_classes, kernel_size=1)
......
......@@ -17,9 +17,10 @@ from paddle import nn
from paddleseg.cvlibs import manager
from paddleseg.models.common import pyramid_pool
from paddleseg.models.common.layer_libs import ConvBNReLU, DepthwiseConvBNReLU, AuxLayer
from paddleseg.models.common.layer_libs import ConvBNReLU, SeparableConvBNReLU, AuxLayer
from paddleseg.utils import utils
@manager.MODELS.add_component
class FastSCNN(nn.Layer):
"""
......@@ -40,9 +41,7 @@ class FastSCNN(nn.Layer):
pretrained (str): the path of pretrained model. Default to None.
"""
def __init__(self,
num_classes,
enable_auxiliary_loss=True,
def __init__(self, num_classes, enable_auxiliary_loss=True,
pretrained=None):
super(FastSCNN, self).__init__()
......@@ -103,13 +102,13 @@ class LearningToDownsample(nn.Layer):
self.conv_bn_relu = ConvBNReLU(
in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2)
self.dsconv_bn_relu1 = DepthwiseConvBNReLU(
self.dsconv_bn_relu1 = SeparableConvBNReLU(
in_channels=dw_channels1,
out_channels=dw_channels2,
kernel_size=3,
stride=2,
padding=1)
self.dsconv_bn_relu2 = DepthwiseConvBNReLU(
self.dsconv_bn_relu2 = SeparableConvBNReLU(
in_channels=dw_channels2,
out_channels=out_channels,
kernel_size=3,
......@@ -297,13 +296,13 @@ class Classifier(nn.Layer):
def __init__(self, input_channels, num_classes):
super(Classifier, self).__init__()
self.dsconv1 = DepthwiseConvBNReLU(
self.dsconv1 = SeparableConvBNReLU(
in_channels=input_channels,
out_channels=input_channels,
kernel_size=3,
padding=1)
self.dsconv2 = DepthwiseConvBNReLU(
self.dsconv2 = SeparableConvBNReLU(
in_channels=input_channels,
out_channels=input_channels,
kernel_size=3,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册