From 16c84698d5b98e88a5b8dd872489a052683435e2 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <3323290568@qq.com> Date: Fri, 8 Jul 2022 17:29:02 +0800 Subject: [PATCH] update loss --- README.md | 2 +- nets/yolo.py | 2 +- nets/yolo_training.py | 659 +++++++++++++++++++++--------------------- train.py | 2 +- utils/dataloader.py | 132 +-------- utils/utils_fit.py | 25 +- 6 files changed, 355 insertions(+), 467 deletions(-) diff --git a/README.md b/README.md index db92ce1..ae39f87 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ YoloV7 | https://github.com/bubbliiiing/yolov7-pytorch ## 性能情况 | 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mAP 0.5:0.95 | mAP 0.5 | | :-----: | :-----: | :------: | :------: | :------: | :-----: | -| COCO-Train2017 | [yolov7_weights.pth](https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_weights.pth) | COCO-Val2017 | 640x640 | 27.6 | 45.0 +| COCO-Train2017 | [yolov7_weights.pth](https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_weights.pth) | COCO-Val2017 | 640x640 | 50.7 | 69.2 ## 所需环境 torch==1.2.0 diff --git a/nets/yolo.py b/nets/yolo.py index 5b53a99..150d401 100644 --- a/nets/yolo.py +++ b/nets/yolo.py @@ -305,5 +305,5 @@ class YoloBody(nn.Module): # y1=(batch_size, 75, 20, 20) #---------------------------------------------------# out0 = self.yolo_head_P5(P5) - return out0, out1, out2 + return [out0, out1, out2] \ No newline at end of file diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 240e515..6516d20 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -5,353 +5,364 @@ from functools import partial import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + +def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 + box2 = box2.T + + # Get the coordinates of bounding boxes + if x1y1x2y2: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + else: # transform from xywh to xyxy + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + + # Intersection area + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + + # Union Area + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + union = w1 * h1 + w2 * h2 - inter + eps + + iou = inter / union + + if GIoU or DIoU or CIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared + if DIoU: + return iou - rho2 / c2 # DIoU + elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + else: # GIoU https://arxiv.org/pdf/1902.09630.pdf + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU + else: + return iou # IoU + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + +def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 + # return positive, negative label smoothing BCE targets + return 1.0 - 0.5 * eps, 0.5 * eps + class YOLOLoss(nn.Module): - def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0): + def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0): super(YOLOLoss, self).__init__() #-----------------------------------------------------------# # 20x20的特征层对应的anchor是[116,90],[156,198],[373,326] # 40x40的特征层对应的anchor是[30,61],[62,45],[59,119] # 80x80的特征层对应的anchor是[10,13],[16,30],[33,23] #-----------------------------------------------------------# - self.anchors = anchors + self.anchors = [anchors[mask] for mask in anchors_mask] self.num_classes = num_classes - self.bbox_attrs = 5 + num_classes self.input_shape = input_shape self.anchors_mask = anchors_mask - self.label_smoothing = label_smoothing - - self.threshold = 4 self.balance = [0.4, 1.0, 4] + self.stride = [32, 16, 8] self.box_ratio = 0.05 self.obj_ratio = 1 * (input_shape[0] * input_shape[1]) / (640 ** 2) self.cls_ratio = 0.5 * (num_classes / 80) - self.cuda = cuda - - def clip_by_tensor(self, t, t_min, t_max): - t = t.float() - result = (t >= t_min).float() * t + (t < t_min).float() * t_min - result = (result <= t_max).float() * result + (result > t_max).float() * t_max - return result - - def MSELoss(self, pred, target): - return torch.pow(pred - target, 2) - - def BCELoss(self, pred, target): - epsilon = 1e-7 - pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon) - output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred) - return output - - def box_giou(self, b1, b2): - """ - 输入为: - ---------- - b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh - b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh - - 返回为: - ------- - giou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1) - """ - #----------------------------------------------------# - # 求出预测框左上角右下角 - #----------------------------------------------------# - b1_xy = b1[..., :2] - b1_wh = b1[..., 2:4] - b1_wh_half = b1_wh/2. - b1_mins = b1_xy - b1_wh_half - b1_maxes = b1_xy + b1_wh_half - #----------------------------------------------------# - # 求出真实框左上角右下角 - #----------------------------------------------------# - b2_xy = b2[..., :2] - b2_wh = b2[..., 2:4] - b2_wh_half = b2_wh/2. - b2_mins = b2_xy - b2_wh_half - b2_maxes = b2_xy + b2_wh_half - - #----------------------------------------------------# - # 求真实框和预测框所有的iou - #----------------------------------------------------# - intersect_mins = torch.max(b1_mins, b2_mins) - intersect_maxes = torch.min(b1_maxes, b2_maxes) - intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes)) - intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] - b1_area = b1_wh[..., 0] * b1_wh[..., 1] - b2_area = b2_wh[..., 0] * b2_wh[..., 1] - union_area = b1_area + b2_area - intersect_area - iou = intersect_area / union_area - - #----------------------------------------------------# - # 找到包裹两个框的最小框的左上角和右下角 - #----------------------------------------------------# - enclose_mins = torch.min(b1_mins, b2_mins) - enclose_maxes = torch.max(b1_maxes, b2_maxes) - enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes)) - #----------------------------------------------------# - # 计算对角线距离 - #----------------------------------------------------# - enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] - giou = iou - (enclose_area - union_area) / enclose_area - - return giou - - #---------------------------------------------------# - # 平滑标签 - #---------------------------------------------------# - def smooth_labels(self, y_true, label_smoothing, num_classes): - return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes - - def forward(self, l, input, targets=None, y_true=None): - #----------------------------------------------------# - # l 代表使用的是第几个有效特征层 - # input的shape为 bs, 3*(5+num_classes), 20, 20 - # bs, 3*(5+num_classes), 40, 40 - # bs, 3*(5+num_classes), 80, 80 - # targets 真实框的标签情况 [batch_size, num_gt, 5] - #----------------------------------------------------# - #--------------------------------# - # 获得图片数量,特征层的高和宽 - # 20, 20 - #--------------------------------# - bs = input.size(0) - in_h = input.size(2) - in_w = input.size(3) - #-----------------------------------------------------------------------# - # 计算步长 - # 每一个特征点对应原来的图片上多少个像素点 - # [640, 640] 高的步长为640 / 20 = 32,宽的步长为640 / 20 = 32 - # 如果特征层为20x20的话,一个特征点就对应原来的图片上的32个像素点 - # 如果特征层为40x40的话,一个特征点就对应原来的图片上的16个像素点 - # 如果特征层为80x80的话,一个特征点就对应原来的图片上的8个像素点 - # stride_h = stride_w = 32、16、8 - #-----------------------------------------------------------------------# - stride_h = self.input_shape[0] / in_h - stride_w = self.input_shape[1] / in_w - #-------------------------------------------------# - # 此时获得的scaled_anchors大小是相对于特征层的 - #-------------------------------------------------# - scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors] - #-----------------------------------------------# - # 输入的input一共有三个,他们的shape分别是 - # bs, 3 * (5+num_classes), 20, 20 => bs, 3, 5 + num_classes, 20, 20 => batch_size, 3, 20, 20, 5 + num_classes - - # batch_size, 3, 20, 20, 5 + num_classes - # batch_size, 3, 40, 40, 5 + num_classes - # batch_size, 3, 80, 80, 5 + num_classes - #-----------------------------------------------# - prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous() - - #-----------------------------------------------# - # 先验框的中心位置的调整参数 - #-----------------------------------------------# - x = torch.sigmoid(prediction[..., 0]) - y = torch.sigmoid(prediction[..., 1]) - #-----------------------------------------------# - # 先验框的宽高调整参数 - #-----------------------------------------------# - w = torch.sigmoid(prediction[..., 2]) - h = torch.sigmoid(prediction[..., 3]) - #-----------------------------------------------# - # 获得置信度,是否有物体 - #-----------------------------------------------# - conf = torch.sigmoid(prediction[..., 4]) - #-----------------------------------------------# - # 种类置信度 - #-----------------------------------------------# - pred_cls = torch.sigmoid(prediction[..., 5:]) - #-----------------------------------------------# - # self.get_target已经合并到dataloader中 - # 原因是在这里执行过慢,会大大延长训练时间 - #-----------------------------------------------# - # y_true, noobj_mask = self.get_target(l, targets, scaled_anchors, in_h, in_w) - - #---------------------------------------------------------------# - # 将预测结果进行解码,判断预测结果和真实值的重合程度 - # 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点 - # 作为负样本不合适 - #----------------------------------------------------------------# - pred_boxes = self.get_pred_boxes(l, x, y, h, w, targets, scaled_anchors, in_h, in_w) - - if self.cuda: - y_true = y_true.type_as(x) + self.threshold = 4 + + self.cp, self.cn = smooth_BCE(eps=label_smoothing) + self.BCEcls, self.BCEobj, self.gr = nn.BCEWithLogitsLoss(), nn.BCEWithLogitsLoss(), 1 + + def __call__(self, p, targets, imgs): # predictions, targets, model + for i in range(len(p)): + bs, _, h, w = p[i].size() + p[i] = p[i].view(bs, len(self.anchors_mask[i]), -1, h, w).permute(0, 1, 3, 4, 2).contiguous() + device = targets.device + lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) + bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs) + pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]].type_as(pp) for pp in p] + + # Losses + for i, pi in enumerate(p): # layer index, layer predictions + b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i] # image, anchor, gridy, gridx + tobj = torch.zeros_like(pi[..., 0], device=device) # target obj + + n = b.shape[0] # number of targets + if n: + ps = pi[b, a, gj, gi] # prediction subset corresponding to targets + + # Regression + grid = torch.stack([gi, gj], dim=1) + pxy = ps[:, :2].sigmoid() * 2. - 0.5 + #pxy = ps[:, :2].sigmoid() * 3. - 1. + pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] + pbox = torch.cat((pxy, pwh), 1) # predicted box + selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] + selected_tbox[:, :2] -= grid.type_as(pi) + iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss + + # Objectness + tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio + + # Classification + selected_tcls = targets[i][:, 1].long() + if self.num_classes > 1: # cls loss (only if multiple classes) + t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets + t[range(n), selected_tcls] = self.cp + lcls += self.BCEcls(ps[:, 5:], t) # BCE + + # Append targets to text file + # with open('targets.txt', 'a') as file: + # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] + + obji = self.BCEobj(pi[..., 4], tobj) + lobj += obji * self.balance[i] # obj loss + + lbox *= self.box_ratio + lobj *= self.obj_ratio + lcls *= self.cls_ratio + bs = tobj.shape[0] # batch size + + loss = lbox + lobj + lcls + return loss * bs + + def build_targets(self, p, targets, imgs): - loss = 0 - n = torch.sum(y_true[..., 4] == 1) - if n != 0: - #---------------------------------------------------------------# - # 计算预测结果和真实结果的giou,计算对应有真实框的先验框的giou损失 - # loss_cls计算对应有真实框的先验框的分类损失 - #----------------------------------------------------------------# - giou = self.box_giou(pred_boxes, y_true[..., :4]).type_as(x) - loss_loc = torch.mean((1 - giou)[y_true[..., 4] == 1]) - loss_cls = torch.mean(self.BCELoss(pred_cls[y_true[..., 4] == 1], self.smooth_labels(y_true[..., 5:][y_true[..., 4] == 1], self.label_smoothing, self.num_classes))) - loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio - #-----------------------------------------------------------# - # 计算置信度的loss - # 也就意味着先验框对应的预测框预测的更准确 - # 它才是用来预测这个物体的。 - #-----------------------------------------------------------# - tobj = torch.where(y_true[..., 4] == 1, giou.detach().clamp(0), torch.zeros_like(y_true[..., 4])) - else: - tobj = torch.zeros_like(y_true[..., 4]) - loss_conf = torch.mean(self.BCELoss(conf, tobj)) + indices, anch = self.find_3_positive(p, targets) + + matching_bs = [[] for pp in p] + matching_as = [[] for pp in p] + matching_gjs = [[] for pp in p] + matching_gis = [[] for pp in p] + matching_targets = [[] for pp in p] + matching_anchs = [[] for pp in p] - loss += loss_conf * self.balance[l] * self.obj_ratio - # if n != 0: - # print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio) - return loss + nl = len(p) - def get_near_points(self, x, y, i, j): - sub_x = x - i - sub_y = y - j - if sub_x > 0.5 and sub_y > 0.5: - return [[0, 0], [1, 0], [0, 1]] - elif sub_x < 0.5 and sub_y > 0.5: - return [[0, 0], [-1, 0], [0, 1]] - elif sub_x < 0.5 and sub_y < 0.5: - return [[0, 0], [-1, 0], [0, -1]] - else: - return [[0, 0], [1, 0], [0, -1]] - - def get_target(self, l, targets, anchors, in_h, in_w): - #-----------------------------------------------------# - # 计算一共有多少张图片 - #-----------------------------------------------------# - bs = len(targets) - #-----------------------------------------------------# - # 用于选取哪些先验框不包含物体 - # bs, 3, 20, 20 - #-----------------------------------------------------# - noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False) - #-----------------------------------------------------# - # 帮助找到每一个先验框最对应的真实框 - #-----------------------------------------------------# - box_best_ratio = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False) - #-----------------------------------------------------# - # batch_size, 3, 20, 20, 5 + num_classes - #-----------------------------------------------------# - y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False) - for b in range(bs): - if len(targets[b])==0: + for batch_idx in range(p[0].shape[0]): + + b_idx = targets[:, 0]==batch_idx + this_target = targets[b_idx] + if this_target.shape[0] == 0: continue - batch_target = torch.zeros_like(targets[b]) - #-------------------------------------------------------# - # 计算出正样本在特征层上的中心点 - # 获得真实框相对于特征层的大小 - #-------------------------------------------------------# - batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w - batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h - batch_target[:, 4] = targets[b][:, 4] - batch_target = batch_target.cpu() + + txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1] + txyxy = xywh2xyxy(txywh) + + pxyxys = [] + p_cls = [] + p_obj = [] + from_which_layer = [] + all_b = [] + all_a = [] + all_gj = [] + all_gi = [] + all_anch = [] + + for i, pi in enumerate(p): + + b, a, gj, gi = indices[i] + idx = (b == batch_idx) + b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx] + all_b.append(b) + all_a.append(a) + all_gj.append(gj) + all_gi.append(gi) + all_anch.append(anch[i][idx]) + from_which_layer.append(torch.ones(size=(len(b),)) * i) + + fg_pred = pi[b, a, gj, gi] + p_obj.append(fg_pred[:, 4:5]) + p_cls.append(fg_pred[:, 5:]) + + grid = torch.stack([gi, gj], dim=1).type_as(fg_pred) + pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] #/ 8. + #pxy = (fg_pred[:, :2].sigmoid() * 3. - 1. + grid) * self.stride[i] + pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i] #/ 8. + pxywh = torch.cat([pxy, pwh], dim=-1) + pxyxy = xywh2xyxy(pxywh) + pxyxys.append(pxyxy) - #-----------------------------------------------------------------------------# - # batch_target : num_true_box, 5 - # batch_target[:, 2:4] : num_true_box, 2 - # torch.unsqueeze(batch_target[:, 2:4], 1) : num_true_box, 1, 2 - # anchors : 9, 2 - # torch.unsqueeze(torch.FloatTensor(anchors), 0) : 1, 9, 2 - # ratios_of_gt_anchors : num_true_box, 9, 2 - # ratios_of_anchors_gt : num_true_box, 9, 2 - # - # ratios : num_true_box, 9, 4 - # max_ratios : num_true_box, 9 - # max_ratios每一个真实框和每一个先验框的最大宽高比! - #------------------------------------------------------------------------------# - ratios_of_gt_anchors = torch.unsqueeze(batch_target[:, 2:4], 1) / torch.unsqueeze(torch.FloatTensor(anchors), 0) - ratios_of_anchors_gt = torch.unsqueeze(torch.FloatTensor(anchors), 0) / torch.unsqueeze(batch_target[:, 2:4], 1) - ratios = torch.cat([ratios_of_gt_anchors, ratios_of_anchors_gt], dim = -1) - max_ratios, _ = torch.max(ratios, dim = -1) - - for t, ratio in enumerate(max_ratios): - #-------------------------------------------------------# - # ratio : 9 - #-------------------------------------------------------# - over_threshold = ratio < self.threshold - over_threshold[torch.argmin(ratio)] = True - for k, mask in enumerate(self.anchors_mask[l]): - if not over_threshold[mask]: - continue - #----------------------------------------# - # 获得真实框属于哪个网格点 - # x 1.25 => 1 - # y 3.75 => 3 - #----------------------------------------# - i = torch.floor(batch_target[t, 0]).long() - j = torch.floor(batch_target[t, 1]).long() - - offsets = self.get_near_points(batch_target[t, 0], batch_target[t, 1], i, j) - for offset in offsets: - local_i = i + offset[0] - local_j = j + offset[1] - - if local_i >= in_w or local_i < 0 or local_j >= in_h or local_j < 0: - continue - - if box_best_ratio[b, k, local_j, local_i] != 0: - if box_best_ratio[b, k, local_j, local_i] > ratio[mask]: - y_true[b, k, local_j, local_i, :] = 0 - else: - continue - - #----------------------------------------# - # 取出真实框的种类 - #----------------------------------------# - c = batch_target[t, 4].long() - - #----------------------------------------# - # noobj_mask代表无目标的特征点 - #----------------------------------------# - noobj_mask[b, k, local_j, local_i] = 0 - #----------------------------------------# - # tx、ty代表中心调整参数的真实值 - #----------------------------------------# - y_true[b, k, local_j, local_i, 0] = batch_target[t, 0] - y_true[b, k, local_j, local_i, 1] = batch_target[t, 1] - y_true[b, k, local_j, local_i, 2] = batch_target[t, 2] - y_true[b, k, local_j, local_i, 3] = batch_target[t, 3] - y_true[b, k, local_j, local_i, 4] = 1 - y_true[b, k, local_j, local_i, c + 5] = 1 - #----------------------------------------# - # 获得当前先验框最好的比例 - #----------------------------------------# - box_best_ratio[b, k, local_j, local_i] = ratio[mask] - - return y_true, noobj_mask - - def get_pred_boxes(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w): - #-----------------------------------------------------# - # 计算一共有多少张图片 - #-----------------------------------------------------# - bs = len(targets) - - #-----------------------------------------------------# - # 生成网格,先验框中心,网格左上角 - #-----------------------------------------------------# - grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat( - int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x) - grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat( - int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x) - - # 生成先验框的宽高 - scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]] - anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x) - anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x) + pxyxys = torch.cat(pxyxys, dim=0) + if pxyxys.shape[0] == 0: + continue + p_obj = torch.cat(p_obj, dim=0) + p_cls = torch.cat(p_cls, dim=0) + from_which_layer = torch.cat(from_which_layer, dim=0) + all_b = torch.cat(all_b, dim=0) + all_a = torch.cat(all_a, dim=0) + all_gj = torch.cat(all_gj, dim=0) + all_gi = torch.cat(all_gi, dim=0) + all_anch = torch.cat(all_anch, dim=0) - anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape) - anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape) - #-------------------------------------------------------# - # 计算调整后的先验框中心与宽高 - #-------------------------------------------------------# - pred_boxes_x = torch.unsqueeze(x * 2. - 0.5 + grid_x, -1) - pred_boxes_y = torch.unsqueeze(y * 2. - 0.5 + grid_y, -1) - pred_boxes_w = torch.unsqueeze((w * 2) ** 2 * anchor_w, -1) - pred_boxes_h = torch.unsqueeze((h * 2) ** 2 * anchor_h, -1) - pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1) - return pred_boxes + pair_wise_iou = box_iou(txyxy, pxyxys) + + pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) + + top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1) + dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) + + gt_cls_per_image = ( + F.one_hot(this_target[:, 1].to(torch.int64), self.num_classes) + .float() + .unsqueeze(1) + .repeat(1, pxyxys.shape[0], 1) + ) + + num_gt = this_target.shape[0] + cls_preds_ = ( + p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() + * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() + ) + + y = cls_preds_.sqrt_() + pair_wise_cls_loss = F.binary_cross_entropy_with_logits( + torch.log(y/(1-y)) , gt_cls_per_image, reduction="none" + ).sum(-1) + del cls_preds_ + + cost = ( + pair_wise_cls_loss + + 3.0 * pair_wise_iou_loss + ) + + matching_matrix = torch.zeros_like(cost) + + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False + ) + matching_matrix[gt_idx][pos_idx] = 1.0 + + del top_k, dynamic_ks + anchor_matching_gt = matching_matrix.sum(0) + if (anchor_matching_gt > 1).sum() > 0: + _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) + matching_matrix[:, anchor_matching_gt > 1] *= 0.0 + matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 + fg_mask_inboxes = matching_matrix.sum(0) > 0.0 + matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) + + from_which_layer = from_which_layer[fg_mask_inboxes] + all_b = all_b[fg_mask_inboxes] + all_a = all_a[fg_mask_inboxes] + all_gj = all_gj[fg_mask_inboxes] + all_gi = all_gi[fg_mask_inboxes] + all_anch = all_anch[fg_mask_inboxes] + + this_target = this_target[matched_gt_inds] + + for i in range(nl): + layer_idx = from_which_layer == i + matching_bs[i].append(all_b[layer_idx]) + matching_as[i].append(all_a[layer_idx]) + matching_gjs[i].append(all_gj[layer_idx]) + matching_gis[i].append(all_gi[layer_idx]) + matching_targets[i].append(this_target[layer_idx]) + matching_anchs[i].append(all_anch[layer_idx]) + + for i in range(nl): + matching_bs[i] = torch.cat(matching_bs[i], dim=0) + matching_as[i] = torch.cat(matching_as[i], dim=0) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + + return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs + + def find_3_positive(self, p, targets): + # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + na, nt = len(self.anchors_mask[0]), targets.shape[0] # number of anchors, targets + indices, anch = [], [] + gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) + targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor([[0, 0], + [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], device=targets.device).float() * g # offsets + + for i in range(len(p)): + anchors = torch.from_numpy(self.anchors[i]).type_as(p[i]) + gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + t = targets * gain + if nt: + # Matches + r = t[:, :, 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1. / r).max(2)[0] < self.threshold # compare + # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter + + # Offsets + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1. < g) & (gxy > 1.)).T + l, m = ((gxi % 1. < g) & (gxi > 1.)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 + + # Define + b, c = t[:, :2].long().T # image, class + gxy = t[:, 2:4] # grid xy + gwh = t[:, 4:6] # grid wh + gij = (gxy - offsets).long() + gi, gj = gij.T # grid xy indices + + # Append + a = t[:, 6].long() # anchor indices + indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices + anch.append(anchors[a]) # anchors + + return indices, anch def is_parallel(model): # Returns True if model is of type DP or DDP diff --git a/train.py b/train.py index 6e85303..e447fcb 100644 --- a/train.py +++ b/train.py @@ -311,7 +311,7 @@ if __name__ == "__main__": #----------------------# # 获得损失函数 #----------------------# - yolo_loss = YOLOLoss(anchors, num_classes, input_shape, Cuda, anchors_mask, label_smoothing) + yolo_loss = YOLOLoss(anchors, num_classes, input_shape, anchors_mask, label_smoothing) #----------------------# # 记录Loss #----------------------# diff --git a/utils/dataloader.py b/utils/dataloader.py index 0bb2008..01c5d54 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -30,7 +30,6 @@ class YoloDataset(Dataset): self.length = len(self.annotation_lines) self.bbox_attrs = 5 + num_classes - self.threshold = 4 def __len__(self): return self.length @@ -57,7 +56,10 @@ class YoloDataset(Dataset): image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1)) box = np.array(box, dtype=np.float32) - if len(box) != 0: + + nL = len(box) # number of labels + labels_out = np.zeros((nL, 6)) + if nL: #---------------------------------------------------# # 对真实框进行归一化,调整到0-1之间 #---------------------------------------------------# @@ -70,8 +72,11 @@ class YoloDataset(Dataset): #---------------------------------------------------# box[:, 2:4] = box[:, 2:4] - box[:, 0:2] box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2 - y_true = self.get_target(box) - return image, box, y_true + + labels_out[:, 1] = box[:, -1] + labels_out[:, 2:] = box[:, :4] + + return image, labels_out def rand(self, a=0, b=1): return np.random.rand()*(b-a) + a @@ -378,127 +383,16 @@ class YoloDataset(Dataset): new_boxes = np.concatenate([box_1, box_2], axis=0) return new_image, new_boxes - def get_near_points(self, x, y, i, j): - sub_x = x - i - sub_y = y - j - if sub_x > 0.5 and sub_y > 0.5: - return [[0, 0], [1, 0], [0, 1]] - elif sub_x < 0.5 and sub_y > 0.5: - return [[0, 0], [-1, 0], [0, 1]] - elif sub_x < 0.5 and sub_y < 0.5: - return [[0, 0], [-1, 0], [0, -1]] - else: - return [[0, 0], [1, 0], [0, -1]] - - def get_target(self, targets): - #-----------------------------------------------------------# - # 一共有三个特征层数 - #-----------------------------------------------------------# - num_layers = len(self.anchors_mask) - - input_shape = np.array(self.input_shape, dtype='int32') - grid_shapes = [input_shape // {0:32, 1:16, 2:8, 3:4}[l] for l in range(num_layers)] - y_true = [np.zeros((len(self.anchors_mask[l]), grid_shapes[l][0], grid_shapes[l][1], self.bbox_attrs), dtype='float32') for l in range(num_layers)] - box_best_ratio = [np.zeros((len(self.anchors_mask[l]), grid_shapes[l][0], grid_shapes[l][1]), dtype='float32') for l in range(num_layers)] - - if len(targets) == 0: - return y_true - - for l in range(num_layers): - in_h, in_w = grid_shapes[l] - anchors = np.array(self.anchors) / {0:32, 1:16, 2:8, 3:4}[l] - - batch_target = np.zeros_like(targets) - #-------------------------------------------------------# - # 计算出正样本在特征层上的中心点 - #-------------------------------------------------------# - batch_target[:, [0,2]] = targets[:, [0,2]] * in_w - batch_target[:, [1,3]] = targets[:, [1,3]] * in_h - batch_target[:, 4] = targets[:, 4] - #-------------------------------------------------------# - # wh : num_true_box, 2 - # np.expand_dims(wh, 1) : num_true_box, 1, 2 - # anchors : 9, 2 - # np.expand_dims(anchors, 0) : 1, 9, 2 - # - # ratios_of_gt_anchors代表每一个真实框和每一个先验框的宽高的比值 - # ratios_of_gt_anchors : num_true_box, 9, 2 - # ratios_of_anchors_gt代表每一个先验框和每一个真实框的宽高的比值 - # ratios_of_anchors_gt : num_true_box, 9, 2 - # - # ratios : num_true_box, 9, 4 - # max_ratios代表每一个真实框和每一个先验框的宽高的比值的最大值 - # max_ratios : num_true_box, 9 - #-------------------------------------------------------# - ratios_of_gt_anchors = np.expand_dims(batch_target[:, 2:4], 1) / np.expand_dims(anchors, 0) - ratios_of_anchors_gt = np.expand_dims(anchors, 0) / np.expand_dims(batch_target[:, 2:4], 1) - ratios = np.concatenate([ratios_of_gt_anchors, ratios_of_anchors_gt], axis = -1) - max_ratios = np.max(ratios, axis = -1) - - for t, ratio in enumerate(max_ratios): - #-------------------------------------------------------# - # ratio : 9 - #-------------------------------------------------------# - over_threshold = ratio < self.threshold - over_threshold[np.argmin(ratio)] = True - for k, mask in enumerate(self.anchors_mask[l]): - if not over_threshold[mask]: - continue - #----------------------------------------# - # 获得真实框属于哪个网格点 - # x 1.25 => 1 - # y 3.75 => 3 - #----------------------------------------# - i = int(np.floor(batch_target[t, 0])) - j = int(np.floor(batch_target[t, 1])) - - offsets = self.get_near_points(batch_target[t, 0], batch_target[t, 1], i, j) - for offset in offsets: - local_i = i + offset[0] - local_j = j + offset[1] - - if local_i >= in_w or local_i < 0 or local_j >= in_h or local_j < 0: - continue - - if box_best_ratio[l][k, local_j, local_i] != 0: - if box_best_ratio[l][k, local_j, local_i] > ratio[mask]: - y_true[l][k, local_j, local_i, :] = 0 - else: - continue - - #----------------------------------------# - # 取出真实框的种类 - #----------------------------------------# - c = int(batch_target[t, 4]) - - #----------------------------------------# - # tx、ty代表中心调整参数的真实值 - #----------------------------------------# - y_true[l][k, local_j, local_i, 0] = batch_target[t, 0] - y_true[l][k, local_j, local_i, 1] = batch_target[t, 1] - y_true[l][k, local_j, local_i, 2] = batch_target[t, 2] - y_true[l][k, local_j, local_i, 3] = batch_target[t, 3] - y_true[l][k, local_j, local_i, 4] = 1 - y_true[l][k, local_j, local_i, c + 5] = 1 - #----------------------------------------# - # 获得当前先验框最好的比例 - #----------------------------------------# - box_best_ratio[l][k, local_j, local_i] = ratio[mask] - - return y_true # DataLoader中collate_fn使用 def yolo_dataset_collate(batch): images = [] bboxes = [] - y_trues = [[] for _ in batch[0][2]] - for img, box, y_true in batch: + for i, (img, box) in enumerate(batch): images.append(img) + box[:, 0] = i bboxes.append(box) - for i, sub_y_true in enumerate(y_true): - y_trues[i].append(sub_y_true) images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) - bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes] - y_trues = [torch.from_numpy(np.array(ann, np.float32)).type(torch.FloatTensor) for ann in y_trues] - return images, bboxes,y_trues + bboxes = torch.from_numpy(np.concatenate(bboxes, 0)).type(torch.FloatTensor) + return images, bboxes diff --git a/utils/utils_fit.py b/utils/utils_fit.py index b86224c..f5e6bc8 100644 --- a/utils/utils_fit.py +++ b/utils/utils_fit.py @@ -17,12 +17,11 @@ def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callbac if iteration >= epoch_step: break - images, targets, y_trues = batch[0], batch[1], batch[2] + images, targets = batch[0], batch[1] with torch.no_grad(): if cuda: images = images.cuda(local_rank) - targets = [ann.cuda(local_rank) for ann in targets] - y_trues = [ann.cuda(local_rank) for ann in y_trues] + targets = targets.cuda(local_rank) #----------------------# # 清零梯度 #----------------------# @@ -32,15 +31,7 @@ def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callbac # 前向传播 #----------------------# outputs = model_train(images) - - loss_value_all = 0 - #----------------------# - # 计算损失 - #----------------------# - for l in range(len(outputs)): - loss_item = yolo_loss(l, outputs[l], targets, y_trues[l]) - loss_value_all += loss_item - loss_value = loss_value_all + loss_value = yolo_loss(outputs, targets, images) #----------------------# # 反向传播 @@ -54,15 +45,7 @@ def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callbac # 前向传播 #----------------------# outputs = model_train(images) - - loss_value_all = 0 - #----------------------# - # 计算损失 - #----------------------# - for l in range(len(outputs)): - loss_item = yolo_loss(l, outputs[l], targets, y_trues[l]) - loss_value_all += loss_item - loss_value = loss_value_all + loss_value = yolo_loss(outputs, targets, images) #----------------------# # 反向传播 -- GitLab