提交 9d3343d0 编写于 作者: W wuzewu

Update bisenetv2

上级 2f147565
......@@ -21,27 +21,29 @@ import paddle.nn.functional as F
from paddleseg import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU, ConvBN, DepthwiseConvBN
from paddleseg.models.common.activation import Activation
class StemBlock(nn.Layer):
def __init__(self, in_dim, out_dim):
super(StemBlock, self).__init__()
self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3, stride=2, padding=1)
self.conv_1x1 = ConvBNReLU(out_dim, out_dim // 2, 1)
self.conv2_3x3 = ConvBNReLU(
out_dim // 2, out_dim, 3, stride=2, padding=1)
self.conv3_3x3 = ConvBNReLU(out_dim * 2, out_dim, 3, padding=1)
self.conv = ConvBNReLU(in_dim, out_dim, 3, stride=2)
self.mpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.left = nn.Sequential(
ConvBNReLU(out_dim, out_dim // 2, 1),
ConvBNReLU(out_dim // 2, out_dim, 3, stride=2))
self.right = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.fuse = ConvBNReLU(out_dim * 2, out_dim, 3)
def forward(self, x):
conv1 = self.conv_3x3(x)
conv2 = self.conv_1x1(conv1)
conv3 = self.conv2_3x3(conv2)
pool = self.mpool(conv1)
concat = paddle.concat([conv3, pool], axis=1)
return self.conv3_3x3(concat)
x = self.conv(x)
left = self.left(x)
right = self.right(x)
concat = paddle.concat([left, right], axis=1)
return self.fuse(concat)
class ContextEmbeddingBlock(nn.Layer):
......@@ -61,31 +63,42 @@ class ContextEmbeddingBlock(nn.Layer):
return self.conv_3x3(conv1)
class GatherAndExpandsionLayer(nn.Layer):
def __init__(self, in_dim, out_dim, expand, stride):
super(GatherAndExpandsionLayer, self).__init__()
class GatherAndExpandsionLayer1(nn.Layer):
"""Gather And Expandsion Layer with stride 1"""
def __init__(self, in_dim, out_dim, expand):
super(GatherAndExpandsionLayer1, self).__init__()
expand_dim = expand * in_dim
self.stride = stride
self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3, padding=1)
self.dwconv = DepthwiseConvBN(
out_dim, expand * out_dim, 3, stride=stride, padding=1)
self.dwconv2 = DepthwiseConvBN(
expand * out_dim, expand * out_dim, 3, padding=1)
self.dwconv3 = DepthwiseConvBN(
in_dim, out_dim, 3, stride=stride, padding=1)
self.conv_1x1 = ConvBN(expand * out_dim, out_dim, 1)
self.conv2_1x1 = ConvBN(out_dim, out_dim, 1)
self.conv = nn.Sequential(
ConvBNReLU(in_dim, in_dim, 3), DepthwiseConvBN(
in_dim, expand_dim, 3), ConvBN(expand_dim, out_dim, 1))
def forward(self, x):
conv1 = self.conv_3x3(x)
fm = self.dwconv(conv1)
residual = x
if self.stride == 2:
fm = self.dwconv2(fm)
residual = self.dwconv3(residual)
residual = self.conv2_1x1(residual)
fm = self.conv_1x1(fm)
return F.relu(fm + residual)
return F.relu(self.conv(x) + x)
class GatherAndExpandsionLayer2(nn.Layer):
"""Gather And Expandsion Layer with stride 2"""
def __init__(self, in_dim, out_dim, expand):
super(GatherAndExpandsionLayer2, self).__init__()
expand_dim = expand * in_dim
self.branch_1 = nn.Sequential(
ConvBNReLU(in_dim, in_dim, 3),
DepthwiseConvBN(in_dim, expand_dim, 3, stride=2),
DepthwiseConvBN(expand_dim, expand_dim, 3),
ConvBN(expand_dim, out_dim, 1))
self.branch_2 = nn.Sequential(
DepthwiseConvBN(in_dim, in_dim, 3, stride=2),
ConvBN(in_dim, out_dim, 1))
def forward(self, x):
return F.relu(self.branch_1(x) + self.branch_2(x))
class DetailBranch(nn.Layer):
......@@ -98,16 +111,16 @@ class DetailBranch(nn.Layer):
self.convs = nn.Sequential(
# stage 1
ConvBNReLU(3, C1, 3, stride=2, padding=1),
ConvBNReLU(C1, C1, 3, padding=1),
ConvBNReLU(3, C1, 3, stride=2),
ConvBNReLU(C1, C1, 3),
# stage 2
ConvBNReLU(C1, C2, 3, stride=2, padding=1),
ConvBNReLU(C2, C2, 3, padding=1),
ConvBNReLU(C2, C2, 3, padding=1),
ConvBNReLU(C1, C2, 3, stride=2),
ConvBNReLU(C2, C2, 3),
ConvBNReLU(C2, C2, 3),
# stage 3
ConvBNReLU(C2, C3, 3, stride=2, padding=1),
ConvBNReLU(C3, C3, 3, padding=1),
ConvBNReLU(C3, C3, 3, padding=1),
ConvBNReLU(C2, C3, 3, stride=2),
ConvBNReLU(C3, C3, 3),
ConvBNReLU(C3, C3, 3),
)
def forward(self, x):
......@@ -124,18 +137,18 @@ class SemanticBranch(nn.Layer):
self.stem = StemBlock(3, C1)
self.stage3 = nn.Sequential(
GatherAndExpandsionLayer(C1, C3, 6, 2),
GatherAndExpandsionLayer(C3, C3, 6, 1))
GatherAndExpandsionLayer2(C1, C3, 6),
GatherAndExpandsionLayer1(C3, C3, 6))
self.stage4 = nn.Sequential(
GatherAndExpandsionLayer(C3, C4, 6, 2),
GatherAndExpandsionLayer(C4, C4, 6, 1))
GatherAndExpandsionLayer2(C3, C4, 6),
GatherAndExpandsionLayer1(C4, C4, 6))
self.stage5_4 = nn.Sequential(
GatherAndExpandsionLayer(C4, C5, 6, 2),
GatherAndExpandsionLayer(C5, C5, 6, 1),
GatherAndExpandsionLayer(C5, C5, 6, 1),
GatherAndExpandsionLayer(C5, C5, 6, 1))
GatherAndExpandsionLayer2(C4, C5, 6),
GatherAndExpandsionLayer1(C5, C5, 6),
GatherAndExpandsionLayer1(C5, C5, 6),
GatherAndExpandsionLayer1(C5, C5, 6))
self.ce = ContextEmbeddingBlock(C5, C5)
......@@ -154,47 +167,49 @@ class BGA(nn.Layer):
def __init__(self, out_dim):
super(BGA, self).__init__()
self.db_dwconv = DepthwiseConvBN(out_dim, out_dim, 3, padding=1)
self.db_conv_1x1 = nn.Conv2d(out_dim, out_dim, 1, 1)
self.db_conv_3x3 = ConvBN(out_dim, out_dim, 3, stride=2, padding=1)
self.db_apool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.db_branch_keep = nn.Sequential(
DepthwiseConvBN(out_dim, out_dim, 3), nn.Conv2d(
out_dim, out_dim, 1))
self.sb_conv_3x3 = ConvBN(out_dim, out_dim, 3, padding=1)
self.sb_dwconv = DepthwiseConvBN(out_dim, out_dim, 3, padding=1)
self.sb_conv_1x1 = nn.Conv2d(out_dim, out_dim, 1)
self.db_branch_down = nn.Sequential(
ConvBN(out_dim, out_dim, 3, stride=2),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1))
self.conv = ConvBN(out_dim, out_dim, 3, padding=1)
self.sb_branch_keep = nn.Sequential(
DepthwiseConvBN(out_dim, out_dim, 3), nn.Conv2d(
out_dim, out_dim, 1), Activation(act='sigmoid'))
def forward(self, dfm, sfm):
dconv1 = self.db_dwconv(dfm)
dconv2 = self.db_conv_1x1(dconv1)
dconv3 = self.db_conv_3x3(dfm)
dpool = self.db_apool(dconv3)
self.sb_branch_up = nn.Sequential(
ConvBN(out_dim, out_dim, 3),
nn.UpsamplingBilinear2d(scale_factor=4), Activation(act='sigmoid'))
sconv1 = self.sb_conv_3x3(sfm)
sconv1 = F.resize_bilinear(sconv1, dconv2.shape[2:])
att1 = F.sigmoid(sconv1)
sconv2 = self.sb_dwconv(sfm)
att2 = self.sb_conv_1x1(sconv2)
att2 = F.sigmoid(att2)
self.conv = ConvBN(out_dim, out_dim, 3)
fm = F.resize_bilinear(att2 * dpool, dconv2.shape[2:])
_sum = att1 * dconv2 + fm
return self.conv(_sum)
def forward(self, dfm, sfm):
db_feat_keep = self.db_branch_keep(dfm)
db_feat_down = self.db_branch_down(dfm)
sb_feat_keep = self.sb_branch_keep(sfm)
sb_feat_up = self.sb_branch_up(sfm)
db_feat = db_feat_down * sb_feat_keep
sb_feat = db_feat_down * sb_feat_keep
sb_feat = F.resize_bilinear(sb_feat, db_feat.shape[2:])
return self.conv(db_feat + sb_feat)
class SegHead(nn.Layer):
def __init__(self, in_dim, out_dim, num_classes):
def __init__(self, in_dim, mid_dim, num_classes):
super(SegHead, self).__init__()
self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3)
self.conv_1x1 = nn.Conv2d(out_dim, num_classes, 1, 1)
self.conv_3x3 = nn.Sequential(
ConvBNReLU(in_dim, mid_dim, 3), nn.Dropout(0.1))
def forward(self, x, label=None):
self.conv_1x1 = nn.Conv2d(mid_dim, num_classes, 1, 1)
def forward(self, x):
conv1 = self.conv_3x3(x)
conv2 = self.conv_1x1(conv1)
pred = F.resize_bilinear(conv2, x.shape[2:])
return pred
return conv2
@manager.MODELS.add_component
......@@ -215,9 +230,9 @@ class BiSeNet(nn.Layer):
def __init__(self, num_classes, lambd=0.25, pretrained=None):
super(BiSeNet, self).__init__()
C1, C2, C3, C4, C5 = 64, 64, 128, 64, 128
C1, C2, C3 = 64, 64, 128
db_channels = (C1, C2, C3)
C1, C3 = int(C1 * lambd), int(C3 * lambd)
C1, C3, C4, C5 = int(C1 * lambd), int(C3 * lambd), 64, 128
sb_channels = (C1, C3, C4, C5)
mid_channels = 128
......@@ -242,7 +257,10 @@ class BiSeNet(nn.Layer):
logit4 = self.aux_head4(feat4)
logit = self.head(self.bga(dfm, sfm))
return [logit, logit1, logit2, logit3, logit4]
logits = [logit, logit1, logit2, logit3, logit4]
logits = [F.resize_bilinear(logit, x.shape[2:]) for logit in logits]
return logits
def init_weight(self, pretrained=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册