yolo.py 16.0 KB
Newer Older
Bubbliiiing's avatar
Bubbliiiing 已提交
1
import numpy as np
Bubbliiiing's avatar
Bubbliiiing 已提交
2 3 4
import torch
import torch.nn as nn

Bubbliiiing's avatar
Bubbliiiing 已提交
5
from nets.backbone import Backbone, Multi_Concat_Block, Conv, SiLU, Transition_Block, autopad
Bubbliiiing's avatar
Bubbliiiing 已提交
6

Bubbliiiing's avatar
Bubbliiiing 已提交
7 8 9

class SPPCSPC(nn.Module):
    # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
Bubbliiiing's avatar
Bubbliiiing 已提交
10
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
Bubbliiiing's avatar
Bubbliiiing 已提交
11 12 13 14 15 16 17 18 19
        super(SPPCSPC, self).__init__()
        c_ = int(2 * c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(c_, c_, 3, 1)
        self.cv4 = Conv(c_, c_, 1, 1)
        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
        self.cv5 = Conv(4 * c_, c_, 1, 1)
        self.cv6 = Conv(c_, c_, 3, 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
20
        # 输出通道数为c2
Bubbliiiing's avatar
Bubbliiiing 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33
        self.cv7 = Conv(2 * c_, c2, 1, 1)

    def forward(self, x):
        x1 = self.cv4(self.cv3(self.cv1(x)))
        y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
        y2 = self.cv2(x)
        return self.cv7(torch.cat((y1, y2), dim=1))

class RepConv(nn.Module):
    # Represented convolution
    # https://arxiv.org/abs/2101.03697
    def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=SiLU(), deploy=False):
        super(RepConv, self).__init__()
Bubbliiiing's avatar
Bubbliiiing 已提交
34 35 36 37 38
        self.deploy         = deploy
        self.groups         = g
        self.in_channels    = c1
        self.out_channels   = c2
        
Bubbliiiing's avatar
Bubbliiiing 已提交
39 40 41
        assert k == 3
        assert autopad(k, p) == 1

Bubbliiiing's avatar
Bubbliiiing 已提交
42 43
        padding_11  = autopad(k, p) - k // 2
        self.act    = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
Bubbliiiing's avatar
Bubbliiiing 已提交
44 45

        if deploy:
Bubbliiiing's avatar
Bubbliiiing 已提交
46
            self.rbr_reparam    = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)
Bubbliiiing's avatar
Bubbliiiing 已提交
47
        else:
Bubbliiiing's avatar
Bubbliiiing 已提交
48
            self.rbr_identity   = (nn.BatchNorm2d(num_features=c1, eps=0.001, momentum=0.03) if c2 == c1 and s == 1 else None)
Bubbliiiing's avatar
Bubbliiiing 已提交
49
            self.rbr_dense      = nn.Sequential(
Bubbliiiing's avatar
Bubbliiiing 已提交
50
                nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False),
Bubbliiiing's avatar
Bubbliiiing 已提交
51
                nn.BatchNorm2d(num_features=c2, eps=0.001, momentum=0.03),
Bubbliiiing's avatar
Bubbliiiing 已提交
52
            )
Bubbliiiing's avatar
Bubbliiiing 已提交
53
            self.rbr_1x1        = nn.Sequential(
Bubbliiiing's avatar
Bubbliiiing 已提交
54
                nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False),
Bubbliiiing's avatar
Bubbliiiing 已提交
55
                nn.BatchNorm2d(num_features=c2, eps=0.001, momentum=0.03),
Bubbliiiing's avatar
Bubbliiiing 已提交
56 57 58 59 60 61 62 63 64 65 66 67
            )

    def forward(self, inputs):
        if hasattr(self, "rbr_reparam"):
            return self.act(self.rbr_reparam(inputs))
        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)
        return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
    
    def get_equivalent_kernel_bias(self):
Bubbliiiing's avatar
Bubbliiiing 已提交
68 69 70
        kernel3x3, bias3x3  = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1  = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid    = self._fuse_bn_tensor(self.rbr_identity)
Bubbliiiing's avatar
Bubbliiiing 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        return (
            kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
            bias3x3 + bias1x1 + biasid,
        )

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
Bubbliiiing's avatar
Bubbliiiing 已提交
86
            kernel      = branch[0].weight
Bubbliiiing's avatar
Bubbliiiing 已提交
87 88
            running_mean = branch[1].running_mean
            running_var = branch[1].running_var
Bubbliiiing's avatar
Bubbliiiing 已提交
89 90 91
            gamma       = branch[1].weight
            beta        = branch[1].bias
            eps         = branch[1].eps
Bubbliiiing's avatar
Bubbliiiing 已提交
92 93 94 95 96 97 98 99 100 101
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, "id_tensor"):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros(
                    (self.in_channels, input_dim, 3, 3), dtype=np.float32
                )
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
Bubbliiiing's avatar
Bubbliiiing 已提交
102
            kernel      = self.id_tensor
Bubbliiiing's avatar
Bubbliiiing 已提交
103 104
            running_mean = branch.running_mean
            running_var = branch.running_var
Bubbliiiing's avatar
Bubbliiiing 已提交
105 106 107
            gamma       = branch.weight
            beta        = branch.bias
            eps         = branch.eps
Bubbliiiing's avatar
Bubbliiiing 已提交
108
        std = (running_var + eps).sqrt()
Bubbliiiing's avatar
Bubbliiiing 已提交
109
        t   = (gamma / std).reshape(-1, 1, 1, 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
110 111 112 113 114 115 116 117 118 119
        return kernel * t, beta - running_mean * gamma / std

    def repvgg_convert(self):
        kernel, bias = self.get_equivalent_kernel_bias()
        return (
            kernel.detach().cpu().numpy(),
            bias.detach().cpu().numpy(),
        )

    def fuse_conv_bn(self, conv, bn):
Bubbliiiing's avatar
Bubbliiiing 已提交
120 121
        std     = (bn.running_var + bn.eps).sqrt()
        bias    = bn.bias - bn.running_mean * bn.weight / std
Bubbliiiing's avatar
Bubbliiiing 已提交
122

Bubbliiiing's avatar
Bubbliiiing 已提交
123
        t       = (bn.weight / std).reshape(-1, 1, 1, 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
124 125
        weights = conv.weight * t

Bubbliiiing's avatar
Bubbliiiing 已提交
126 127
        bn      = nn.Identity()
        conv    = nn.Conv2d(in_channels = conv.in_channels,
Bubbliiiing's avatar
Bubbliiiing 已提交
128 129 130 131 132 133 134 135 136 137
                              out_channels = conv.out_channels,
                              kernel_size = conv.kernel_size,
                              stride=conv.stride,
                              padding = conv.padding,
                              dilation = conv.dilation,
                              groups = conv.groups,
                              bias = True,
                              padding_mode = conv.padding_mode)

        conv.weight = torch.nn.Parameter(weights)
Bubbliiiing's avatar
Bubbliiiing 已提交
138
        conv.bias   = torch.nn.Parameter(bias)
Bubbliiiing's avatar
Bubbliiiing 已提交
139 140 141 142 143 144
        return conv

    def fuse_repvgg_block(self):    
        if self.deploy:
            return
        print(f"RepConv.fuse_repvgg_block")
Bubbliiiing's avatar
Bubbliiiing 已提交
145
        self.rbr_dense  = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
Bubbliiiing's avatar
Bubbliiiing 已提交
146
        
Bubbliiiing's avatar
Bubbliiiing 已提交
147 148
        self.rbr_1x1    = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
        rbr_1x1_bias    = self.rbr_1x1.bias
Bubbliiiing's avatar
Bubbliiiing 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
        
        # Fuse self.rbr_identity
        if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
            identity_conv_1x1 = nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=self.groups, 
                    bias=False)
            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
            identity_conv_1x1.weight.data.fill_(0.0)
            identity_conv_1x1.weight.data.fill_diagonal_(1.0)
            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)

Bubbliiiing's avatar
Bubbliiiing 已提交
167 168 169
            identity_conv_1x1           = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
            bias_identity_expanded      = identity_conv_1x1.bias
            weight_identity_expanded    = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])            
Bubbliiiing's avatar
Bubbliiiing 已提交
170
        else:
Bubbliiiing's avatar
Bubbliiiing 已提交
171 172
            bias_identity_expanded      = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
            weight_identity_expanded    = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )            
Bubbliiiing's avatar
Bubbliiiing 已提交
173
        
Bubbliiiing's avatar
Bubbliiiing 已提交
174 175
        self.rbr_dense.weight   = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
        self.rbr_dense.bias     = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
Bubbliiiing's avatar
Bubbliiiing 已提交
176
                
Bubbliiiing's avatar
Bubbliiiing 已提交
177 178
        self.rbr_reparam    = self.rbr_dense
        self.deploy         = True
Bubbliiiing's avatar
Bubbliiiing 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

        if self.rbr_identity is not None:
            del self.rbr_identity
            self.rbr_identity = None

        if self.rbr_1x1 is not None:
            del self.rbr_1x1
            self.rbr_1x1 = None

        if self.rbr_dense is not None:
            del self.rbr_dense
            self.rbr_dense = None
            
def fuse_conv_and_bn(conv, bn):
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

Bubbliiiing's avatar
Bubbliiiing 已提交
201 202
    w_conv  = conv.weight.clone().view(conv.out_channels, -1)
    w_bn    = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
Bubbliiiing's avatar
Bubbliiiing 已提交
203 204
    # fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape).detach())
Bubbliiiing's avatar
Bubbliiiing 已提交
205

Bubbliiiing's avatar
Bubbliiiing 已提交
206 207
    b_conv  = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn    = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
Bubbliiiing's avatar
Bubbliiiing 已提交
208 209
    # fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
    fusedconv.bias.copy_((torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn).detach())
Bubbliiiing's avatar
Bubbliiiing 已提交
210 211 212 213 214 215
    return fusedconv

#---------------------------------------------------#
#   yolo_body
#---------------------------------------------------#
class YoloBody(nn.Module):
Bubbliiiing's avatar
Bubbliiiing 已提交
216
    def __init__(self, anchors_mask, num_classes, phi, pretrained=False):
Bubbliiiing's avatar
Bubbliiiing 已提交
217
        super(YoloBody, self).__init__()
Bubbliiiing's avatar
Bubbliiiing 已提交
218 219 220 221 222 223 224 225 226 227
        #-----------------------------------------------#
        #   定义了不同yolov7版本的参数
        #-----------------------------------------------#
        transition_channels = {'l' : 32, 'x' : 40}[phi]
        block_channels      = 32
        panet_channels      = {'l' : 32, 'x' : 64}[phi]
        e       = {'l' : 2, 'x' : 1}[phi]
        n       = {'l' : 4, 'x' : 6}[phi]
        ids     = {'l' : [-1, -2, -3, -4, -5, -6], 'x' : [-1, -3, -5, -7, -8]}[phi]
        conv    = {'l' : RepConv, 'x' : Conv}[phi]
Bubbliiiing's avatar
Bubbliiiing 已提交
228 229 230 231 232
        #-----------------------------------------------#
        #   输入图片是640, 640, 3
        #-----------------------------------------------#

        #---------------------------------------------------#   
Bubbliiiing's avatar
Bubbliiiing 已提交
233
        #   生成主干模型
Bubbliiiing's avatar
Bubbliiiing 已提交
234
        #   获得三个有效特征层,他们的shape分别是:
Bubbliiiing's avatar
Bubbliiiing 已提交
235 236 237
        #   80, 80, 512
        #   40, 40, 1024
        #   20, 20, 1024
Bubbliiiing's avatar
Bubbliiiing 已提交
238
        #---------------------------------------------------#
Bubbliiiing's avatar
Bubbliiiing 已提交
239
        self.backbone   = Backbone(transition_channels, block_channels, n, phi, pretrained=pretrained)
Bubbliiiing's avatar
Bubbliiiing 已提交
240

Bubbliiiing's avatar
Bubbliiiing 已提交
241
        #------------------------加强特征提取网络------------------------# 
Bubbliiiing's avatar
Bubbliiiing 已提交
242 243
        self.upsample   = nn.Upsample(scale_factor=2, mode="nearest")

Bubbliiiing's avatar
Bubbliiiing 已提交
244
        # 20, 20, 1024 => 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
245
        self.sppcspc                = SPPCSPC(transition_channels * 32, transition_channels * 16)
Bubbliiiing's avatar
Bubbliiiing 已提交
246
        # 20, 20, 512 => 20, 20, 256 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
247
        self.conv_for_P5            = Conv(transition_channels * 16, transition_channels * 8)
Bubbliiiing's avatar
Bubbliiiing 已提交
248
        # 40, 40, 1024 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
249
        self.conv_for_feat2         = Conv(transition_channels * 32, transition_channels * 8)
Bubbliiiing's avatar
Bubbliiiing 已提交
250
        # 40, 40, 512 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
251
        self.conv3_for_upsample1    = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids)
Bubbliiiing's avatar
Bubbliiiing 已提交
252

Bubbliiiing's avatar
Bubbliiiing 已提交
253
        # 40, 40, 256 => 40, 40, 128 => 80, 80, 128
Bubbliiiing's avatar
Bubbliiiing 已提交
254
        self.conv_for_P4            = Conv(transition_channels * 8, transition_channels * 4)
Bubbliiiing's avatar
Bubbliiiing 已提交
255
        # 80, 80, 512 => 80, 80, 128
Bubbliiiing's avatar
Bubbliiiing 已提交
256
        self.conv_for_feat1         = Conv(transition_channels * 16, transition_channels * 4)
Bubbliiiing's avatar
Bubbliiiing 已提交
257
        # 80, 80, 256 => 80, 80, 128
Bubbliiiing's avatar
Bubbliiiing 已提交
258
        self.conv3_for_upsample2    = Multi_Concat_Block(transition_channels * 8, panet_channels * 2, transition_channels * 4, e=e, n=n, ids=ids)
Bubbliiiing's avatar
Bubbliiiing 已提交
259

Bubbliiiing's avatar
Bubbliiiing 已提交
260
        # 80, 80, 128 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
261
        self.down_sample1           = Transition_Block(transition_channels * 4, transition_channels * 4)
Bubbliiiing's avatar
Bubbliiiing 已提交
262
        # 40, 40, 512 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
263
        self.conv3_for_downsample1  = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids)
Bubbliiiing's avatar
Bubbliiiing 已提交
264

Bubbliiiing's avatar
Bubbliiiing 已提交
265
        # 40, 40, 256 => 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
266
        self.down_sample2           = Transition_Block(transition_channels * 8, transition_channels * 8)
Bubbliiiing's avatar
Bubbliiiing 已提交
267
        # 20, 20, 1024 => 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
268
        self.conv3_for_downsample2  = Multi_Concat_Block(transition_channels * 32, panet_channels * 8, transition_channels * 16, e=e, n=n, ids=ids)
Bubbliiiing's avatar
Bubbliiiing 已提交
269
        #------------------------加强特征提取网络------------------------# 
Bubbliiiing's avatar
Bubbliiiing 已提交
270

Bubbliiiing's avatar
Bubbliiiing 已提交
271
        # 80, 80, 128 => 80, 80, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
272
        self.rep_conv_1 = conv(transition_channels * 4, transition_channels * 8, 3, 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
273
        # 40, 40, 256 => 40, 40, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
274
        self.rep_conv_2 = conv(transition_channels * 8, transition_channels * 16, 3, 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
275
        # 20, 20, 512 => 20, 20, 1024
Bubbliiiing's avatar
Bubbliiiing 已提交
276
        self.rep_conv_3 = conv(transition_channels * 16, transition_channels * 32, 3, 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
277

Bubbliiiing's avatar
Bubbliiiing 已提交
278 279
        # 4 + 1 + num_classes
        # 80, 80, 256 => 80, 80, 3 * 25 (4 + 1 + 20) & 85 (4 + 1 + 80)
Bubbliiiing's avatar
Bubbliiiing 已提交
280
        self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
281
        # 40, 40, 512 => 40, 40, 3 * 25 & 85
Bubbliiiing's avatar
Bubbliiiing 已提交
282
        self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
283
        # 20, 20, 512 => 20, 20, 3 * 25 & 85
Bubbliiiing's avatar
Bubbliiiing 已提交
284
        self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
285

Bubbliiiing's avatar
Bubbliiiing 已提交
286
    def fuse(self):
Bubbliiiing's avatar
Bubbliiiing 已提交
287 288 289 290 291
        print('Fusing layers... ')
        for m in self.modules():
            if isinstance(m, RepConv):
                m.fuse_repvgg_block()
            elif type(m) is Conv and hasattr(m, 'bn'):
Bubbliiiing's avatar
Bubbliiiing 已提交
292 293 294
                m.conv = fuse_conv_and_bn(m.conv, m.bn)
                delattr(m, 'bn')
                m.forward = m.fuseforward
Bubbliiiing's avatar
Bubbliiiing 已提交
295 296 297 298 299 300
        return self
    
    def forward(self, x):
        #  backbone
        feat1, feat2, feat3 = self.backbone.forward(x)
        
Bubbliiiing's avatar
Bubbliiiing 已提交
301 302
        #------------------------加强特征提取网络------------------------# 
        # 20, 20, 1024 => 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
303
        P5          = self.sppcspc(feat3)
Bubbliiiing's avatar
Bubbliiiing 已提交
304
        # 20, 20, 512 => 20, 20, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
305
        P5_conv     = self.conv_for_P5(P5)
Bubbliiiing's avatar
Bubbliiiing 已提交
306
        # 20, 20, 256 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
307
        P5_upsample = self.upsample(P5_conv)
Bubbliiiing's avatar
Bubbliiiing 已提交
308
        # 40, 40, 256 cat 40, 40, 256 => 40, 40, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
309
        P4          = torch.cat([self.conv_for_feat2(feat2), P5_upsample], 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
310
        # 40, 40, 512 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
311 312
        P4          = self.conv3_for_upsample1(P4)

Bubbliiiing's avatar
Bubbliiiing 已提交
313
        # 40, 40, 256 => 40, 40, 128
Bubbliiiing's avatar
Bubbliiiing 已提交
314
        P4_conv     = self.conv_for_P4(P4)
Bubbliiiing's avatar
Bubbliiiing 已提交
315
        # 40, 40, 128 => 80, 80, 128
Bubbliiiing's avatar
Bubbliiiing 已提交
316
        P4_upsample = self.upsample(P4_conv)
Bubbliiiing's avatar
Bubbliiiing 已提交
317
        # 80, 80, 128 cat 80, 80, 128 => 80, 80, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
318
        P3          = torch.cat([self.conv_for_feat1(feat1), P4_upsample], 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
319
        # 80, 80, 256 => 80, 80, 128
Bubbliiiing's avatar
Bubbliiiing 已提交
320 321
        P3          = self.conv3_for_upsample2(P3)

Bubbliiiing's avatar
Bubbliiiing 已提交
322
        # 80, 80, 128 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
323
        P3_downsample = self.down_sample1(P3)
Bubbliiiing's avatar
Bubbliiiing 已提交
324
        # 40, 40, 256 cat 40, 40, 256 => 40, 40, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
325
        P4 = torch.cat([P3_downsample, P4], 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
326
        # 40, 40, 512 => 40, 40, 256
Bubbliiiing's avatar
Bubbliiiing 已提交
327 328
        P4 = self.conv3_for_downsample1(P4)

Bubbliiiing's avatar
Bubbliiiing 已提交
329
        # 40, 40, 256 => 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
330
        P4_downsample = self.down_sample2(P4)
Bubbliiiing's avatar
Bubbliiiing 已提交
331
        # 20, 20, 512 cat 20, 20, 512 => 20, 20, 1024
Bubbliiiing's avatar
Bubbliiiing 已提交
332
        P5 = torch.cat([P4_downsample, P5], 1)
Bubbliiiing's avatar
Bubbliiiing 已提交
333
        # 20, 20, 1024 => 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
334
        P5 = self.conv3_for_downsample2(P5)
Bubbliiiing's avatar
Bubbliiiing 已提交
335 336 337 338
        #------------------------加强特征提取网络------------------------# 
        # P3 80, 80, 128 
        # P4 40, 40, 256
        # P5 20, 20, 512
Bubbliiiing's avatar
Bubbliiiing 已提交
339 340 341 342 343 344
        
        P3 = self.rep_conv_1(P3)
        P4 = self.rep_conv_2(P4)
        P5 = self.rep_conv_3(P5)
        #---------------------------------------------------#
        #   第三个特征层
Bubbliiiing's avatar
Bubbliiiing 已提交
345
        #   y3=(batch_size, 75, 80, 80)
Bubbliiiing's avatar
Bubbliiiing 已提交
346 347 348 349
        #---------------------------------------------------#
        out2 = self.yolo_head_P3(P3)
        #---------------------------------------------------#
        #   第二个特征层
Bubbliiiing's avatar
Bubbliiiing 已提交
350
        #   y2=(batch_size, 75, 40, 40)
Bubbliiiing's avatar
Bubbliiiing 已提交
351 352 353 354
        #---------------------------------------------------#
        out1 = self.yolo_head_P4(P4)
        #---------------------------------------------------#
        #   第一个特征层
Bubbliiiing's avatar
Bubbliiiing 已提交
355
        #   y1=(batch_size, 75, 20, 20)
Bubbliiiing's avatar
Bubbliiiing 已提交
356 357 358
        #---------------------------------------------------#
        out0 = self.yolo_head_P5(P5)

Bubbliiiing's avatar
Bubbliiiing 已提交
359
        return [out0, out1, out2]