diff --git a/dygraph/paddleseg/models/bisenet.py b/dygraph/paddleseg/models/bisenet.py index a1ae897ef264812dbc3cc623317d290e36e37ff5..996b694fe1e31374a2c64e129556cbb9428f8abc 100644 --- a/dygraph/paddleseg/models/bisenet.py +++ b/dygraph/paddleseg/models/bisenet.py @@ -21,27 +21,29 @@ import paddle.nn.functional as F from paddleseg import utils from paddleseg.cvlibs import manager, param_init from paddleseg.models.common.layer_libs import ConvBNReLU, ConvBN, DepthwiseConvBN +from paddleseg.models.common.activation import Activation class StemBlock(nn.Layer): def __init__(self, in_dim, out_dim): super(StemBlock, self).__init__() - self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3, stride=2, padding=1) - self.conv_1x1 = ConvBNReLU(out_dim, out_dim // 2, 1) - self.conv2_3x3 = ConvBNReLU( - out_dim // 2, out_dim, 3, stride=2, padding=1) - self.conv3_3x3 = ConvBNReLU(out_dim * 2, out_dim, 3, padding=1) + self.conv = ConvBNReLU(in_dim, out_dim, 3, stride=2) - self.mpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.left = nn.Sequential( + ConvBNReLU(out_dim, out_dim // 2, 1), + ConvBNReLU(out_dim // 2, out_dim, 3, stride=2)) + + self.right = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.fuse = ConvBNReLU(out_dim * 2, out_dim, 3) def forward(self, x): - conv1 = self.conv_3x3(x) - conv2 = self.conv_1x1(conv1) - conv3 = self.conv2_3x3(conv2) - pool = self.mpool(conv1) - concat = paddle.concat([conv3, pool], axis=1) - return self.conv3_3x3(concat) + x = self.conv(x) + left = self.left(x) + right = self.right(x) + concat = paddle.concat([left, right], axis=1) + return self.fuse(concat) class ContextEmbeddingBlock(nn.Layer): @@ -61,31 +63,42 @@ class ContextEmbeddingBlock(nn.Layer): return self.conv_3x3(conv1) -class GatherAndExpandsionLayer(nn.Layer): - def __init__(self, in_dim, out_dim, expand, stride): - super(GatherAndExpandsionLayer, self).__init__() +class GatherAndExpandsionLayer1(nn.Layer): + """Gather And Expandsion Layer with stride 1""" + + def __init__(self, in_dim, out_dim, expand): + super(GatherAndExpandsionLayer1, self).__init__() + + expand_dim = expand * in_dim - self.stride = stride - self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3, padding=1) - self.dwconv = DepthwiseConvBN( - out_dim, expand * out_dim, 3, stride=stride, padding=1) - self.dwconv2 = DepthwiseConvBN( - expand * out_dim, expand * out_dim, 3, padding=1) - self.dwconv3 = DepthwiseConvBN( - in_dim, out_dim, 3, stride=stride, padding=1) - self.conv_1x1 = ConvBN(expand * out_dim, out_dim, 1) - self.conv2_1x1 = ConvBN(out_dim, out_dim, 1) + self.conv = nn.Sequential( + ConvBNReLU(in_dim, in_dim, 3), DepthwiseConvBN( + in_dim, expand_dim, 3), ConvBN(expand_dim, out_dim, 1)) def forward(self, x): - conv1 = self.conv_3x3(x) - fm = self.dwconv(conv1) - residual = x - if self.stride == 2: - fm = self.dwconv2(fm) - residual = self.dwconv3(residual) - residual = self.conv2_1x1(residual) - fm = self.conv_1x1(fm) - return F.relu(fm + residual) + return F.relu(self.conv(x) + x) + + +class GatherAndExpandsionLayer2(nn.Layer): + """Gather And Expandsion Layer with stride 2""" + + def __init__(self, in_dim, out_dim, expand): + super(GatherAndExpandsionLayer2, self).__init__() + + expand_dim = expand * in_dim + + self.branch_1 = nn.Sequential( + ConvBNReLU(in_dim, in_dim, 3), + DepthwiseConvBN(in_dim, expand_dim, 3, stride=2), + DepthwiseConvBN(expand_dim, expand_dim, 3), + ConvBN(expand_dim, out_dim, 1)) + + self.branch_2 = nn.Sequential( + DepthwiseConvBN(in_dim, in_dim, 3, stride=2), + ConvBN(in_dim, out_dim, 1)) + + def forward(self, x): + return F.relu(self.branch_1(x) + self.branch_2(x)) class DetailBranch(nn.Layer): @@ -98,16 +111,16 @@ class DetailBranch(nn.Layer): self.convs = nn.Sequential( # stage 1 - ConvBNReLU(3, C1, 3, stride=2, padding=1), - ConvBNReLU(C1, C1, 3, padding=1), + ConvBNReLU(3, C1, 3, stride=2), + ConvBNReLU(C1, C1, 3), # stage 2 - ConvBNReLU(C1, C2, 3, stride=2, padding=1), - ConvBNReLU(C2, C2, 3, padding=1), - ConvBNReLU(C2, C2, 3, padding=1), + ConvBNReLU(C1, C2, 3, stride=2), + ConvBNReLU(C2, C2, 3), + ConvBNReLU(C2, C2, 3), # stage 3 - ConvBNReLU(C2, C3, 3, stride=2, padding=1), - ConvBNReLU(C3, C3, 3, padding=1), - ConvBNReLU(C3, C3, 3, padding=1), + ConvBNReLU(C2, C3, 3, stride=2), + ConvBNReLU(C3, C3, 3), + ConvBNReLU(C3, C3, 3), ) def forward(self, x): @@ -124,18 +137,18 @@ class SemanticBranch(nn.Layer): self.stem = StemBlock(3, C1) self.stage3 = nn.Sequential( - GatherAndExpandsionLayer(C1, C3, 6, 2), - GatherAndExpandsionLayer(C3, C3, 6, 1)) + GatherAndExpandsionLayer2(C1, C3, 6), + GatherAndExpandsionLayer1(C3, C3, 6)) self.stage4 = nn.Sequential( - GatherAndExpandsionLayer(C3, C4, 6, 2), - GatherAndExpandsionLayer(C4, C4, 6, 1)) + GatherAndExpandsionLayer2(C3, C4, 6), + GatherAndExpandsionLayer1(C4, C4, 6)) self.stage5_4 = nn.Sequential( - GatherAndExpandsionLayer(C4, C5, 6, 2), - GatherAndExpandsionLayer(C5, C5, 6, 1), - GatherAndExpandsionLayer(C5, C5, 6, 1), - GatherAndExpandsionLayer(C5, C5, 6, 1)) + GatherAndExpandsionLayer2(C4, C5, 6), + GatherAndExpandsionLayer1(C5, C5, 6), + GatherAndExpandsionLayer1(C5, C5, 6), + GatherAndExpandsionLayer1(C5, C5, 6)) self.ce = ContextEmbeddingBlock(C5, C5) @@ -154,47 +167,49 @@ class BGA(nn.Layer): def __init__(self, out_dim): super(BGA, self).__init__() - self.db_dwconv = DepthwiseConvBN(out_dim, out_dim, 3, padding=1) - self.db_conv_1x1 = nn.Conv2d(out_dim, out_dim, 1, 1) - self.db_conv_3x3 = ConvBN(out_dim, out_dim, 3, stride=2, padding=1) - self.db_apool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.db_branch_keep = nn.Sequential( + DepthwiseConvBN(out_dim, out_dim, 3), nn.Conv2d( + out_dim, out_dim, 1)) - self.sb_conv_3x3 = ConvBN(out_dim, out_dim, 3, padding=1) - self.sb_dwconv = DepthwiseConvBN(out_dim, out_dim, 3, padding=1) - self.sb_conv_1x1 = nn.Conv2d(out_dim, out_dim, 1) + self.db_branch_down = nn.Sequential( + ConvBN(out_dim, out_dim, 3, stride=2), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1)) - self.conv = ConvBN(out_dim, out_dim, 3, padding=1) + self.sb_branch_keep = nn.Sequential( + DepthwiseConvBN(out_dim, out_dim, 3), nn.Conv2d( + out_dim, out_dim, 1), Activation(act='sigmoid')) - def forward(self, dfm, sfm): - dconv1 = self.db_dwconv(dfm) - dconv2 = self.db_conv_1x1(dconv1) - dconv3 = self.db_conv_3x3(dfm) - dpool = self.db_apool(dconv3) + self.sb_branch_up = nn.Sequential( + ConvBN(out_dim, out_dim, 3), + nn.UpsamplingBilinear2d(scale_factor=4), Activation(act='sigmoid')) - sconv1 = self.sb_conv_3x3(sfm) - sconv1 = F.resize_bilinear(sconv1, dconv2.shape[2:]) - att1 = F.sigmoid(sconv1) - sconv2 = self.sb_dwconv(sfm) - att2 = self.sb_conv_1x1(sconv2) - att2 = F.sigmoid(att2) + self.conv = ConvBN(out_dim, out_dim, 3) - fm = F.resize_bilinear(att2 * dpool, dconv2.shape[2:]) - _sum = att1 * dconv2 + fm - return self.conv(_sum) + def forward(self, dfm, sfm): + db_feat_keep = self.db_branch_keep(dfm) + db_feat_down = self.db_branch_down(dfm) + sb_feat_keep = self.sb_branch_keep(sfm) + sb_feat_up = self.sb_branch_up(sfm) + db_feat = db_feat_down * sb_feat_keep + sb_feat = db_feat_down * sb_feat_keep + sb_feat = F.resize_bilinear(sb_feat, db_feat.shape[2:]) + + return self.conv(db_feat + sb_feat) class SegHead(nn.Layer): - def __init__(self, in_dim, out_dim, num_classes): + def __init__(self, in_dim, mid_dim, num_classes): super(SegHead, self).__init__() - self.conv_3x3 = ConvBNReLU(in_dim, out_dim, 3) - self.conv_1x1 = nn.Conv2d(out_dim, num_classes, 1, 1) + self.conv_3x3 = nn.Sequential( + ConvBNReLU(in_dim, mid_dim, 3), nn.Dropout(0.1)) - def forward(self, x, label=None): + self.conv_1x1 = nn.Conv2d(mid_dim, num_classes, 1, 1) + + def forward(self, x): conv1 = self.conv_3x3(x) conv2 = self.conv_1x1(conv1) - pred = F.resize_bilinear(conv2, x.shape[2:]) - return pred + return conv2 @manager.MODELS.add_component @@ -215,9 +230,9 @@ class BiSeNet(nn.Layer): def __init__(self, num_classes, lambd=0.25, pretrained=None): super(BiSeNet, self).__init__() - C1, C2, C3, C4, C5 = 64, 64, 128, 64, 128 + C1, C2, C3 = 64, 64, 128 db_channels = (C1, C2, C3) - C1, C3 = int(C1 * lambd), int(C3 * lambd) + C1, C3, C4, C5 = int(C1 * lambd), int(C3 * lambd), 64, 128 sb_channels = (C1, C3, C4, C5) mid_channels = 128 @@ -242,7 +257,10 @@ class BiSeNet(nn.Layer): logit4 = self.aux_head4(feat4) logit = self.head(self.bga(dfm, sfm)) - return [logit, logit1, logit2, logit3, logit4] + logits = [logit, logit1, logit2, logit3, logit4] + logits = [F.resize_bilinear(logit, x.shape[2:]) for logit in logits] + + return logits def init_weight(self, pretrained=None): """