未验证 提交 968b754e 编写于 作者: B Bubbliiiing 提交者: GitHub

Update yolo_training.py

上级 a16c5082
......@@ -9,6 +9,32 @@ from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
from PIL import Image
from utils.utils import bbox_iou, merge_bboxes
def jaccard(_box_a, _box_b):
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
box_a = torch.zeros_like(_box_a)
box_b = torch.zeros_like(_box_b)
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
inter = inter[:, :, 0] * inter[:, :, 1]
# 计算先验框和真实框各自的面积
area_a = ((box_a[:, 2]-box_a[:, 0]) *
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
area_b = ((box_b[:, 2]-box_b[:, 0]) *
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
# 求IOU
union = area_a + area_b - inter
return inter / union # [A,B]
#---------------------------------------------------#
# 平滑标签
#---------------------------------------------------#
......@@ -261,17 +287,17 @@ class YOLOLoss(nn.Module):
for i in range(bs):
pred_boxes_for_ignore = pred_boxes[i]
pred_boxes_for_ignore = pred_boxes_for_ignore.view(-1, 4)
for t in range(target[i].shape[0]):
gx = target[i][t, 0] * in_w
gy = target[i][t, 1] * in_h
gw = target[i][t, 2] * in_w
gh = target[i][t, 3] * in_h
gt_box = torch.FloatTensor(np.array([gx, gy, gw, gh])).unsqueeze(0).type(FloatTensor)
anch_ious = bbox_iou(gt_box, pred_boxes_for_ignore, x1y1x2y2=False)
anch_ious = anch_ious.view(pred_boxes[i].size()[:3])
noobj_mask[i][anch_ious>self.ignore_threshold] = 0
if len(target[i]) > 0:
gx = target[i][:, 0:1] * in_w
gy = target[i][:, 1:2] * in_h
gw = target[i][:, 2:3] * in_w
gh = target[i][:, 3:4] * in_h
gt_box = torch.FloatTensor(np.concatenate([gx, gy, gw, gh],-1)).type(FloatTensor)
anch_ious = jaccard(gt_box, pred_boxes_for_ignore)
for t in range(target[i].shape[0]):
anch_iou = anch_ious[t].view(pred_boxes[i].size()[:3])
noobj_mask[i][anch_iou>self.ignore_threshold] = 0
return noobj_mask, pred_boxes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册