提交 9d3343d0 编写于 作者: W wuzewu

Update bisenetv2

上级 2f147565
...@@ -21,27 +21,29 @@ import paddle.nn.functional as F ...@@ -21,27 +21,29 @@ import paddle.nn.functional as F
from paddleseg import utils from paddleseg import utils
from paddleseg.cvlibs import manager, param_init from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU, ConvBN, DepthwiseConvBN from paddleseg.models.common.layer_libs import ConvBNReLU, ConvBN, DepthwiseConvBN
from paddleseg.models.common.activation import Activation
class StemBlock(nn.Layer): class StemBlock(nn.Layer):
def __init__(self, in_dim, out_dim): def __init__(self, in_dim, out_dim):
super(StemBlock, self).__init__() super(StemBlock, self).__init__()
self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3, stride=2, padding=1) self.conv = ConvBNReLU(in_dim, out_dim, 3, stride=2)
self.conv_1x1 = ConvBNReLU(out_dim, out_dim // 2, 1)
self.conv2_3x3 = ConvBNReLU(
out_dim // 2, out_dim, 3, stride=2, padding=1)
self.conv3_3x3 = ConvBNReLU(out_dim * 2, out_dim, 3, padding=1)
self.mpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.left = nn.Sequential(
ConvBNReLU(out_dim, out_dim // 2, 1),
ConvBNReLU(out_dim // 2, out_dim, 3, stride=2))
self.right = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.fuse = ConvBNReLU(out_dim * 2, out_dim, 3)
def forward(self, x): def forward(self, x):
conv1 = self.conv_3x3(x) x = self.conv(x)
conv2 = self.conv_1x1(conv1) left = self.left(x)
conv3 = self.conv2_3x3(conv2) right = self.right(x)
pool = self.mpool(conv1) concat = paddle.concat([left, right], axis=1)
concat = paddle.concat([conv3, pool], axis=1) return self.fuse(concat)
return self.conv3_3x3(concat)
class ContextEmbeddingBlock(nn.Layer): class ContextEmbeddingBlock(nn.Layer):
...@@ -61,31 +63,42 @@ class ContextEmbeddingBlock(nn.Layer): ...@@ -61,31 +63,42 @@ class ContextEmbeddingBlock(nn.Layer):
return self.conv_3x3(conv1) return self.conv_3x3(conv1)
class GatherAndExpandsionLayer(nn.Layer): class GatherAndExpandsionLayer1(nn.Layer):
def __init__(self, in_dim, out_dim, expand, stride): """Gather And Expandsion Layer with stride 1"""
super(GatherAndExpandsionLayer, self).__init__()
def __init__(self, in_dim, out_dim, expand):
super(GatherAndExpandsionLayer1, self).__init__()
expand_dim = expand * in_dim
self.stride = stride self.conv = nn.Sequential(
self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3, padding=1) ConvBNReLU(in_dim, in_dim, 3), DepthwiseConvBN(
self.dwconv = DepthwiseConvBN( in_dim, expand_dim, 3), ConvBN(expand_dim, out_dim, 1))
out_dim, expand * out_dim, 3, stride=stride, padding=1)
self.dwconv2 = DepthwiseConvBN(
expand * out_dim, expand * out_dim, 3, padding=1)
self.dwconv3 = DepthwiseConvBN(
in_dim, out_dim, 3, stride=stride, padding=1)
self.conv_1x1 = ConvBN(expand * out_dim, out_dim, 1)
self.conv2_1x1 = ConvBN(out_dim, out_dim, 1)
def forward(self, x): def forward(self, x):
conv1 = self.conv_3x3(x) return F.relu(self.conv(x) + x)
fm = self.dwconv(conv1)
residual = x
if self.stride == 2: class GatherAndExpandsionLayer2(nn.Layer):
fm = self.dwconv2(fm) """Gather And Expandsion Layer with stride 2"""
residual = self.dwconv3(residual)
residual = self.conv2_1x1(residual) def __init__(self, in_dim, out_dim, expand):
fm = self.conv_1x1(fm) super(GatherAndExpandsionLayer2, self).__init__()
return F.relu(fm + residual)
expand_dim = expand * in_dim
self.branch_1 = nn.Sequential(
ConvBNReLU(in_dim, in_dim, 3),
DepthwiseConvBN(in_dim, expand_dim, 3, stride=2),
DepthwiseConvBN(expand_dim, expand_dim, 3),
ConvBN(expand_dim, out_dim, 1))
self.branch_2 = nn.Sequential(
DepthwiseConvBN(in_dim, in_dim, 3, stride=2),
ConvBN(in_dim, out_dim, 1))
def forward(self, x):
return F.relu(self.branch_1(x) + self.branch_2(x))
class DetailBranch(nn.Layer): class DetailBranch(nn.Layer):
...@@ -98,16 +111,16 @@ class DetailBranch(nn.Layer): ...@@ -98,16 +111,16 @@ class DetailBranch(nn.Layer):
self.convs = nn.Sequential( self.convs = nn.Sequential(
# stage 1 # stage 1
ConvBNReLU(3, C1, 3, stride=2, padding=1), ConvBNReLU(3, C1, 3, stride=2),
ConvBNReLU(C1, C1, 3, padding=1), ConvBNReLU(C1, C1, 3),
# stage 2 # stage 2
ConvBNReLU(C1, C2, 3, stride=2, padding=1), ConvBNReLU(C1, C2, 3, stride=2),
ConvBNReLU(C2, C2, 3, padding=1), ConvBNReLU(C2, C2, 3),
ConvBNReLU(C2, C2, 3, padding=1), ConvBNReLU(C2, C2, 3),
# stage 3 # stage 3
ConvBNReLU(C2, C3, 3, stride=2, padding=1), ConvBNReLU(C2, C3, 3, stride=2),
ConvBNReLU(C3, C3, 3, padding=1), ConvBNReLU(C3, C3, 3),
ConvBNReLU(C3, C3, 3, padding=1), ConvBNReLU(C3, C3, 3),
) )
def forward(self, x): def forward(self, x):
...@@ -124,18 +137,18 @@ class SemanticBranch(nn.Layer): ...@@ -124,18 +137,18 @@ class SemanticBranch(nn.Layer):
self.stem = StemBlock(3, C1) self.stem = StemBlock(3, C1)
self.stage3 = nn.Sequential( self.stage3 = nn.Sequential(
GatherAndExpandsionLayer(C1, C3, 6, 2), GatherAndExpandsionLayer2(C1, C3, 6),
GatherAndExpandsionLayer(C3, C3, 6, 1)) GatherAndExpandsionLayer1(C3, C3, 6))
self.stage4 = nn.Sequential( self.stage4 = nn.Sequential(
GatherAndExpandsionLayer(C3, C4, 6, 2), GatherAndExpandsionLayer2(C3, C4, 6),
GatherAndExpandsionLayer(C4, C4, 6, 1)) GatherAndExpandsionLayer1(C4, C4, 6))
self.stage5_4 = nn.Sequential( self.stage5_4 = nn.Sequential(
GatherAndExpandsionLayer(C4, C5, 6, 2), GatherAndExpandsionLayer2(C4, C5, 6),
GatherAndExpandsionLayer(C5, C5, 6, 1), GatherAndExpandsionLayer1(C5, C5, 6),
GatherAndExpandsionLayer(C5, C5, 6, 1), GatherAndExpandsionLayer1(C5, C5, 6),
GatherAndExpandsionLayer(C5, C5, 6, 1)) GatherAndExpandsionLayer1(C5, C5, 6))
self.ce = ContextEmbeddingBlock(C5, C5) self.ce = ContextEmbeddingBlock(C5, C5)
...@@ -154,47 +167,49 @@ class BGA(nn.Layer): ...@@ -154,47 +167,49 @@ class BGA(nn.Layer):
def __init__(self, out_dim): def __init__(self, out_dim):
super(BGA, self).__init__() super(BGA, self).__init__()
self.db_dwconv = DepthwiseConvBN(out_dim, out_dim, 3, padding=1) self.db_branch_keep = nn.Sequential(
self.db_conv_1x1 = nn.Conv2d(out_dim, out_dim, 1, 1) DepthwiseConvBN(out_dim, out_dim, 3), nn.Conv2d(
self.db_conv_3x3 = ConvBN(out_dim, out_dim, 3, stride=2, padding=1) out_dim, out_dim, 1))
self.db_apool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.sb_conv_3x3 = ConvBN(out_dim, out_dim, 3, padding=1) self.db_branch_down = nn.Sequential(
self.sb_dwconv = DepthwiseConvBN(out_dim, out_dim, 3, padding=1) ConvBN(out_dim, out_dim, 3, stride=2),
self.sb_conv_1x1 = nn.Conv2d(out_dim, out_dim, 1) nn.AvgPool2d(kernel_size=3, stride=2, padding=1))
self.conv = ConvBN(out_dim, out_dim, 3, padding=1) self.sb_branch_keep = nn.Sequential(
DepthwiseConvBN(out_dim, out_dim, 3), nn.Conv2d(
out_dim, out_dim, 1), Activation(act='sigmoid'))
def forward(self, dfm, sfm): self.sb_branch_up = nn.Sequential(
dconv1 = self.db_dwconv(dfm) ConvBN(out_dim, out_dim, 3),
dconv2 = self.db_conv_1x1(dconv1) nn.UpsamplingBilinear2d(scale_factor=4), Activation(act='sigmoid'))
dconv3 = self.db_conv_3x3(dfm)
dpool = self.db_apool(dconv3)
sconv1 = self.sb_conv_3x3(sfm) self.conv = ConvBN(out_dim, out_dim, 3)
sconv1 = F.resize_bilinear(sconv1, dconv2.shape[2:])
att1 = F.sigmoid(sconv1)
sconv2 = self.sb_dwconv(sfm)
att2 = self.sb_conv_1x1(sconv2)
att2 = F.sigmoid(att2)
fm = F.resize_bilinear(att2 * dpool, dconv2.shape[2:]) def forward(self, dfm, sfm):
_sum = att1 * dconv2 + fm db_feat_keep = self.db_branch_keep(dfm)
return self.conv(_sum) db_feat_down = self.db_branch_down(dfm)
sb_feat_keep = self.sb_branch_keep(sfm)
sb_feat_up = self.sb_branch_up(sfm)
db_feat = db_feat_down * sb_feat_keep
sb_feat = db_feat_down * sb_feat_keep
sb_feat = F.resize_bilinear(sb_feat, db_feat.shape[2:])
return self.conv(db_feat + sb_feat)
class SegHead(nn.Layer): class SegHead(nn.Layer):
def __init__(self, in_dim, out_dim, num_classes): def __init__(self, in_dim, mid_dim, num_classes):
super(SegHead, self).__init__() super(SegHead, self).__init__()
self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3) self.conv_3x3 = nn.Sequential(
self.conv_1x1 = nn.Conv2d(out_dim, num_classes, 1, 1) ConvBNReLU(in_dim, mid_dim, 3), nn.Dropout(0.1))
def forward(self, x, label=None): self.conv_1x1 = nn.Conv2d(mid_dim, num_classes, 1, 1)
def forward(self, x):
conv1 = self.conv_3x3(x) conv1 = self.conv_3x3(x)
conv2 = self.conv_1x1(conv1) conv2 = self.conv_1x1(conv1)
pred = F.resize_bilinear(conv2, x.shape[2:]) return conv2
return pred
@manager.MODELS.add_component @manager.MODELS.add_component
...@@ -215,9 +230,9 @@ class BiSeNet(nn.Layer): ...@@ -215,9 +230,9 @@ class BiSeNet(nn.Layer):
def __init__(self, num_classes, lambd=0.25, pretrained=None): def __init__(self, num_classes, lambd=0.25, pretrained=None):
super(BiSeNet, self).__init__() super(BiSeNet, self).__init__()
C1, C2, C3, C4, C5 = 64, 64, 128, 64, 128 C1, C2, C3 = 64, 64, 128
db_channels = (C1, C2, C3) db_channels = (C1, C2, C3)
C1, C3 = int(C1 * lambd), int(C3 * lambd) C1, C3, C4, C5 = int(C1 * lambd), int(C3 * lambd), 64, 128
sb_channels = (C1, C3, C4, C5) sb_channels = (C1, C3, C4, C5)
mid_channels = 128 mid_channels = 128
...@@ -242,7 +257,10 @@ class BiSeNet(nn.Layer): ...@@ -242,7 +257,10 @@ class BiSeNet(nn.Layer):
logit4 = self.aux_head4(feat4) logit4 = self.aux_head4(feat4)
logit = self.head(self.bga(dfm, sfm)) logit = self.head(self.bga(dfm, sfm))
return [logit, logit1, logit2, logit3, logit4] logits = [logit, logit1, logit2, logit3, logit4]
logits = [F.resize_bilinear(logit, x.shape[2:]) for logit in logits]
return logits
def init_weight(self, pretrained=None): def init_weight(self, pretrained=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册