CSPdarknet.py 4.9 KB
Newer Older
Bubbliiiing's avatar
Bubbliiiing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
import torch
import torch.nn as nn


def autopad(k, p=None):
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k] 
    return p

class SiLU(nn.Module):  
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)
    
class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))
    
class RCSPDark_Block(nn.Module):
    def __init__(self, c1, c2, c3, n=4, e=0.5, ids=[0]):
        super(RCSPDark_Block, self).__init__()
        c_ = int(c1 * e)
        
        self.ids = ids
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = nn.ModuleList(
            [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
        )
        self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)

    def forward(self, x):
        x_1 = self.cv1(x)
        x_2 = self.cv2(x)
        
        x_all = [x_1, x_2]
        for i in range(len(self.cv3)):
            x_2 = self.cv3[i](x_2)
            x_all.append(x_2)
            
        out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
        return out

class MP(nn.Module):
    def __init__(self, k=2):
        super(MP, self).__init__()
        self.m = nn.MaxPool2d(kernel_size=k, stride=k)

    def forward(self, x):
        return self.m(x)
    
class RCSPDark_Transition(nn.Module):
    def __init__(self, c1, c2):
        super(RCSPDark_Transition, self).__init__()
        self.cv1 = Conv(c1, c2, 1, 1)
        self.cv2 = Conv(c1, c2, 1, 1)
        self.cv3 = Conv(c2, c2, 3, 2)
        
        self.mp  = MP()

    def forward(self, x):
        x_1 = self.mp(x)
        x_1 = self.cv1(x_1)
        
        x_2 = self.cv2(x)
        x_2 = self.cv3(x_2)
        
        return torch.cat([x_2, x_1], 1)
    
class CSPDarknet(nn.Module):
    def __init__(self, base_channels, pretrained=False):
        super().__init__()
        #-----------------------------------------------#
        #   输入图片是640, 640, 3
        #   初始的基本通道是64
        #-----------------------------------------------#
        
        self.stem = nn.Sequential(
            Conv(3, base_channels, 3, 1),
            Conv(base_channels, base_channels * 2, 3, 2),
            Conv(base_channels * 2, base_channels * 2, 3, 1),
        )
        self.dark2 = nn.Sequential(
            Conv(base_channels * 2, base_channels * 4, 3, 2),
            RCSPDark_Block(base_channels * 4, base_channels * 2, base_channels * 8, ids=[-1, -3, -5, -6]),
        )
        self.dark3 = nn.Sequential(
            RCSPDark_Transition(base_channels * 8, base_channels * 4),
            RCSPDark_Block(base_channels * 8, base_channels * 4, base_channels * 16, ids=[-1, -3, -5, -6]),
        )
        self.dark4 = nn.Sequential(
            RCSPDark_Transition(base_channels * 16, base_channels * 8),
            RCSPDark_Block(base_channels * 16, base_channels * 8, base_channels * 32, ids=[-1, -3, -5, -6]),
        )
        self.dark5 = nn.Sequential(
            RCSPDark_Transition(base_channels * 32, base_channels * 16),
            RCSPDark_Block(base_channels * 32, base_channels * 8, base_channels * 32, e=1/4, ids=[-1, -3, -5, -6]),
        )
        
        if pretrained:
            phi = 'l'
            url = {
                "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/cspdarknet_backbone_l.pth',
            }[phi]
            checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data")
            self.load_state_dict(checkpoint, strict=False)
            print("Load weights from ", url.split('/')[-1])

    def forward(self, x):
        x = self.stem[0](x)
        
        x = self.stem[1](x)
        x = self.stem[2](x)
        x = self.dark2(x)
        #-----------------------------------------------#
        #   dark3的输出为80, 80, 256,是一个有效特征层
        #-----------------------------------------------#
        x = self.dark3(x)
        feat1 = x
        #-----------------------------------------------#
        #   dark4的输出为40, 40, 512,是一个有效特征层
        #-----------------------------------------------#
        x = self.dark4(x)
        feat2 = x
        #-----------------------------------------------#
        #   dark5的输出为20, 20, 1024,是一个有效特征层
        #-----------------------------------------------#
        x = self.dark5(x)
        feat3 = x
        return feat1, feat2, feat3