提交 837e72e6 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修复

上级 8b7c5d20
......@@ -113,7 +113,7 @@ class YoloDataset(Dataset):
image = image.resize((w,h), Image.BICUBIC)
image_data = np.array(image, np.float32)
return image_data, box
return image_data, rbox
def merge_bboxes(self, bboxes, cutx, cuty):
merge_bbox = []
......
......@@ -229,104 +229,6 @@ class DecodeBox():
output[i][:, :5] = self.yolo_correct_boxes(box_xy, box_wh, angle, input_shape, image_shape, letterbox_image)
return output
def non_max_suppression_obb(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 - 1 # number of classes
xc = prediction[..., 5] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling no used just now
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 6), device=x.device)
v[:, :5] = l[:, 1:6] # box
v[:, 5] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 6] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
if nc == 1:
x[:, 6: 6+nc] = x[:, 5:6] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
x[:, 6:6+nc] *= x[:, 5:6] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
# box = xywh2xyxy(x[:, :4])
# _, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179]
# theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2)
theta_pred = (x[:,4:5] - 0.5) * torch.pi
# Detections matrix nx7 (xyxy,theta, conf, cls)
if multi_label:
i, j = (x[:, 6:6+nc] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 6, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 6:6+nc].max(1, keepdim=True)
x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 6:7] * (0 if agnostic else max_wh) # classes
rboxes = x[:, :5].clone()
rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class)
scores = x[:, 5] # scores
#boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
#i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
_, i = obb_nms(rboxes, scores, iou_thres) # obb NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
return output
if __name__ == "__main__":
import matplotlib.pyplot as plt
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册