提交 a7c29abb 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修改解耦

上级 d59a72cb
import numpy as np import numpy as np
import torch import torch
from torchvision.ops import nms from torchvision.ops import nms
from utils.nms_rotated import obb_nms # from utils.nms_rotated import obb_nms
class DecodeBox(): class DecodeBox():
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]): def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
...@@ -23,7 +23,7 @@ class DecodeBox(): ...@@ -23,7 +23,7 @@ class DecodeBox():
#-----------------------------------------------# #-----------------------------------------------#
# 输入的input一共有三个,他们的shape分别是 # 输入的input一共有三个,他们的shape分别是
# batch_size = 1 # batch_size = 1
# batch_size, 3 * (4 + 1 + 80), 20, 20 # batch_size, 3 * (5 + 1 + 80), 20, 20
# batch_size, 255, 40, 40 # batch_size, 255, 40, 40
# batch_size, 255, 80, 80 # batch_size, 255, 80, 80
#-----------------------------------------------# #-----------------------------------------------#
...@@ -64,11 +64,11 @@ class DecodeBox(): ...@@ -64,11 +64,11 @@ class DecodeBox():
#-----------------------------------------------# #-----------------------------------------------#
# 获得置信度,是否有物体 # 获得置信度,是否有物体
#-----------------------------------------------# #-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4]) conf = torch.sigmoid(prediction[..., 5])
#-----------------------------------------------# #-----------------------------------------------#
# 种类置信度 # 种类置信度
#-----------------------------------------------# #-----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:]) pred_cls = torch.sigmoid(prediction[..., 6:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
...@@ -232,102 +232,102 @@ class DecodeBox(): ...@@ -232,102 +232,102 @@ class DecodeBox():
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output return output
def non_max_suppression_obb(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, def non_max_suppression_obb(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()): labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results """Runs Non-Maximum Suppression (NMS) on inference results
Returns: Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls] list of detections, on (n,6) tensor per image [xyxy, conf, cls]
""" """
nc = prediction.shape[2] - 5 - 1 # number of classes nc = prediction.shape[2] - 5 - 1 # number of classes
xc = prediction[..., 5] > conf_thres # candidates xc = prediction[..., 5] > conf_thres # candidates
# Settings # Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS merge = False # use merge-NMS
output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0] output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints # Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling no used just now # Cat apriori labels if autolabelling no used just now
if labels and len(labels[xi]): if labels and len(labels[xi]):
l = labels[xi] l = labels[xi]
v = torch.zeros((len(l), nc + 6), device=x.device) v = torch.zeros((len(l), nc + 6), device=x.device)
v[:, :5] = l[:, 1:6] # box v[:, :5] = l[:, 1:6] # box
v[:, 5] = 1.0 # conf v[:, 5] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 6] = 1.0 # cls v[range(len(l)), l[:, 0].long() + 6] = 1.0 # cls
x = torch.cat((x, v), 0) x = torch.cat((x, v), 0)
# If none remain process next image # If none remain process next image
if not x.shape[0]: if not x.shape[0]:
continue continue
# Compute conf # Compute conf
if nc == 1: if nc == 1:
x[:, 6: 6+nc] = x[:, 5:6] # for models with one class, cls_loss is 0 and cls_conf is always 0.5, x[:, 6: 6+nc] = x[:, 5:6] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate. # so there is no need to multiplicate.
else: else:
x[:, 6:6+nc] *= x[:, 5:6] # conf = obj_conf * cls_conf x[:, 6:6+nc] *= x[:, 5:6] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)
# box = xywh2xyxy(x[:, :4]) # box = xywh2xyxy(x[:, :4])
# _, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179] # _, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179]
# theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2) # theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2)
theta_pred = (x[:,4:5] - 0.5) * torch.pi theta_pred = (x[:,4:5] - 0.5) * torch.pi
# Detections matrix nx7 (xyxy,theta, conf, cls) # Detections matrix nx7 (xyxy,theta, conf, cls)
if multi_label: if multi_label:
i, j = (x[:, 6:6+nc] > conf_thres).nonzero(as_tuple=False).T i, j = (x[:, 6:6+nc] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 6, None], j[:, None].float()), 1) x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 6, None], j[:, None].float()), 1)
else: # best class only else: # best class only
conf, j = x[:, 6:6+nc].max(1, keepdim=True) conf, j = x[:, 6:6+nc].max(1, keepdim=True)
x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres] x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class # Filter by class
if classes is not None: if classes is not None:
x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)] x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint # Apply finite constraint
# if not torch.isfinite(x).all(): # if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)] # x = x[torch.isfinite(x).all(1)]
# Check shape # Check shape
n = x.shape[0] # number of boxes n = x.shape[0] # number of boxes
if not n: # no boxes if not n: # no boxes
continue continue
elif n > max_nms: # excess boxes elif n > max_nms: # excess boxes
x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS # Batched NMS
c = x[:, 6:7] * (0 if agnostic else max_wh) # classes c = x[:, 6:7] * (0 if agnostic else max_wh) # classes
rboxes = x[:, :5].clone() rboxes = x[:, :5].clone()
rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class) rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class)
scores = x[:, 5] # scores scores = x[:, 5] # scores
#boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores #boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
#i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS #i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
_, i = obb_nms(rboxes, scores, iou_thres) # obb NMS _, i = obb_nms(rboxes, scores, iou_thres) # obb NMS
if i.shape[0] > max_det: # limit detections if i.shape[0] > max_det: # limit detections
i = i[:max_det] i = i[:max_det]
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) # # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix # iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights # weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# if redundant: # if redundant:
# i = i[iou.sum(1) > 1] # require redundancy # i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i] output[xi] = x[i]
return output return output
if __name__ == "__main__": if __name__ == "__main__":
...@@ -339,7 +339,7 @@ if __name__ == "__main__": ...@@ -339,7 +339,7 @@ if __name__ == "__main__":
#---------------------------------------------------# #---------------------------------------------------#
def get_anchors_and_decode(input, input_shape, anchors, anchors_mask, num_classes): def get_anchors_and_decode(input, input_shape, anchors, anchors_mask, num_classes):
#-----------------------------------------------# #-----------------------------------------------#
# input batch_size, 3 * (4 + 1 + num_classes), 20, 20 # input batch_size, 3 * (5 + 1 + num_classes), 20, 20
#-----------------------------------------------# #-----------------------------------------------#
batch_size = input.size(0) batch_size = input.size(0)
input_height = input.size(2) input_height = input.size(2)
...@@ -364,7 +364,7 @@ if __name__ == "__main__": ...@@ -364,7 +364,7 @@ if __name__ == "__main__":
# batch_size, 3, 20, 20, 4 + 1 + num_classes # batch_size, 3, 20, 20, 4 + 1 + num_classes
#-----------------------------------------------# #-----------------------------------------------#
prediction = input.view(batch_size, len(anchors_mask[2]), prediction = input.view(batch_size, len(anchors_mask[2]),
num_classes + 5, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous() num_classes + 6, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
#-----------------------------------------------# #-----------------------------------------------#
# 先验框的中心位置的调整参数 # 先验框的中心位置的调整参数
...@@ -379,11 +379,11 @@ if __name__ == "__main__": ...@@ -379,11 +379,11 @@ if __name__ == "__main__":
#-----------------------------------------------# #-----------------------------------------------#
# 获得置信度,是否有物体 0 - 1 # 获得置信度,是否有物体 0 - 1
#-----------------------------------------------# #-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4]) conf = torch.sigmoid(prediction[..., 5])
#-----------------------------------------------# #-----------------------------------------------#
# 种类置信度 0 - 1 # 种类置信度 0 - 1
#-----------------------------------------------# #-----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:]) pred_cls = torch.sigmoid(prediction[..., 6:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
...@@ -498,7 +498,7 @@ if __name__ == "__main__": ...@@ -498,7 +498,7 @@ if __name__ == "__main__":
plt.show() plt.show()
# #
feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 255, 20, 20])).float() feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 258, 20, 20])).float()
anchors = np.array([[116, 90], [156, 198], [373, 326], [30,61], [62,45], [59,119], [10,13], [16,30], [33,23]]) anchors = np.array([[116, 90], [156, 198], [373, 326], [30,61], [62,45], [59,119], [10,13], [16,30], [33,23]])
anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
get_anchors_and_decode(feat, [640, 640], anchors, anchors_mask, 80) get_anchors_and_decode(feat, [640, 640], anchors, anchors_mask, 80)
...@@ -10,7 +10,7 @@ from PIL import ImageDraw, ImageFont ...@@ -10,7 +10,7 @@ from PIL import ImageDraw, ImageFont
from nets.yolo import YoloBody from nets.yolo import YoloBody
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input, from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
resize_image, show_config) resize_image, show_config)
from utils.utils_bbox import non_max_suppression_obb from utils.utils_bbox import DecodeBox
from utils.utils_rbox import rbox2poly from utils.utils_rbox import rbox2poly
''' '''
训练自己的数据集必看注释! 训练自己的数据集必看注释!
...@@ -84,7 +84,7 @@ class YOLO(object): ...@@ -84,7 +84,7 @@ class YOLO(object):
#---------------------------------------------------# #---------------------------------------------------#
self.class_names, self.num_classes = get_classes(self.classes_path) self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors, self.num_anchors = get_anchors(self.anchors_path) self.anchors, self.num_anchors = get_anchors(self.anchors_path)
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
#---------------------------------------------------# #---------------------------------------------------#
# 画框设置不同的颜色 # 画框设置不同的颜色
#---------------------------------------------------# #---------------------------------------------------#
...@@ -144,10 +144,11 @@ class YOLO(object): ...@@ -144,10 +144,11 @@ class YOLO(object):
# 将图像输入网络当中进行预测! # 将图像输入网络当中进行预测!
#---------------------------------------------------------# #---------------------------------------------------------#
outputs = self.net(images) outputs = self.net(images)
outputs = self.bbox_util.decode_box(outputs)
#---------------------------------------------------------# #---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制 # 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------# #---------------------------------------------------------#
results = non_max_suppression_obb(outputs, self.confidence, self.nms_iou, classes=self.num_classes) results = self.bbox_util.non_max_suppression_obb(torch.cat(outputs, 1), self.confidence, self.nms_iou, classes=self.num_classes)
if results[0] is None: if results[0] is None:
return image return image
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册