提交 a7c29abb 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修改解耦

上级 d59a72cb
import numpy as np
import torch
from torchvision.ops import nms
from utils.nms_rotated import obb_nms
# from utils.nms_rotated import obb_nms
class DecodeBox():
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
......@@ -23,7 +23,7 @@ class DecodeBox():
#-----------------------------------------------#
# 输入的input一共有三个,他们的shape分别是
# batch_size = 1
# batch_size, 3 * (4 + 1 + 80), 20, 20
# batch_size, 3 * (5 + 1 + 80), 20, 20
# batch_size, 255, 40, 40
# batch_size, 255, 80, 80
#-----------------------------------------------#
......@@ -64,11 +64,11 @@ class DecodeBox():
#-----------------------------------------------#
# 获得置信度,是否有物体
#-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4])
conf = torch.sigmoid(prediction[..., 5])
#-----------------------------------------------#
# 种类置信度
#-----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:])
pred_cls = torch.sigmoid(prediction[..., 6:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
......@@ -232,102 +232,102 @@ class DecodeBox():
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output
def non_max_suppression_obb(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 - 1 # number of classes
xc = prediction[..., 5] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling no used just now
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 6), device=x.device)
v[:, :5] = l[:, 1:6] # box
v[:, 5] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 6] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
if nc == 1:
x[:, 6: 6+nc] = x[:, 5:6] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
x[:, 6:6+nc] *= x[:, 5:6] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
# box = xywh2xyxy(x[:, :4])
# _, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179]
# theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2)
theta_pred = (x[:,4:5] - 0.5) * torch.pi
# Detections matrix nx7 (xyxy,theta, conf, cls)
if multi_label:
i, j = (x[:, 6:6+nc] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 6, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 6:6+nc].max(1, keepdim=True)
x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 6:7] * (0 if agnostic else max_wh) # classes
rboxes = x[:, :5].clone()
rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class)
scores = x[:, 5] # scores
#boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
#i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
_, i = obb_nms(rboxes, scores, iou_thres) # obb NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
return output
def non_max_suppression_obb(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 - 1 # number of classes
xc = prediction[..., 5] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling no used just now
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 6), device=x.device)
v[:, :5] = l[:, 1:6] # box
v[:, 5] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 6] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
if nc == 1:
x[:, 6: 6+nc] = x[:, 5:6] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
x[:, 6:6+nc] *= x[:, 5:6] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
# box = xywh2xyxy(x[:, :4])
# _, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179]
# theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2)
theta_pred = (x[:,4:5] - 0.5) * torch.pi
# Detections matrix nx7 (xyxy,theta, conf, cls)
if multi_label:
i, j = (x[:, 6:6+nc] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 6, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 6:6+nc].max(1, keepdim=True)
x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 6:7] * (0 if agnostic else max_wh) # classes
rboxes = x[:, :5].clone()
rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class)
scores = x[:, 5] # scores
#boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
#i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
_, i = obb_nms(rboxes, scores, iou_thres) # obb NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
return output
if __name__ == "__main__":
......@@ -339,7 +339,7 @@ if __name__ == "__main__":
#---------------------------------------------------#
def get_anchors_and_decode(input, input_shape, anchors, anchors_mask, num_classes):
#-----------------------------------------------#
# input batch_size, 3 * (4 + 1 + num_classes), 20, 20
# input batch_size, 3 * (5 + 1 + num_classes), 20, 20
#-----------------------------------------------#
batch_size = input.size(0)
input_height = input.size(2)
......@@ -364,7 +364,7 @@ if __name__ == "__main__":
# batch_size, 3, 20, 20, 4 + 1 + num_classes
#-----------------------------------------------#
prediction = input.view(batch_size, len(anchors_mask[2]),
num_classes + 5, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
num_classes + 6, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
#-----------------------------------------------#
# 先验框的中心位置的调整参数
......@@ -379,11 +379,11 @@ if __name__ == "__main__":
#-----------------------------------------------#
# 获得置信度,是否有物体 0 - 1
#-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4])
conf = torch.sigmoid(prediction[..., 5])
#-----------------------------------------------#
# 种类置信度 0 - 1
#-----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:])
pred_cls = torch.sigmoid(prediction[..., 6:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
......@@ -498,7 +498,7 @@ if __name__ == "__main__":
plt.show()
#
feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 255, 20, 20])).float()
feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 258, 20, 20])).float()
anchors = np.array([[116, 90], [156, 198], [373, 326], [30,61], [62,45], [59,119], [10,13], [16,30], [33,23]])
anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
get_anchors_and_decode(feat, [640, 640], anchors, anchors_mask, 80)
......@@ -10,7 +10,7 @@ from PIL import ImageDraw, ImageFont
from nets.yolo import YoloBody
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
resize_image, show_config)
from utils.utils_bbox import non_max_suppression_obb
from utils.utils_bbox import DecodeBox
from utils.utils_rbox import rbox2poly
'''
训练自己的数据集必看注释!
......@@ -84,7 +84,7 @@ class YOLO(object):
#---------------------------------------------------#
self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
......@@ -144,10 +144,11 @@ class YOLO(object):
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = self.bbox_util.decode_box(outputs)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression_obb(outputs, self.confidence, self.nms_iou, classes=self.num_classes)
results = self.bbox_util.non_max_suppression_obb(torch.cat(outputs, 1), self.confidence, self.nms_iou, classes=self.num_classes)
if results[0] is None:
return image
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册