diff --git a/nets/backbone.py b/nets/backbone.py index 2a5c5d36b250d7bf2ee42681dcc32a72a0eac0be..73019761efaeb6d69be0dc510f3a09d689ab16eb 100644 --- a/nets/backbone.py +++ b/nets/backbone.py @@ -43,6 +43,7 @@ class Multi_Concat_Block(nn.Module): x_2 = self.cv2(x) x_all = [x_1, x_2] + # [-1, -3, -5, -6] => [5, 3, 1, 0] for i in range(len(self.cv3)): x_2 = self.cv3[i](x_2) x_all.append(x_2) @@ -68,12 +69,15 @@ class Transition_Block(nn.Module): self.mp = MP() def forward(self, x): + # 160, 160, 256 => 80, 80, 256 => 80, 80, 128 x_1 = self.mp(x) x_1 = self.cv1(x_1) + # 160, 160, 256 => 160, 160, 128 => 80, 80, 128 x_2 = self.cv2(x) x_2 = self.cv3(x_2) + # 80, 80, 128 cat 80, 80, 128 => 80, 80, 256 return torch.cat([x_2, x_1], 1) class Backbone(nn.Module): @@ -86,23 +90,28 @@ class Backbone(nn.Module): 'l' : [-1, -3, -5, -6], 'x' : [-1, -3, -5, -7, -8], }[phi] + # 640, 640, 3 => 640, 640, 32 => 320, 320, 64 self.stem = nn.Sequential( Conv(3, transition_channels, 3, 1), Conv(transition_channels, transition_channels * 2, 3, 2), Conv(transition_channels * 2, transition_channels * 2, 3, 1), ) + # 320, 320, 64 => 160, 160, 128 => 160, 160, 256 self.dark2 = nn.Sequential( Conv(transition_channels * 2, transition_channels * 4, 3, 2), Multi_Concat_Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids), ) + # 160, 160, 256 => 80, 80, 256 => 80, 80, 512 self.dark3 = nn.Sequential( Transition_Block(transition_channels * 8, transition_channels * 4), Multi_Concat_Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids), ) + # 80, 80, 512 => 40, 40, 512 => 40, 40, 1024 self.dark4 = nn.Sequential( Transition_Block(transition_channels * 16, transition_channels * 8), Multi_Concat_Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids), ) + # 40, 40, 1024 => 20, 20, 1024 => 20, 20, 1024 self.dark5 = nn.Sequential( Transition_Block(transition_channels * 32, transition_channels * 16), Multi_Concat_Block(transition_channels * 32, block_channels * 8, transition_channels * 32, n=n, ids=ids), @@ -121,12 +130,12 @@ class Backbone(nn.Module): x = self.stem(x) x = self.dark2(x) #-----------------------------------------------# - # dark3的输出为80, 80, 256,是一个有效特征层 + # dark3的输出为80, 80, 512,是一个有效特征层 #-----------------------------------------------# x = self.dark3(x) feat1 = x #-----------------------------------------------# - # dark4的输出为40, 40, 512,是一个有效特征层 + # dark4的输出为40, 40, 1024,是一个有效特征层 #-----------------------------------------------# x = self.dark4(x) feat2 = x diff --git a/nets/yolo.py b/nets/yolo.py index a443762c927ea73fd9a84b0a2d56c9b72e466cb3..0c0a7c5f3d0591b4d8efe920a3856af2bbf6aecb 100644 --- a/nets/yolo.py +++ b/nets/yolo.py @@ -17,6 +17,7 @@ class SPPCSPC(nn.Module): self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) self.cv5 = Conv(4 * c_, c_, 1, 1) self.cv6 = Conv(c_, c_, 3, 1) + # 输出通道数为c2 self.cv7 = Conv(2 * c_, c2, 1, 1) def forward(self, x): @@ -199,11 +200,13 @@ def fuse_conv_and_bn(conv, bn): w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) + # fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape).detach()) b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) - fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + # fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + fusedconv.bias.copy_((torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn).detach()) return fusedconv #---------------------------------------------------# @@ -235,29 +238,49 @@ class YoloBody(nn.Module): #---------------------------------------------------# self.backbone = Backbone(transition_channels, block_channels, n, phi, pretrained=pretrained) + #------------------------加强特征提取网络------------------------# self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + # 20, 20, 1024 => 20, 20, 512 self.sppcspc = SPPCSPC(transition_channels * 32, transition_channels * 16) + # 20, 20, 512 => 20, 20, 256 => 40, 40, 256 self.conv_for_P5 = Conv(transition_channels * 16, transition_channels * 8) + # 40, 40, 1024 => 40, 40, 256 self.conv_for_feat2 = Conv(transition_channels * 32, transition_channels * 8) + # 40, 40, 512 => 40, 40, 256 self.conv3_for_upsample1 = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids) + # 40, 40, 256 => 40, 40, 128 => 80, 80, 128 self.conv_for_P4 = Conv(transition_channels * 8, transition_channels * 4) + # 80, 80, 512 => 80, 80, 128 self.conv_for_feat1 = Conv(transition_channels * 16, transition_channels * 4) + # 80, 80, 256 => 80, 80, 128 self.conv3_for_upsample2 = Multi_Concat_Block(transition_channels * 8, panet_channels * 2, transition_channels * 4, e=e, n=n, ids=ids) + # 80, 80, 128 => 40, 40, 256 self.down_sample1 = Transition_Block(transition_channels * 4, transition_channels * 4) + # 40, 40, 512 => 40, 40, 256 self.conv3_for_downsample1 = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids) + # 40, 40, 256 => 20, 20, 512 self.down_sample2 = Transition_Block(transition_channels * 8, transition_channels * 8) + # 20, 20, 1024 => 20, 20, 512 self.conv3_for_downsample2 = Multi_Concat_Block(transition_channels * 32, panet_channels * 8, transition_channels * 16, e=e, n=n, ids=ids) + #------------------------加强特征提取网络------------------------# + # 80, 80, 128 => 80, 80, 256 self.rep_conv_1 = conv(transition_channels * 4, transition_channels * 8, 3, 1) + # 40, 40, 256 => 40, 40, 512 self.rep_conv_2 = conv(transition_channels * 8, transition_channels * 16, 3, 1) + # 20, 20, 512 => 20, 20, 1024 self.rep_conv_3 = conv(transition_channels * 16, transition_channels * 32, 3, 1) + # 4 + 1 + num_classes + # 80, 80, 256 => 80, 80, 3 * 25 (4 + 1 + 20) & 85 (4 + 1 + 80) self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1) + # 40, 40, 512 => 40, 40, 3 * 25 & 85 self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1) + # 20, 20, 512 => 20, 20, 3 * 25 & 85 self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1) def fuse(self): @@ -275,24 +298,44 @@ class YoloBody(nn.Module): # backbone feat1, feat2, feat3 = self.backbone.forward(x) + #------------------------加强特征提取网络------------------------# + # 20, 20, 1024 => 20, 20, 512 P5 = self.sppcspc(feat3) + # 20, 20, 512 => 20, 20, 256 P5_conv = self.conv_for_P5(P5) + # 20, 20, 256 => 40, 40, 256 P5_upsample = self.upsample(P5_conv) + # 40, 40, 256 cat 40, 40, 256 => 40, 40, 512 P4 = torch.cat([self.conv_for_feat2(feat2), P5_upsample], 1) + # 40, 40, 512 => 40, 40, 256 P4 = self.conv3_for_upsample1(P4) + # 40, 40, 256 => 40, 40, 128 P4_conv = self.conv_for_P4(P4) + # 40, 40, 128 => 80, 80, 128 P4_upsample = self.upsample(P4_conv) + # 80, 80, 128 cat 80, 80, 128 => 80, 80, 256 P3 = torch.cat([self.conv_for_feat1(feat1), P4_upsample], 1) + # 80, 80, 256 => 80, 80, 128 P3 = self.conv3_for_upsample2(P3) + # 80, 80, 128 => 40, 40, 256 P3_downsample = self.down_sample1(P3) + # 40, 40, 256 cat 40, 40, 256 => 40, 40, 512 P4 = torch.cat([P3_downsample, P4], 1) + # 40, 40, 512 => 40, 40, 256 P4 = self.conv3_for_downsample1(P4) + # 40, 40, 256 => 20, 20, 512 P4_downsample = self.down_sample2(P4) + # 20, 20, 512 cat 20, 20, 512 => 20, 20, 1024 P5 = torch.cat([P4_downsample, P5], 1) + # 20, 20, 1024 => 20, 20, 512 P5 = self.conv3_for_downsample2(P5) + #------------------------加强特征提取网络------------------------# + # P3 80, 80, 128 + # P4 40, 40, 256 + # P5 20, 20, 512 P3 = self.rep_conv_1(P3) P4 = self.rep_conv_2(P4) diff --git a/yolo.py b/yolo.py index ab39d2ec797adb1187e717cdd7ac2654c23fc9ba..172c9e0afa83941086695f36f4380ef9701f6d03 100644 --- a/yolo.py +++ b/yolo.py @@ -133,6 +133,7 @@ class YOLO(object): image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) #---------------------------------------------------------# # 添加上batch_size维度 + # h, w, 3 => 3, h, w => 1, 3, h, w #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)