From cdc66baa477cd73782af5e445f461db5f722778b Mon Sep 17 00:00:00 2001 From: Bubbliiiing <3323290568@qq.com> Date: Sat, 9 Jul 2022 21:44:20 +0800 Subject: [PATCH] update a lot --- nets/{CSPdarknet.py => backbone.py} | 48 +- nets/yolo.py | 103 +++-- nets/yolo_training.py | 671 +++++++++++++++++----------- summary.py | 3 +- train.py | 12 +- utils/dataloader.py | 9 +- utils/utils.py | 6 +- utils/utils_bbox.py | 6 +- yolo.py | 8 +- 9 files changed, 523 insertions(+), 343 deletions(-) rename nets/{CSPdarknet.py => backbone.py} (70%) diff --git a/nets/CSPdarknet.py b/nets/backbone.py similarity index 70% rename from nets/CSPdarknet.py rename to nets/backbone.py index 91cabed..0521ebc 100644 --- a/nets/CSPdarknet.py +++ b/nets/backbone.py @@ -25,10 +25,10 @@ class Conv(nn.Module): def fuseforward(self, x): return self.act(self.conv(x)) -class RCSPDark_Block(nn.Module): - def __init__(self, c1, c2, c3, n=4, e=0.5, ids=[0]): - super(RCSPDark_Block, self).__init__() - c_ = int(c1 * e) +class Block(nn.Module): + def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): + super(Block, self).__init__() + c_ = int(c2 * e) self.ids = ids self.cv1 = Conv(c1, c_, 1, 1) @@ -58,9 +58,9 @@ class MP(nn.Module): def forward(self, x): return self.m(x) -class RCSPDark_Transition(nn.Module): +class Transition(nn.Module): def __init__(self, c1, c2): - super(RCSPDark_Transition, self).__init__() + super(Transition, self).__init__() self.cv1 = Conv(c1, c2, 1, 1) self.cv2 = Conv(c1, c2, 1, 1) self.cv3 = Conv(c2, c2, 3, 2) @@ -76,40 +76,42 @@ class RCSPDark_Transition(nn.Module): return torch.cat([x_2, x_1], 1) -class CSPDarknet(nn.Module): - def __init__(self, base_channels, pretrained=False): +class Backbone(nn.Module): + def __init__(self, transition_channels, block_channels, n, phi, pretrained=False): super().__init__() #-----------------------------------------------# # 输入图片是640, 640, 3 - # 初始的基本通道是64 #-----------------------------------------------# - + ids = { + 'l' : [-1, -3, -5, -6], + 'x' : [-1, -3, -5, -7, -8], + }[phi] self.stem = nn.Sequential( - Conv(3, base_channels, 3, 1), - Conv(base_channels, base_channels * 2, 3, 2), - Conv(base_channels * 2, base_channels * 2, 3, 1), + Conv(3, transition_channels, 3, 1), + Conv(transition_channels, transition_channels * 2, 3, 2), + Conv(transition_channels * 2, transition_channels * 2, 3, 1), ) self.dark2 = nn.Sequential( - Conv(base_channels * 2, base_channels * 4, 3, 2), - RCSPDark_Block(base_channels * 4, base_channels * 2, base_channels * 8, ids=[-1, -3, -5, -6]), + Conv(transition_channels * 2, transition_channels * 4, 3, 2), + Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids), ) self.dark3 = nn.Sequential( - RCSPDark_Transition(base_channels * 8, base_channels * 4), - RCSPDark_Block(base_channels * 8, base_channels * 4, base_channels * 16, ids=[-1, -3, -5, -6]), + Transition(transition_channels * 8, transition_channels * 4), + Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids), ) self.dark4 = nn.Sequential( - RCSPDark_Transition(base_channels * 16, base_channels * 8), - RCSPDark_Block(base_channels * 16, base_channels * 8, base_channels * 32, ids=[-1, -3, -5, -6]), + Transition(transition_channels * 16, transition_channels * 8), + Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids), ) self.dark5 = nn.Sequential( - RCSPDark_Transition(base_channels * 32, base_channels * 16), - RCSPDark_Block(base_channels * 32, base_channels * 8, base_channels * 32, e=1/4, ids=[-1, -3, -5, -6]), + Transition(transition_channels * 32, transition_channels * 16), + Block(transition_channels * 32, block_channels * 8, transition_channels * 32, n=n, ids=ids), ) if pretrained: - phi = 'l' url = { - "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/cspdarknet_backbone.pth', + "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone.pth', + "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone.pth', }[phi] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data") self.load_state_dict(checkpoint, strict=False) diff --git a/nets/yolo.py b/nets/yolo.py index 150d401..6a0445d 100644 --- a/nets/yolo.py +++ b/nets/yolo.py @@ -2,8 +2,7 @@ import numpy as np import torch import torch.nn as nn -from nets.CSPdarknet import (Conv, CSPDarknet, RCSPDark_Block, - RCSPDark_Transition, SiLU, autopad) +from nets.backbone import Backbone, Block, Conv, SiLU, Transition, autopad class SPPCSPC(nn.Module): @@ -65,9 +64,9 @@ class RepConv(nn.Module): return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) def get_equivalent_kernel_bias(self): - kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) - kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) - kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) return ( kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid, @@ -83,12 +82,12 @@ class RepConv(nn.Module): if branch is None: return 0, 0 if isinstance(branch, nn.Sequential): - kernel = branch[0].weight + kernel = branch[0].weight running_mean = branch[1].running_mean running_var = branch[1].running_var - gamma = branch[1].weight - beta = branch[1].bias - eps = branch[1].eps + gamma = branch[1].weight + beta = branch[1].bias + eps = branch[1].eps else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, "id_tensor"): @@ -99,14 +98,14 @@ class RepConv(nn.Module): for i in range(self.in_channels): kernel_value[i, i % input_dim, 1, 1] = 1 self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) - kernel = self.id_tensor + kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var - gamma = branch.weight - beta = branch.bias - eps = branch.eps + gamma = branch.weight + beta = branch.bias + eps = branch.eps std = (running_var + eps).sqrt() - t = (gamma / std).reshape(-1, 1, 1, 1) + t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def repvgg_convert(self): @@ -164,18 +163,18 @@ class RepConv(nn.Module): identity_conv_1x1.weight.data.fill_diagonal_(1.0) identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3) - identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity) - bias_identity_expanded = identity_conv_1x1.bias - weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1]) + identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity) + bias_identity_expanded = identity_conv_1x1.bias + weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1]) else: - bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) ) - weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) ) + bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) ) + weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) ) - self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded) - self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded) + self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded) + self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded) - self.rbr_reparam = self.rbr_dense - self.deploy = True + self.rbr_reparam = self.rbr_dense + self.deploy = True if self.rbr_identity is not None: del self.rbr_identity @@ -211,47 +210,55 @@ def fuse_conv_and_bn(conv, bn): # yolo_body #---------------------------------------------------# class YoloBody(nn.Module): - def __init__(self, anchors_mask, num_classes, pretrained=False): + def __init__(self, anchors_mask, num_classes, phi, pretrained=False): super(YoloBody, self).__init__() - base_channels = 32 + #-----------------------------------------------# + # 定义了不同yolov7版本的参数 + #-----------------------------------------------# + transition_channels = {'l' : 32, 'x' : 40}[phi] + block_channels = 32 + panet_channels = {'l' : 32, 'x' : 64}[phi] + e = {'l' : 2, 'x' : 1}[phi] + n = {'l' : 4, 'x' : 6}[phi] + ids = {'l' : [-1, -2, -3, -4, -5, -6], 'x' : [-1, -3, -5, -7, -8]}[phi] + conv = {'l' : RepConv, 'x' : Conv}[phi] #-----------------------------------------------# # 输入图片是640, 640, 3 - # 初始的基本通道是64 #-----------------------------------------------# #---------------------------------------------------# # 生成主干模型 # 获得三个有效特征层,他们的shape分别是: - # 80,80,512 - # 40,40,1024 - # 20,20,1024 + # 80, 80, 512 + # 40, 40, 1024 + # 20, 20, 1024 #---------------------------------------------------# - self.backbone = CSPDarknet(base_channels, pretrained=pretrained) + self.backbone = Backbone(transition_channels, block_channels, n, phi, pretrained=pretrained) self.upsample = nn.Upsample(scale_factor=2, mode="nearest") - self.sppcspc = SPPCSPC(base_channels * 32, base_channels * 16) - self.conv_for_P5 = Conv(base_channels * 16, base_channels * 8) - self.conv_for_feat2 = Conv(base_channels * 32, base_channels * 8) - self.conv3_for_upsample1 = RCSPDark_Block(base_channels * 16, base_channels * 4, base_channels * 8, ids=[-1, -2, -3, -4, -5, -6]) + self.sppcspc = SPPCSPC(transition_channels * 32, transition_channels * 16) + self.conv_for_P5 = Conv(transition_channels * 16, transition_channels * 8) + self.conv_for_feat2 = Conv(transition_channels * 32, transition_channels * 8) + self.conv3_for_upsample1 = Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids) - self.conv_for_P4 = Conv(base_channels * 8, base_channels * 4) - self.conv_for_feat1 = Conv(base_channels * 16, base_channels * 4) - self.conv3_for_upsample2 = RCSPDark_Block(base_channels * 8, base_channels * 2, base_channels * 4, ids=[-1, -2, -3, -4, -5, -6]) + self.conv_for_P4 = Conv(transition_channels * 8, transition_channels * 4) + self.conv_for_feat1 = Conv(transition_channels * 16, transition_channels * 4) + self.conv3_for_upsample2 = Block(transition_channels * 8, panet_channels * 2, transition_channels * 4, e=e, n=n, ids=ids) - self.down_sample1 = RCSPDark_Transition(base_channels * 4, base_channels * 4) - self.conv3_for_downsample1 = RCSPDark_Block(base_channels * 16, base_channels * 4, base_channels * 8, ids=[-1, -2, -3, -4, -5, -6]) + self.down_sample1 = Transition(transition_channels * 4, transition_channels * 4) + self.conv3_for_downsample1 = Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids) - self.down_sample2 = RCSPDark_Transition(base_channels * 8, base_channels * 8) - self.conv3_for_downsample2 = RCSPDark_Block(base_channels * 32, base_channels * 8, base_channels * 16, ids=[-1, -2, -3, -4, -5, -6]) + self.down_sample2 = Transition(transition_channels * 8, transition_channels * 8) + self.conv3_for_downsample2 = Block(transition_channels * 32, panet_channels * 8, transition_channels * 16, e=e, n=n, ids=ids) - self.rep_conv_1 = RepConv(base_channels * 4, base_channels * 8, 3, 1) - self.rep_conv_2 = RepConv(base_channels * 8, base_channels * 16, 3, 1) - self.rep_conv_3 = RepConv(base_channels * 16, base_channels * 32, 3, 1) + self.rep_conv_1 = conv(transition_channels * 4, transition_channels * 8, 3, 1) + self.rep_conv_2 = conv(transition_channels * 8, transition_channels * 16, 3, 1) + self.rep_conv_3 = conv(transition_channels * 16, transition_channels * 32, 3, 1) - self.yolo_head_P3 = nn.Conv2d(base_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1) - self.yolo_head_P4 = nn.Conv2d(base_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1) - self.yolo_head_P5 = nn.Conv2d(base_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1) + self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1) + self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1) + self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1) def fuse(self): print('Fusing layers... ') @@ -306,4 +313,4 @@ class YoloBody(nn.Module): #---------------------------------------------------# out0 = self.yolo_head_P5(P5) - return [out0, out1, out2] \ No newline at end of file + return [out0, out1, out2] diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 6516d20..eac5964 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -8,82 +8,6 @@ import torch.nn as nn import torch.nn.functional as F -def box_iou(box1, box2): - # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py - """ - Return intersection-over-union (Jaccard index) of boxes. - Both sets of boxes are expected to be in (x1, y1, x2, y2) format. - Arguments: - box1 (Tensor[N, 4]) - box2 (Tensor[M, 4]) - Returns: - iou (Tensor[N, M]): the NxM matrix containing the pairwise - IoU values for every element in boxes1 and boxes2 - """ - def box_area(box): - # box = 4xn - return (box[2] - box[0]) * (box[3] - box[1]) - - area1 = box_area(box1.T) - area2 = box_area(box2.T) - - # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) - inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) - return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) - -def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): - # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 - box2 = box2.T - - # Get the coordinates of bounding boxes - if x1y1x2y2: # x1, y1, x2, y2 = box1 - b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] - else: # transform from xywh to xyxy - b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 - b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 - b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 - b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 - - # Intersection area - inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ - (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) - - # Union Area - w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps - w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps - union = w1 * h1 + w2 * h2 - inter + eps - - iou = inter / union - - if GIoU or DIoU or CIoU: - cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width - ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height - if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 - c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared - rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + - (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared - if DIoU: - return iou - rho2 / c2 # DIoU - elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) - with torch.no_grad(): - alpha = v / (v - iou + (1 + eps)) - return iou - (rho2 / c2 + v * alpha) # CIoU - else: # GIoU https://arxiv.org/pdf/1902.09630.pdf - c_area = cw * ch + eps # convex area - return iou - (c_area - union) / c_area # GIoU - else: - return iou # IoU - -def xywh2xyxy(x): - # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x - y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y - y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x - y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y - return y def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 # return positive, negative label smoothing BCE targets @@ -93,9 +17,9 @@ class YOLOLoss(nn.Module): def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0): super(YOLOLoss, self).__init__() #-----------------------------------------------------------# - # 20x20的特征层对应的anchor是[116,90],[156,198],[373,326] - # 40x40的特征层对应的anchor是[30,61],[62,45],[59,119] - # 80x80的特征层对应的anchor是[10,13],[16,30],[33,23] + # 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401] + # 26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146] + # 52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28] #-----------------------------------------------------------# self.anchors = [anchors[mask] for mask in anchors_mask] self.num_classes = num_classes @@ -104,106 +28,253 @@ class YOLOLoss(nn.Module): self.balance = [0.4, 1.0, 4] self.stride = [32, 16, 8] + self.box_ratio = 0.05 self.obj_ratio = 1 * (input_shape[0] * input_shape[1]) / (640 ** 2) self.cls_ratio = 0.5 * (num_classes / 80) self.threshold = 4 - self.cp, self.cn = smooth_BCE(eps=label_smoothing) - self.BCEcls, self.BCEobj, self.gr = nn.BCEWithLogitsLoss(), nn.BCEWithLogitsLoss(), 1 - - def __call__(self, p, targets, imgs): # predictions, targets, model - for i in range(len(p)): - bs, _, h, w = p[i].size() - p[i] = p[i].view(bs, len(self.anchors_mask[i]), -1, h, w).permute(0, 1, 3, 4, 2).contiguous() - device = targets.device - lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) - bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs) - pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]].type_as(pp) for pp in p] - - # Losses - for i, pi in enumerate(p): # layer index, layer predictions - b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i] # image, anchor, gridy, gridx - tobj = torch.zeros_like(pi[..., 0], device=device) # target obj + self.cp, self.cn = smooth_BCE(eps=label_smoothing) + self.BCEcls, self.BCEobj, self.gr = nn.BCEWithLogitsLoss(), nn.BCEWithLogitsLoss(), 1 - n = b.shape[0] # number of targets + def bbox_iou(self, box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + box2 = box2.T + + if x1y1x2y2: + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + else: + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + union = w1 * h1 + w2 * h2 - inter + eps + + iou = inter / union + + if GIoU or DIoU or CIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared + if DIoU: + return iou - rho2 / c2 # DIoU + elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + else: # GIoU https://arxiv.org/pdf/1902.09630.pdf + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU + else: + return iou # IoU + + def __call__(self, predictions, targets, imgs): + #-------------------------------------------# + # 对输入进来的预测结果进行reshape + # bs, 255, 20, 20 => bs, 3, 20, 20, 85 + # bs, 255, 40, 40 => bs, 3, 40, 40, 85 + # bs, 255, 80, 80 => bs, 3, 80, 80, 85 + #-------------------------------------------# + for i in range(len(predictions)): + bs, _, h, w = predictions[i].size() + predictions[i] = predictions[i].view(bs, len(self.anchors_mask[i]), -1, h, w).permute(0, 1, 3, 4, 2).contiguous() + + #-------------------------------------------# + # 获得工作的设备 + #-------------------------------------------# + device = targets.device + #-------------------------------------------# + # 初始化三个部分的损失 + #-------------------------------------------# + cls_loss, box_loss, obj_loss = torch.zeros(1, device = device), torch.zeros(1, device = device), torch.zeros(1, device = device) + + #-------------------------------------------# + # 进行正样本的匹配 + #-------------------------------------------# + bs, as_, gjs, gis, targets, anchors = self.build_targets(predictions, targets, imgs) + #-------------------------------------------# + # 计算获得对应特征层的高宽 + #-------------------------------------------# + feature_map_sizes = [torch.tensor(prediction.shape, device=device)[[3, 2, 3, 2]].type_as(prediction) for prediction in predictions] + + #-------------------------------------------# + # 计算损失,对三个特征层各自进行处理 + #-------------------------------------------# + for i, prediction in enumerate(predictions): + #-------------------------------------------# + # image, anchor, gridy, gridx + #-------------------------------------------# + b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i] + tobj = torch.zeros_like(prediction[..., 0], device=device) # target obj + + #-------------------------------------------# + # 获得目标数量,如果目标大于0 + # 则开始计算种类损失和回归损失 + #-------------------------------------------# + n = b.shape[0] if n: - ps = pi[b, a, gj, gi] # prediction subset corresponding to targets - - # Regression - grid = torch.stack([gi, gj], dim=1) - pxy = ps[:, :2].sigmoid() * 2. - 0.5 - #pxy = ps[:, :2].sigmoid() * 3. - 1. - pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] - pbox = torch.cat((pxy, pwh), 1) # predicted box - selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] - selected_tbox[:, :2] -= grid.type_as(pi) - iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target) - lbox += (1.0 - iou).mean() # iou loss - - # Objectness + prediction_pos = prediction[b, a, gj, gi] # prediction subset corresponding to targets + + #-------------------------------------------# + # 计算匹配上的正样本的回归损失 + #-------------------------------------------# + #-------------------------------------------# + # grid 获得正样本的x、y轴坐标 + #-------------------------------------------# + grid = torch.stack([gi, gj], dim=1) + #-------------------------------------------# + # 进行解码,获得预测结果 + #-------------------------------------------# + xy = prediction_pos[:, :2].sigmoid() * 2. - 0.5 + wh = (prediction_pos[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] + box = torch.cat((xy, wh), 1) + #-------------------------------------------# + # 对真实框进行处理,映射到特征层上 + #-------------------------------------------# + selected_tbox = targets[i][:, 2:6] * feature_map_sizes[i] + selected_tbox[:, :2] -= grid.type_as(prediction) + #-------------------------------------------# + # 计算预测框和真实框的回归损失 + #-------------------------------------------# + iou = self.bbox_iou(box.T, selected_tbox, x1y1x2y2=False, CIoU=True) + box_loss += (1.0 - iou).mean() + #-------------------------------------------# + # 根据预测结果的iou获得置信度损失的gt + #-------------------------------------------# tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio - # Classification - selected_tcls = targets[i][:, 1].long() - if self.num_classes > 1: # cls loss (only if multiple classes) - t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets - t[range(n), selected_tcls] = self.cp - lcls += self.BCEcls(ps[:, 5:], t) # BCE - - # Append targets to text file - # with open('targets.txt', 'a') as file: - # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] - - obji = self.BCEobj(pi[..., 4], tobj) - lobj += obji * self.balance[i] # obj loss + #-------------------------------------------# + # 计算匹配上的正样本的分类损失 + #-------------------------------------------# + selected_tcls = targets[i][:, 1].long() + t = torch.full_like(prediction_pos[:, 5:], self.cn, device=device) # targets + t[range(n), selected_tcls] = self.cp + cls_loss += self.BCEcls(prediction_pos[:, 5:], t) # BCE + + #-------------------------------------------# + # 计算目标是否存在的置信度损失 + # 并且乘上每个特征层的比例 + #-------------------------------------------# + obj_loss += self.BCEobj(prediction[..., 4], tobj) * self.balance[i] # obj loss - lbox *= self.box_ratio - lobj *= self.obj_ratio - lcls *= self.cls_ratio - bs = tobj.shape[0] # batch size - - loss = lbox + lobj + lcls - return loss * bs - - def build_targets(self, p, targets, imgs): + #-------------------------------------------# + # 将各个部分的损失乘上比例 + # 全加起来后,乘上batch_size + #-------------------------------------------# + box_loss *= self.box_ratio + obj_loss *= self.obj_ratio + cls_loss *= self.cls_ratio + bs = tobj.shape[0] - indices, anch = self.find_3_positive(p, targets) - - matching_bs = [[] for pp in p] - matching_as = [[] for pp in p] - matching_gjs = [[] for pp in p] - matching_gis = [[] for pp in p] - matching_targets = [[] for pp in p] - matching_anchs = [[] for pp in p] + loss = box_loss + obj_loss + cls_loss + return loss * bs - nl = len(p) + def xywh2xyxy(self, x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y - for batch_idx in range(p[0].shape[0]): + def box_iou(self, box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + + def build_targets(self, predictions, targets, imgs): + #-------------------------------------------# + # 匹配正样本 + #-------------------------------------------# + indices, anch = self.find_3_positive(predictions, targets) + + matching_bs = [[] for _ in predictions] + matching_as = [[] for _ in predictions] + matching_gjs = [[] for _ in predictions] + matching_gis = [[] for _ in predictions] + matching_targets = [[] for _ in predictions] + matching_anchs = [[] for _ in predictions] - b_idx = targets[:, 0]==batch_idx + #-------------------------------------------# + # 一共三层 + #-------------------------------------------# + num_layer = len(predictions) + #-------------------------------------------# + # 对batch_size进行循环,进行OTA匹配 + # 在batch_size循环中对layer进行循环 + #-------------------------------------------# + for batch_idx in range(predictions[0].shape[0]): + #-------------------------------------------# + # 先判断匹配上的真实框哪些属于该图片 + #-------------------------------------------# + b_idx = targets[:, 0]==batch_idx this_target = targets[b_idx] + #-------------------------------------------# + # 如果没有真实框属于该图片则continue + #-------------------------------------------# if this_target.shape[0] == 0: continue - + + #-------------------------------------------# + # 真实框的坐标进行缩放 + #-------------------------------------------# txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1] - txyxy = xywh2xyxy(txywh) - - pxyxys = [] - p_cls = [] - p_obj = [] + #-------------------------------------------# + # 从中心宽高到左上角右下角 + #-------------------------------------------# + txyxy = self.xywh2xyxy(txywh) + + pxyxys = [] + p_cls = [] + p_obj = [] from_which_layer = [] - all_b = [] - all_a = [] - all_gj = [] - all_gi = [] - all_anch = [] + all_b = [] + all_a = [] + all_gj = [] + all_gi = [] + all_anch = [] - for i, pi in enumerate(p): - - b, a, gj, gi = indices[i] - idx = (b == batch_idx) - b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx] + #-------------------------------------------# + # 对三个layer进行循环 + #-------------------------------------------# + for i, prediction in enumerate(predictions): + #-------------------------------------------# + # b代表第几张图片 a代表第几个先验框 + # gj代表y轴,gi代表x轴 + #-------------------------------------------# + b, a, gj, gi = indices[i] + idx = (b == batch_idx) + b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx] + all_b.append(b) all_a.append(a) all_gj.append(gj) @@ -211,88 +282,118 @@ class YOLOLoss(nn.Module): all_anch.append(anch[i][idx]) from_which_layer.append(torch.ones(size=(len(b),)) * i) - fg_pred = pi[b, a, gj, gi] + #-------------------------------------------# + # 取出这个真实框对应的预测结果 + #-------------------------------------------# + fg_pred = prediction[b, a, gj, gi] p_obj.append(fg_pred[:, 4:5]) p_cls.append(fg_pred[:, 5:]) - grid = torch.stack([gi, gj], dim=1).type_as(fg_pred) - pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] #/ 8. - #pxy = (fg_pred[:, :2].sigmoid() * 3. - 1. + grid) * self.stride[i] - pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i] #/ 8. - pxywh = torch.cat([pxy, pwh], dim=-1) - pxyxy = xywh2xyxy(pxywh) + #-------------------------------------------# + # 获得网格后,进行解码 + #-------------------------------------------# + grid = torch.stack([gi, gj], dim=1).type_as(fg_pred) + pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] + pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i] + pxywh = torch.cat([pxy, pwh], dim=-1) + pxyxy = self.xywh2xyxy(pxywh) pxyxys.append(pxyxy) + #-------------------------------------------# + # 判断是否存在对应的预测框,不存在则跳过 + #-------------------------------------------# pxyxys = torch.cat(pxyxys, dim=0) if pxyxys.shape[0] == 0: continue - p_obj = torch.cat(p_obj, dim=0) - p_cls = torch.cat(p_cls, dim=0) + + #-------------------------------------------# + # 进行堆叠 + #-------------------------------------------# + p_obj = torch.cat(p_obj, dim=0) + p_cls = torch.cat(p_cls, dim=0) from_which_layer = torch.cat(from_which_layer, dim=0) - all_b = torch.cat(all_b, dim=0) - all_a = torch.cat(all_a, dim=0) - all_gj = torch.cat(all_gj, dim=0) - all_gi = torch.cat(all_gi, dim=0) - all_anch = torch.cat(all_anch, dim=0) + all_b = torch.cat(all_b, dim=0) + all_a = torch.cat(all_a, dim=0) + all_gj = torch.cat(all_gj, dim=0) + all_gi = torch.cat(all_gi, dim=0) + all_anch = torch.cat(all_anch, dim=0) - pair_wise_iou = box_iou(txyxy, pxyxys) - - pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) - - top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1) - dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) - - gt_cls_per_image = ( - F.one_hot(this_target[:, 1].to(torch.int64), self.num_classes) - .float() - .unsqueeze(1) - .repeat(1, pxyxys.shape[0], 1) - ) - - num_gt = this_target.shape[0] - cls_preds_ = ( - p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() - * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() - ) - - y = cls_preds_.sqrt_() - pair_wise_cls_loss = F.binary_cross_entropy_with_logits( - torch.log(y/(1-y)) , gt_cls_per_image, reduction="none" - ).sum(-1) + #-------------------------------------------------------------# + # 计算当前图片中,真实框与预测框的重合程度 + # iou的范围为0-1,取-log后为0~inf + # 重合程度越大,取-log后越小 + # 因此,真实框与预测框重合度越大,pair_wise_iou_loss越小 + #-------------------------------------------------------------# + pair_wise_iou = self.box_iou(txyxy, pxyxys) + pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) + + #-------------------------------------------# + # 最多二十个预测框与真实框的重合程度 + # 然后求和,找到每个真实框对应几个预测框 + #-------------------------------------------# + top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1) + dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) + + #-------------------------------------------# + # gt_cls_per_image 种类的真实信息 + #-------------------------------------------# + gt_cls_per_image = F.one_hot(this_target[:, 1].to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, pxyxys.shape[0], 1) + + #-------------------------------------------# + # cls_preds_ 种类置信度的预测信息 + # cls_preds_越接近于1,y越接近于1 + # y / (1 - y)越接近于无穷大 + # 也就是种类置信度预测的越准 + # pair_wise_cls_loss越小 + #-------------------------------------------# + num_gt = this_target.shape[0] + cls_preds_ = p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() + y = cls_preds_.sqrt_() + pair_wise_cls_loss = F.binary_cross_entropy_with_logits(torch.log(y / (1 - y)), gt_cls_per_image, reduction="none").sum(-1) del cls_preds_ + #-------------------------------------------# + # 求cost的总和 + #-------------------------------------------# cost = ( pair_wise_cls_loss + 3.0 * pair_wise_iou_loss ) + #-------------------------------------------# + # 求cost最小的k个预测框 + #-------------------------------------------# matching_matrix = torch.zeros_like(cost) - for gt_idx in range(num_gt): - _, pos_idx = torch.topk( - cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False - ) + _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) matching_matrix[gt_idx][pos_idx] = 1.0 del top_k, dynamic_ks + + #-------------------------------------------# + # 如果一个预测框对应多个真实框 + # 只使用这个预测框最对应的真实框 + #-------------------------------------------# anchor_matching_gt = matching_matrix.sum(0) if (anchor_matching_gt > 1).sum() > 0: _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) - matching_matrix[:, anchor_matching_gt > 1] *= 0.0 + matching_matrix[:, anchor_matching_gt > 1] *= 0.0 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 fg_mask_inboxes = matching_matrix.sum(0) > 0.0 matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) + + #-------------------------------------------# + # 取出符合条件的框 + #-------------------------------------------# + from_which_layer = from_which_layer[fg_mask_inboxes] + all_b = all_b[fg_mask_inboxes] + all_a = all_a[fg_mask_inboxes] + all_gj = all_gj[fg_mask_inboxes] + all_gi = all_gi[fg_mask_inboxes] + all_anch = all_anch[fg_mask_inboxes] + this_target = this_target[matched_gt_inds] - from_which_layer = from_which_layer[fg_mask_inboxes] - all_b = all_b[fg_mask_inboxes] - all_a = all_a[fg_mask_inboxes] - all_gj = all_gj[fg_mask_inboxes] - all_gi = all_gi[fg_mask_inboxes] - all_anch = all_anch[fg_mask_inboxes] - - this_target = this_target[matched_gt_inds] - - for i in range(nl): + for i in range(num_layer): layer_idx = from_which_layer == i matching_bs[i].append(all_b[layer_idx]) matching_as[i].append(all_a[layer_idx]) @@ -301,68 +402,118 @@ class YOLOLoss(nn.Module): matching_targets[i].append(this_target[layer_idx]) matching_anchs[i].append(all_anch[layer_idx]) - for i in range(nl): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) + for i in range(num_layer): + matching_bs[i] = torch.cat(matching_bs[i], dim=0) + matching_as[i] = torch.cat(matching_as[i], dim=0) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs - def find_3_positive(self, p, targets): - # Build targets for compute_loss(), input targets(image,class,x,y,w,h) - na, nt = len(self.anchors_mask[0]), targets.shape[0] # number of anchors, targets - indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain - ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) - targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices - - g = 0.5 # bias - off = torch.tensor([[0, 0], - [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m - # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm - ], device=targets.device).float() * g # offsets - - for i in range(len(p)): - anchors = torch.from_numpy(self.anchors[i]).type_as(p[i]) - gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain - - # Match targets to anchors + def find_3_positive(self, predictions, targets): + #------------------------------------# + # 获得每个特征层先验框的数量 + # 与真实框的数量 + #------------------------------------# + num_anchor, num_gt = len(self.anchors_mask[0]), targets.shape[0] + #------------------------------------# + # 创建空列表存放indices和anchors + #------------------------------------# + indices, anchors = [], [] + #------------------------------------# + # 创建7个1 + # 序号0,1为1 + # 序号2:6为特征层的高宽 + # 序号6为1 + #------------------------------------# + gain = torch.ones(7, device=targets.device) + #------------------------------------# + # ai [num_anchor, num_gt] + # targets [num_gt, 6] => [num_anchor, num_gt, 7] + #------------------------------------# + ai = torch.arange(num_anchor, device=targets.device).float().view(num_anchor, 1).repeat(1, num_gt) + targets = torch.cat((targets.repeat(num_anchor, 1, 1), ai[:, :, None]), 2) # append anchor indices + + g = 0.5 # offsets + off = torch.tensor([ + [0, 0], + [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], device=targets.device).float() * g + + for i in range(len(predictions)): + #----------------------------------------------------# + # 将先验框除以stride,获得相对于特征层的先验框。 + # anchors_i [num_anchor, 2] + #----------------------------------------------------# + anchors_i = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i]) + #-------------------------------------------# + # 计算获得对应特征层的高宽 + #-------------------------------------------# + gain[2:6] = torch.tensor(predictions[i].shape)[[3, 2, 3, 2]] + + #-------------------------------------------# + # 将真实框乘上gain, + # 其实就是将真实框映射到特征层上 + #-------------------------------------------# t = targets * gain - if nt: - # Matches - r = t[:, :, 4:6] / anchors[:, None] # wh ratio - j = torch.max(r, 1. / r).max(2)[0] < self.threshold # compare - # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + if num_gt: + #-------------------------------------------# + # 计算真实框与先验框高宽的比值 + # 然后根据比值大小进行判断, + # 判断结果用于取出,获得所有先验框对应的真实框 + # r [num_anchor, num_gt, 2] + # t [num_anchor, num_gt, 7] => [num_matched_anchor, 7] + #-------------------------------------------# + r = t[:, :, 4:6] / anchors_i[:, None] + j = torch.max(r, 1. / r).max(2)[0] < self.threshold t = t[j] # filter - - # Offsets - gxy = t[:, 2:4] # grid xy - gxi = gain[[2, 3]] - gxy # inverse - j, k = ((gxy % 1. < g) & (gxy > 1.)).T - l, m = ((gxi % 1. < g) & (gxi > 1.)).T - j = torch.stack((torch.ones_like(j), j, k, l, m)) - t = t.repeat((5, 1, 1))[j] + + #-------------------------------------------# + # gxy 获得所有先验框对应的真实框的x轴y轴坐标 + # gxi 取相对于该特征层的右小角的坐标 + #-------------------------------------------# + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1. < g) & (gxy > 1.)).T + l, m = ((gxi % 1. < g) & (gxi > 1.)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + + #-------------------------------------------# + # t 重复5次,使用满足条件的j进行框的提取 + # j 一共五行,代表当前特征点在五个 + # [0, 0], [1, 0], [0, 1], [-1, 0], [0, -1] + # 方向是否存在 + #-------------------------------------------# + t = t.repeat((5, 1, 1))[j] offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] else: t = targets[0] offsets = 0 - # Define - b, c = t[:, :2].long().T # image, class - gxy = t[:, 2:4] # grid xy - gwh = t[:, 4:6] # grid wh - gij = (gxy - offsets).long() - gi, gj = gij.T # grid xy indices - - # Append + #-------------------------------------------# + # b 代表属于第几个图片 + # gxy 代表该真实框所处的x、y中心坐标 + # gwh 代表该真实框的wh坐标 + # gij 代表真实框所属的特征点坐标 + #-------------------------------------------# + b, c = t[:, :2].long().T # image, class + gxy = t[:, 2:4] # grid xy + gwh = t[:, 4:6] # grid wh + gij = (gxy - offsets).long() + gi, gj = gij.T # grid xy indices + + #-------------------------------------------# + # gj、gi不能超出特征层范围 + # a代表属于该特征点的第几个先验框 + #-------------------------------------------# a = t[:, 6].long() # anchor indices indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices - anch.append(anchors[a]) # anchors + anchors.append(anchors_i[a]) # anchors - return indices, anch + return indices, anchors def is_parallel(model): # Returns True if model is of type DP or DDP diff --git a/summary.py b/summary.py index 0b27ac1..0ea3d58 100644 --- a/summary.py +++ b/summary.py @@ -11,9 +11,10 @@ if __name__ == "__main__": input_shape = [640, 640] anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] num_classes = 80 + phi = 'yolov7' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - m = YoloBody(anchors_mask, num_classes, False).to(device) + m = YoloBody(anchors_mask, num_classes, phi, False).to(device) summary(m, (3, input_shape[0], input_shape[1])) dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) diff --git a/train.py b/train.py index e447fcb..f53cb13 100644 --- a/train.py +++ b/train.py @@ -98,6 +98,12 @@ if __name__ == "__main__": # input_shape 输入的shape大小,一定要是32的倍数 #------------------------------------------------------# input_shape = [640, 640] + #------------------------------------------------------# + # phi 所使用到的yolov7的版本,本仓库一共提供两个: + # l : 对应yolov7 + # x : 对应yolov7_x + #------------------------------------------------------# + phi = 'l' #----------------------------------------------------------------------------------------------------------------------------# # pretrained 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。 # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。 @@ -268,15 +274,15 @@ if __name__ == "__main__": if pretrained: if distributed: if local_rank == 0: - download_weights() + download_weights(phi) dist.barrier() else: - download_weights() + download_weights(phi) #------------------------------------------------------# # 创建yolo模型 #------------------------------------------------------# - model = YoloBody(anchors_mask, num_classes, pretrained=pretrained) + model = YoloBody(anchors_mask, num_classes, phi, pretrained=pretrained) if not pretrained: weights_init(model) if model_path != '': diff --git a/utils/dataloader.py b/utils/dataloader.py index 01c5d54..c14a20a 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -57,7 +57,10 @@ class YoloDataset(Dataset): image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1)) box = np.array(box, dtype=np.float32) - nL = len(box) # number of labels + #---------------------------------------------------# + # 对真实框进行预处理 + #---------------------------------------------------# + nL = len(box) labels_out = np.zeros((nL, 6)) if nL: #---------------------------------------------------# @@ -73,6 +76,10 @@ class YoloDataset(Dataset): box[:, 2:4] = box[:, 2:4] - box[:, 0:2] box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2 + #---------------------------------------------------# + # 调整顺序,符合训练的格式 + # labels_out中序号为0的部分在collect时处理 + #---------------------------------------------------# labels_out[:, 1] = box[:, -1] labels_out[:, 2:] = box[:, :4] diff --git a/utils/utils.py b/utils/utils.py index 4c22c2c..3b578de 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -71,13 +71,13 @@ def show_config(**kwargs): print('|%25s | %40s|' % (str(key), str(value))) print('-' * 70) -def download_weights(model_dir="./model_data"): +def download_weights(phi, model_dir="./model_data"): import os from torch.hub import load_state_dict_from_url - phi = "l" download_urls = { - "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/cspdarknet_backbone.pth', + "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone.pth', + "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone.pth', } url = download_urls[phi] diff --git a/utils/utils_bbox.py b/utils/utils_bbox.py index b7e6971..1f6a04d 100644 --- a/utils/utils_bbox.py +++ b/utils/utils_bbox.py @@ -11,9 +11,9 @@ class DecodeBox(): self.bbox_attrs = 5 + num_classes self.input_shape = input_shape #-----------------------------------------------------------# - # 20x20的特征层对应的anchor是[116,90],[156,198],[373,326] - # 40x40的特征层对应的anchor是[30,61],[62,45],[59,119] - # 80x80的特征层对应的anchor是[10,13],[16,30],[33,23] + # 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401] + # 26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146] + # 52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28] #-----------------------------------------------------------# self.anchors_mask = anchors_mask diff --git a/yolo.py b/yolo.py index 9664fb9..ab39d2e 100644 --- a/yolo.py +++ b/yolo.py @@ -37,6 +37,12 @@ class YOLO(object): # 输入图片的大小,必须为32的倍数。 #---------------------------------------------------------------------# "input_shape" : [640, 640], + #------------------------------------------------------# + # 所使用到的yolov7的版本,本仓库一共提供两个: + # l : 对应yolov7 + # x : 对应yolov7_x + #------------------------------------------------------# + "phi" : 'l', #---------------------------------------------------------------------# # 只有得分大于置信度的预测框会被保留下来 #---------------------------------------------------------------------# @@ -97,7 +103,7 @@ class YOLO(object): #---------------------------------------------------# # 建立yolo模型,载入yolo模型的权重 #---------------------------------------------------# - self.net = YoloBody(self.anchors_mask, self.num_classes) + self.net = YoloBody(self.anchors_mask, self.num_classes, self.phi) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.net.load_state_dict(torch.load(self.model_path, map_location=device)) self.net = self.net.fuse().eval() -- GitLab