提交 e0535031 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

尝试增加obb

上级 681713de
......@@ -277,11 +277,11 @@ class YoloBody(nn.Module):
# 4 + 1 + num_classes
# 80, 80, 256 => 80, 80, 3 * 25 (4 + 1 + 20) & 85 (4 + 1 + 80)
self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1)
self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + 1 + num_classes), 1)
# 40, 40, 512 => 40, 40, 3 * 25 & 85
self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1)
self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + 1 + num_classes), 1)
# 20, 20, 512 => 20, 20, 3 * 25 & 85
self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1)
self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + 1 + num_classes), 1)
def fuse(self):
print('Fusing layers... ')
......
......@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.kld_loss import compute_kld_loss, KLDloss
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
# return positive, negative label smoothing BCE targets
......@@ -35,6 +35,7 @@ class YOLOLoss(nn.Module):
self.cp, self.cn = smooth_BCE(eps=label_smoothing)
self.BCEcls, self.BCEobj, self.gr = nn.BCEWithLogitsLoss(), nn.BCEWithLogitsLoss(), 1
self.kldbbox = KLDloss(taf=1.0, fun='sqrt')
def bbox_iou(self, box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
box2 = box2.T
......@@ -123,7 +124,7 @@ class YOLOLoss(nn.Module):
n = b.shape[0]
if n:
prediction_pos = prediction[b, a, gj, gi] # prediction subset corresponding to targets
# prediction_pos [xywh angle conf cls ]
#-------------------------------------------#
# 计算匹配上的正样本的回归损失
#-------------------------------------------#
......@@ -136,35 +137,38 @@ class YOLOLoss(nn.Module):
#-------------------------------------------#
xy = prediction_pos[:, :2].sigmoid() * 2. - 0.5
wh = (prediction_pos[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
box = torch.cat((xy, wh), 1)
angle = (prediction_pos[:, 4:5].sigmoid() - 0.5) * torch.pi
box_theta = torch.cat((xy, wh, angle), 1)
#-------------------------------------------#
# 对真实框进行处理,映射到特征层上
#-------------------------------------------#
selected_tbox = targets[i][:, 2:6] * feature_map_sizes[i]
selected_tbox[:, :2] -= grid.type_as(prediction)
theta = targets[i][:, 6:7]
selected_tbox_theta = torch.cat((selected_tbox, theta),1)
#-------------------------------------------#
# 计算预测框和真实框的回归损失
#-------------------------------------------#
iou = self.bbox_iou(box.T, selected_tbox, x1y1x2y2=False, CIoU=True)
box_loss += (1.0 - iou).mean()
kldloss = self.kldbbox(box_theta, selected_tbox_theta)
loss += kldloss.mean()
#-------------------------------------------#
# 根据预测结果的iou获得置信度损失的gt
#-------------------------------------------#
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * (1 - kldloss).detach().clamp(0).type(tobj.dtype) # iou ratio
#-------------------------------------------#
# 计算匹配上的正样本的分类损失
#-------------------------------------------#
selected_tcls = targets[i][:, 1].long()
t = torch.full_like(prediction_pos[:, 5:], self.cn, device=device) # targets
t = torch.full_like(prediction_pos[:, 6:], self.cn, device=device) # targets
t[range(n), selected_tcls] = self.cp
cls_loss += self.BCEcls(prediction_pos[:, 5:], t) # BCE
cls_loss += self.BCEcls(prediction_pos[:, 6:], t) # BCE
#-------------------------------------------#
# 计算目标是否存在的置信度损失
# 并且乘上每个特征层的比例
#-------------------------------------------#
obj_loss += self.BCEobj(prediction[..., 4], tobj) * self.balance[i] # obj loss
obj_loss += self.BCEobj(prediction[..., 5], tobj) * self.balance[i] # obj loss
#-------------------------------------------#
# 将各个部分的损失乘上比例
......@@ -237,6 +241,7 @@ class YOLOLoss(nn.Module):
#-------------------------------------------#
b_idx = targets[:, 0]==batch_idx
this_target = targets[b_idx]
# targets (tensor): (n_gt_all_batch, [img_index clsid cx cy l s theta ])
#-------------------------------------------#
# 如果没有真实框属于该图片则continue
#-------------------------------------------#
......@@ -250,7 +255,7 @@ class YOLOLoss(nn.Module):
#-------------------------------------------#
# 从中心宽高到左上角右下角
#-------------------------------------------#
txyxy = self.xywh2xyxy(txywh)
txyxy = torch.cat((txywh, this_target[:,6:]), dim=-1)
pxyxys = []
p_cls = []
......@@ -285,8 +290,8 @@ class YOLOLoss(nn.Module):
# 取出这个真实框对应的预测结果
#-------------------------------------------#
fg_pred = prediction[b, a, gj, gi]
p_obj.append(fg_pred[:, 4:5])
p_cls.append(fg_pred[:, 5:])
p_obj.append(fg_pred[:, 5:6]) # [4:5] = theta
p_cls.append(fg_pred[:, 6:])
#-------------------------------------------#
# 获得网格后,进行解码
......@@ -294,9 +299,9 @@ class YOLOLoss(nn.Module):
grid = torch.stack([gi, gj], dim=1).type_as(fg_pred)
pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i]
pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i]
pxywh = torch.cat([pxy, pwh], dim=-1)
pxyxy = self.xywh2xyxy(pxywh)
pxyxys.append(pxyxy)
pangle = (fg_pred[:, 4:5].sigmoid() - 0.5) * torch.pi
pxywh = torch.cat([pxy, pwh, pangle], dim=-1)
pxyxys.append(pxywh)
#-------------------------------------------#
# 判断是否存在对应的预测框,不存在则跳过
......@@ -323,8 +328,8 @@ class YOLOLoss(nn.Module):
# 重合程度越大,取-log后越小
# 因此,真实框与预测框重合度越大,pair_wise_iou_loss越小
#-------------------------------------------------------------#
pair_wise_iou = self.box_iou(txyxy, pxyxys)
pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
pair_wise_iou_loss = compute_kld_loss(txyxy, pxyxys, taf=1.0, fun='sqrt')
pair_wise_iou = 1 - pair_wise_iou_loss
#-------------------------------------------#
# 最多二十个预测框与真实框的重合程度
......@@ -427,14 +432,14 @@ class YOLOLoss(nn.Module):
# 序号2:6为特征层的高宽
# 序号6为1
#------------------------------------#
gain = torch.ones(7, device=targets.device)
gain = torch.ones(8, device=targets.device)
#------------------------------------#
# ai [num_anchor, num_gt]
# targets [num_gt, 6] => [num_anchor, num_gt, 7]
# targets [num_gt, 6] => [num_anchor, num_gt, 8]
#------------------------------------#
ai = torch.arange(num_anchor, device=targets.device).float().view(num_anchor, 1).repeat(1, num_gt)
targets = torch.cat((targets.repeat(num_anchor, 1, 1), ai[:, :, None]), 2) # append anchor indices
# targets (tensor): (na, n_gt_all_batch, [img_index, clsid, cx, cy, l, s, theta, anchor_index]])
g = 0.5 # offsets
off = torch.tensor([
[0, 0],
......@@ -509,7 +514,7 @@ class YOLOLoss(nn.Module):
# gj、gi不能超出特征层范围
# a代表属于该特征点的第几个先验框
#-------------------------------------------#
a = t[:, 6].long() # anchor indices
a = t[:, -1].long() # anchor indices
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid indices
anchors.append(anchors_i[a]) # anchors
......
......@@ -7,7 +7,7 @@ from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
from utils.utils_rbox import poly_filter, poly2rbox
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, anchors, anchors_mask, epoch_length, \
......@@ -29,7 +29,7 @@ class YoloDataset(Dataset):
self.epoch_now = -1
self.length = len(self.annotation_lines)
self.bbox_attrs = 5 + num_classes
self.bbox_attrs = 5 + 1 + num_classes
def __len__(self):
return self.length
......@@ -61,7 +61,7 @@ class YoloDataset(Dataset):
# 对真实框进行预处理
#---------------------------------------------------#
nL = len(box)
labels_out = np.zeros((nL, 6))
labels_out = np.zeros((nL, 7))
if nL:
#---------------------------------------------------#
# 对真实框进行归一化,调整到0-1之间
......@@ -71,7 +71,8 @@ class YoloDataset(Dataset):
#---------------------------------------------------#
# 序号为0、1的部分,为真实框的中心
# 序号为2、3的部分,为真实框的宽高
# 序号为4的部分,为真实框的种类
# 序号为4的部分,为真实框的旋转角度
# 序号为5的部分,为真实框的种类
#---------------------------------------------------#
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
......@@ -81,7 +82,7 @@ class YoloDataset(Dataset):
# labels_out中序号为0的部分在collect时处理
#---------------------------------------------------#
labels_out[:, 1] = box[:, -1]
labels_out[:, 2:] = box[:, :4]
labels_out[:, 2:] = box[:, :5]
return image, labels_out
......
'''
Author: [egrt]
Date: 2023-01-30 18:47:24
LastEditors: [egrt]
LastEditTime: 2023-01-30 18:48:35
Description:
'''
import torch
import torch.nn as nn
class KLDloss(nn.Module):
def __init__(self, taf=1.0, fun="sqrt"):
super(KLDloss, self).__init__()
self.fun = fun
self.taf = taf
def forward(self, pred, target): # pred [[x,y,w,h,angle], ...]
#assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 5)
target = target.view(-1, 5)
delta_x = pred[:, 0] - target[:, 0]
delta_y = pred[:, 1] - target[:, 1]
pre_angle_radian = pred[:, 4]
targrt_angle_radian = target[:, 4]
delta_angle_radian = pre_angle_radian - targrt_angle_radian
kld = 0.5 * (
4 * torch.pow( ( delta_x.mul(torch.cos(targrt_angle_radian)) + delta_y.mul(torch.sin(targrt_angle_radian)) ), 2) / torch.pow(target[:, 2], 2)
+ 4 * torch.pow( ( delta_y.mul(torch.cos(targrt_angle_radian)) - delta_x.mul(torch.sin(targrt_angle_radian)) ), 2) / torch.pow(target[:, 3], 2)
)\
+ 0.5 * (
torch.pow(pred[:, 3], 2) / torch.pow(target[:, 2], 2) * torch.pow(torch.sin(delta_angle_radian), 2)
+ torch.pow(pred[:, 2], 2) / torch.pow(target[:, 3], 2) * torch.pow(torch.sin(delta_angle_radian), 2)
+ torch.pow(pred[:, 3], 2) / torch.pow(target[:, 3], 2) * torch.pow(torch.cos(delta_angle_radian), 2)
+ torch.pow(pred[:, 2], 2) / torch.pow(target[:, 2], 2) * torch.pow(torch.cos(delta_angle_radian), 2)
)\
+ 0.5 * (
torch.log(torch.pow(target[:, 3], 2) / torch.pow(pred[:, 3], 2))
+ torch.log(torch.pow(target[:, 2], 2) / torch.pow(pred[:, 2], 2))
)\
- 1.0
if self.fun == "sqrt":
kld = kld.clamp(1e-7).sqrt()
elif self.fun == "log1p":
kld = torch.log1p(kld.clamp(1e-7))
else:
pass
kld_loss = 1 - 1 / (self.taf + kld)
return kld_loss
def compute_kld_loss(targets, preds,taf=1.0,fun='sqrt'):
with torch.no_grad():
kld_loss_ts_ps = torch.zeros(0, preds.shape[0], device=targets.device)
for target in targets:
target = target.unsqueeze(0).repeat(preds.shape[0], 1)
kld_loss_t_p = kld_loss(preds, target,taf=taf, fun=fun)
kld_loss_ts_ps = torch.cat((kld_loss_ts_ps, kld_loss_t_p.unsqueeze(0)), dim=0)
return kld_loss_ts_ps
def kld_loss(pred, target, taf=1.0, fun='sqrt'): # pred [[x,y,w,h,angle], ...]
#assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 5)
target = target.view(-1, 5)
delta_x = pred[:, 0] - target[:, 0]
delta_y = pred[:, 1] - target[:, 1]
pre_angle_radian = pred[:, 4] #3.141592653589793 * pred[:, 4] / 180.0
targrt_angle_radian = target[:, 4] #3.141592653589793 * target[:, 4] / 180.0
delta_angle_radian = pre_angle_radian - targrt_angle_radian
kld = 0.5 * (
4 * torch.pow((delta_x.mul(torch.cos(targrt_angle_radian)) + delta_y.mul(torch.sin(targrt_angle_radian))),
2) / torch.pow(target[:, 2], 2)
+ 4 * torch.pow((delta_y.mul(torch.cos(targrt_angle_radian)) - delta_x.mul(torch.sin(targrt_angle_radian))),
2) / torch.pow(target[:, 3], 2)
) \
+ 0.5 * (
torch.pow(pred[:, 3], 2) / torch.pow(target[:, 2], 2) * torch.pow(torch.sin(delta_angle_radian), 2)
+ torch.pow(pred[:, 2], 2) / torch.pow(target[:, 3], 2) * torch.pow(torch.sin(delta_angle_radian), 2)
+ torch.pow(pred[:, 3], 2) / torch.pow(target[:, 3], 2) * torch.pow(torch.cos(delta_angle_radian), 2)
+ torch.pow(pred[:, 2], 2) / torch.pow(target[:, 2], 2) * torch.pow(torch.cos(delta_angle_radian), 2)
) \
+ 0.5 * (
torch.log(torch.pow(target[:, 3], 2) / torch.pow(pred[:, 3], 2))
+ torch.log(torch.pow(target[:, 2], 2) / torch.pow(pred[:, 2], 2))
) \
- 1.0
if fun == "sqrt":
kld = kld.clamp(1e-7).sqrt()
elif fun == "log1p":
kld = torch.log1p(kld.clamp(1e-7))
else:
pass
kld_loss = 1 - 1 / (taf + kld)
return kld_loss
if __name__ == '__main__':
'''
测试损失函数
'''
kld_loss_n = KLDloss(alpha=1,fun='log1p')
pred = torch.tensor([[5, 5, 5, 23, 0.15],[6,6,5,28,0]]).type(torch.float32)
target = torch.tensor([[5, 5, 5, 24, 0],[6,6,5,28,0]]).type(torch.float32)
kld = kld_loss_n(target,pred)
\ No newline at end of file
from .nms_rotated_wrapper import obb_nms, poly_nms
__all__ = ['obb_nms', 'poly_nms']
import numpy as np
import torch
from . import nms_rotated_ext
def obb_nms(dets, scores, iou_thr, device_id=None):
"""
RIoU NMS - iou_thr.
Args:
dets (tensor/array): (num, [cx cy w h θ]) θ∈[-pi/2, pi/2)
scores (tensor/array): (num)
iou_thr (float): (1)
Returns:
dets (tensor): (n_nms, [cx cy w h θ])
inds (tensor): (n_nms), nms index of dets
"""
if isinstance(dets, torch.Tensor):
is_numpy = False
dets_th = dets
elif isinstance(dets, np.ndarray):
is_numpy = True
device = 'cpu' if device_id is None else f'cuda:{device_id}'
dets_th = torch.from_numpy(dets).to(device)
else:
raise TypeError('dets must be eithr a Tensor or numpy array, '
f'but got {type(dets)}')
if dets_th.numel() == 0: # len(dets)
inds = dets_th.new_zeros(0, dtype=torch.int64)
else:
# same bug will happen when bboxes is too small
too_small = dets_th[:, [2, 3]].min(1)[0] < 0.001 # [n]
if too_small.all(): # all the bboxes is too small
inds = dets_th.new_zeros(0, dtype=torch.int64)
else:
ori_inds = torch.arange(dets_th.size(0)) # 0 ~ n-1
ori_inds = ori_inds[~too_small]
dets_th = dets_th[~too_small] # (n_filter, 5)
scores = scores[~too_small]
inds = nms_rotated_ext.nms_rotated(dets_th, scores, iou_thr)
inds = ori_inds[inds]
if is_numpy:
inds = inds.cpu().numpy()
return dets[inds, :], inds
def poly_nms(dets, iou_thr, device_id=None):
if isinstance(dets, torch.Tensor):
is_numpy = False
dets_th = dets
elif isinstance(dets, np.ndarray):
is_numpy = True
device = 'cpu' if device_id is None else f'cuda:{device_id}'
dets_th = torch.from_numpy(dets).to(device)
else:
raise TypeError('dets must be eithr a Tensor or numpy array, '
f'but got {type(dets)}')
if dets_th.device == torch.device('cpu'):
raise NotImplementedError
inds = nms_rotated_ext.nms_poly(dets_th.float(), iou_thr)
if is_numpy:
inds = inds.cpu().numpy()
return dets[inds, :], inds
if __name__ == '__main__':
rboxes_opencv = torch.tensor(([136.6, 111.6, 200, 100, -60],
[136.6, 111.6, 100, 200, -30],
[100, 100, 141.4, 141.4, -45],
[100, 100, 141.4, 141.4, -45]))
rboxes_longedge = torch.tensor(([136.6, 111.6, 200, 100, -60],
[136.6, 111.6, 200, 100, 120],
[100, 100, 141.4, 141.4, 45],
[100, 100, 141.4, 141.4, 135]))
\ No newline at end of file
#!/usr/bin/env python
import os
import subprocess
import time
from setuptools import find_packages, setup
import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
def make_cuda_ext(name, module, sources, sources_cuda=[]):
define_macros = []
extra_compile_args = {'cxx': []}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print(f'Compiling {name} without CUDA')
extension = CppExtension
# raise EnvironmentError('CUDA is required to compile MMDetection!')
return extension(
name=f'{module}.{name}',
sources=[os.path.join(*module.split('.'), p) for p in sources],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
# python setup.py develop
if __name__ == '__main__':
#write_version_py()
setup(
name='nms_rotated',
ext_modules=[
make_cuda_ext(
name='nms_rotated_ext',
module='',
sources=[
'src/nms_rotated_cpu.cpp',
'src/nms_rotated_ext.cpp'
],
sources_cuda=[
'src/nms_rotated_cuda.cu',
'src/poly_nms_cuda.cu',
]),
],
cmdclass={'build_ext': BuildExtension},
zip_safe=False)
\ No newline at end of file
// Mortified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/box_iou_rotated
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#pragma once
#include <cassert>
#include <cmath>
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
return Point(x + p.x, y + p.y);
}
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
return Point(x - p.x, y - p.y);
}
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
// R: result type. can be different from input type
template <typename T, typename R = T>
HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {
return static_cast<R>(A.x) * static_cast<R>(B.y) -
static_cast<R>(B.x) * static_cast<R>(A.y);
}
template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices(
const RotatedBox<T>& box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251; ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
HOST_DEVICE_INLINE int get_intersection_points(
const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham(
const Point<T> (&p)[24],
const int& num_in,
Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
// compute distance to origin before sort, and sort them together with the
// points
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
#else
// CPU version
std::sort(
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
// compute distance to origin after sort, since the points are now different.
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1) {
auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];
// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -
// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we
// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means
// round to nearest floating point).
if (q1.x * q2.y >= q2.x * q1.y)
m--;
else
break;
}
// Using double also helps, but float can solve the issue for now.
// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])
// >= 0) {
// m--;
// }
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T rotated_boxes_intersection(
const RotatedBox<T>& box1,
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
template <typename T>
HOST_DEVICE_INLINE T
single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
T area1 = box1.w * box1.h;
T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
T intersection = rotated_boxes_intersection<T>(box1, box2);
T iou = intersection / (area1 + area2 - intersection);
return iou;
}
// Modified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/nms_rotated
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include <torch/types.h>
#include "box_iou_rotated_utils.h"
template <typename scalar_t>
at::Tensor nms_rotated_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold) {
// nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel,
// however, the code in this function is much shorter because
// we delegate the IoU computation for rotated boxes to
// the single_box_iou_rotated function in box_iou_rotated_utils.h
AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor");
AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor");
AT_ASSERTM(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1) {
continue;
}
keep[num_to_keep++] = i;
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1) {
continue;
}
auto ovr = single_box_iou_rotated<scalar_t>(
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>());
if (ovr >= iou_threshold) {
suppressed[j] = 1;
}
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
at::Tensor nms_rotated_cpu(
// input must be contiguous
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold) {
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] {
result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
// Modified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/nms_rotated
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "box_iou_rotated_utils.h"
int const threadsPerBlock = sizeof(unsigned long long) * 8;
template <typename T>
__global__ void nms_rotated_cuda_kernel(
const int n_boxes,
const float iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
// nms_rotated_cuda_kernel is modified from torchvision's nms_cuda_kernel
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 5 values
// (x_center, y_center, width, height, angle_degrees) here.
__shared__ T block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_rotated function from box_iou_rotated_utils.h
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 5) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
at::Tensor nms_rotated_cuda(
// input must be contiguous
const at::Tensor& dets,
const at::Tensor& scores,
float iou_threshold) {
// using scalar_t = float;
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t);
auto dets_num = dets.size(0);
const int col_blocks =
at::cuda::ATenCeilDiv(static_cast<int>(dets_num), threadsPerBlock);
at::Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(
dets_sorted.scalar_type(), "nms_rotated_kernel_cuda", [&] {
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
iou_threshold,
dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>());
});
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
// Modified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/nms_rotated
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include <ATen/ATen.h>
#include <torch/extension.h>
#ifdef WITH_CUDA
at::Tensor nms_rotated_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
at::Tensor poly_nms_cuda(
const at::Tensor boxes,
float nms_overlap_thresh);
#endif
at::Tensor nms_rotated_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
inline at::Tensor nms_rotated(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold) {
assert(dets.device().is_cuda() == scores.device().is_cuda());
if (dets.device().is_cuda()) {
#ifdef WITH_CUDA
return nms_rotated_cuda(
dets.contiguous(), scores.contiguous(), iou_threshold);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold);
}
inline at::Tensor nms_poly(
const at::Tensor& dets,
const float iou_threshold) {
if (dets.device().is_cuda()) {
#ifdef WITH_CUDA
if (dets.numel() == 0)
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
return poly_nms_cuda(dets, iou_threshold);
#else
AT_ERROR("POLY_NMS is not compiled with GPU support");
#endif
}
AT_ERROR("POLY_NMS is not implemented on CPU");
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms_rotated", &nms_rotated, "nms for rotated bboxes");
m.def("nms_poly", &nms_poly, "nms for poly bboxes");
}
#include <torch/extension.h>
template <typename scalar_t>
at::Tensor poly_nms_cpu_kernel(const at::Tensor& dets, const float threshold) {
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <vector>
#include <iostream>
#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
if (error != cudaSuccess) { \
std::cout << cudaGetErrorString(error) << std::endl; \
} \
} while (0)
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;
#define maxn 10
const double eps=1E-8;
__device__ inline int sig(float d){
return(d>1E-8)-(d<-1E-8);
}
__device__ inline int point_eq(const float2 a, const float2 b) {
return sig(a.x - b.x) == 0 && sig(a.y - b.y)==0;
}
__device__ inline void point_swap(float2 *a, float2 *b) {
float2 temp = *a;
*a = *b;
*b = temp;
}
__device__ inline void point_reverse(float2 *first, float2* last)
{
while ((first!=last)&&(first!=--last)) {
point_swap (first,last);
++first;
}
}
__device__ inline float cross(float2 o,float2 a,float2 b){ //叉积
return(a.x-o.x)*(b.y-o.y)-(b.x-o.x)*(a.y-o.y);
}
__device__ inline float area(float2* ps,int n){
ps[n]=ps[0];
float res=0;
for(int i=0;i<n;i++){
res+=ps[i].x*ps[i+1].y-ps[i].y*ps[i+1].x;
}
return res/2.0;
}
__device__ inline int lineCross(float2 a,float2 b,float2 c,float2 d,float2&p){
float s1,s2;
s1=cross(a,b,c);
s2=cross(a,b,d);
if(sig(s1)==0&&sig(s2)==0) return 2;
if(sig(s2-s1)==0) return 0;
p.x=(c.x*s2-d.x*s1)/(s2-s1);
p.y=(c.y*s2-d.y*s1)/(s2-s1);
return 1;
}
__device__ inline void polygon_cut(float2*p,int&n,float2 a,float2 b, float2* pp){
int m=0;p[n]=p[0];
for(int i=0;i<n;i++){
if(sig(cross(a,b,p[i]))>0) pp[m++]=p[i];
if(sig(cross(a,b,p[i]))!=sig(cross(a,b,p[i+1])))
lineCross(a,b,p[i],p[i+1],pp[m++]);
}
n=0;
for(int i=0;i<m;i++)
if(!i||!(point_eq(pp[i], pp[i-1])))
p[n++]=pp[i];
// while(n>1&&p[n-1]==p[0])n--;
while(n>1&&point_eq(p[n-1], p[0]))n--;
}
//---------------华丽的分隔线-----------------//
//返回三角形oab和三角形ocd的有向交面积,o是原点//
__device__ inline float intersectArea(float2 a,float2 b,float2 c,float2 d){
float2 o = make_float2(0,0);
int s1=sig(cross(o,a,b));
int s2=sig(cross(o,c,d));
if(s1==0||s2==0)return 0.0;//退化,面积为0
// if(s1==-1) swap(a,b);
// if(s2==-1) swap(c,d);
if (s1 == -1) point_swap(&a, &b);
if (s2 == -1) point_swap(&c, &d);
float2 p[10]={o,a,b};
int n=3;
float2 pp[maxn];
polygon_cut(p,n,o,c,pp);
polygon_cut(p,n,c,d,pp);
polygon_cut(p,n,d,o,pp);
float res=fabs(area(p,n));
if(s1*s2==-1) res=-res;return res;
}
//求两多边形的交面积
__device__ inline float intersectArea(float2*ps1,int n1,float2*ps2,int n2){
if(area(ps1,n1)<0) point_reverse(ps1,ps1+n1);
if(area(ps2,n2)<0) point_reverse(ps2,ps2+n2);
ps1[n1]=ps1[0];
ps2[n2]=ps2[0];
float res=0;
for(int i=0;i<n1;i++){
for(int j=0;j<n2;j++){
res+=intersectArea(ps1[i],ps1[i+1],ps2[j],ps2[j+1]);
}
}
return res;//assumeresispositive!
}
// TODO: optimal if by first calculate the iou between two hbbs
__device__ inline float devPolyIoU(float const * const p, float const * const q) {
float2 ps1[maxn], ps2[maxn];
int n1 = 4;
int n2 = 4;
for (int i = 0; i < 4; i++) {
ps1[i].x = p[i * 2];
ps1[i].y = p[i * 2 + 1];
ps2[i].x = q[i * 2];
ps2[i].y = q[i * 2 + 1];
}
float inter_area = intersectArea(ps1, n1, ps2, n2);
float union_area = fabs(area(ps1, n1)) + fabs(area(ps2, n2)) - inter_area;
float iou = 0;
if (union_area == 0) {
iou = (inter_area + 1) / (union_area + 1);
} else {
iou = inter_area / union_area;
}
return iou;
}
__global__ void poly_nms_kernel(const int n_polys, const float nms_overlap_thresh,
const float *dev_polys, unsigned long long *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
const int row_size =
min(n_polys - row_start * threadsPerBlock, threadsPerBlock);
const int cols_size =
min(n_polys - col_start * threadsPerBlock, threadsPerBlock);
__shared__ float block_polys[threadsPerBlock * 9];
if (threadIdx.x < cols_size) {
block_polys[threadIdx.x * 9 + 0] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 0];
block_polys[threadIdx.x * 9 + 1] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 1];
block_polys[threadIdx.x * 9 + 2] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 2];
block_polys[threadIdx.x * 9 + 3] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 3];
block_polys[threadIdx.x * 9 + 4] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 4];
block_polys[threadIdx.x * 9 + 5] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 5];
block_polys[threadIdx.x * 9 + 6] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 6];
block_polys[threadIdx.x * 9 + 7] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 7];
block_polys[threadIdx.x * 9 + 8] =
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 8];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_polys + cur_box_idx * 9;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < cols_size; i++) {
if (devPolyIoU(cur_box, block_polys + i * 9) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = THCCeilDiv(n_polys, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
// boxes is a N x 9 tensor
at::Tensor poly_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
at::DeviceGuard guard(boxes.device());
using scalar_t = float;
AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
auto scores = boxes.select(1, 8);
auto order_t = std::get<1>(scores.sort(0, /*descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t);
int boxes_num = boxes.size(0);
const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();
THCState *state = at::globalContext().lazyInitCUDA();
unsigned long long* mask_dev = NULL;
mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
poly_nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpyAsync(
&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost,
at::cuda::getCurrentCUDAStream()
));
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
THCudaFree(state, mask_dev);
return order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
order_t.device(), keep.scalar_type())});
}
import numpy as np
import torch
from torchvision.ops import nms
from utils.nms_rotated import obb_nms
class DecodeBox():
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
super(DecodeBox, self).__init__()
self.anchors = anchors
self.num_classes = num_classes
self.bbox_attrs = 5 + num_classes
self.bbox_attrs = 5 + 1 + num_classes
self.input_shape = input_shape
#-----------------------------------------------------------#
# 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
......@@ -231,6 +231,103 @@ class DecodeBox():
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output
def non_max_suppression_obb(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 - 1 # number of classes
xc = prediction[..., 5] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling no used just now
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 6), device=x.device)
v[:, :5] = l[:, 1:6] # box
v[:, 5] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 6] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
if nc == 1:
x[:, 6: 6+nc] = x[:, 5:6] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
x[:, 6:6+nc] *= x[:, 5:6] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
# box = xywh2xyxy(x[:, :4])
# _, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179]
# theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2)
theta_pred = (x[:,4:5] - 0.5) * torch.pi
# Detections matrix nx7 (xyxy,theta, conf, cls)
if multi_label:
i, j = (x[:, 6:6+nc] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 6, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 6:6+nc].max(1, keepdim=True)
x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 6:7] * (0 if agnostic else max_wh) # classes
rboxes = x[:, :5].clone()
rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class)
scores = x[:, 5] # scores
#boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
#i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
_, i = obb_nms(rboxes, scores, iou_thres) # obb NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
return output
if __name__ == "__main__":
......
'''
Author: [egrt]
Date: 2023-01-30 19:00:28
LastEditors: [egrt]
LastEditTime: 2023-01-30 19:34:35
Description: Oriented Bounding Boxes utils
'''
import numpy as np
pi = np.pi
import cv2
import torch
def gaussian_label_cpu(label, num_class, u=0, sig=4.0):
"""
转换成CSL Labels:
用高斯窗口函数根据角度θ的周期性赋予gt labels同样的周期性,使得损失函数在计算边界处时可以做到“差值很大但loss很小”;
并且使得其labels具有环形特征,能够反映各个θ之间的角度距离
Args:
label (float32):[1], theta class
num_theta_class (int): [1], theta class num
u (float32):[1], μ in gaussian function
sig (float32):[1], σ in gaussian function, which is window radius for Circular Smooth Label
Returns:
csl_label (array): [num_theta_class], gaussian function smooth label
"""
x = np.arange(-num_class/2, num_class/2)
y_sig = np.exp(-(x - u) ** 2 / (2 * sig ** 2))
index = int(num_class/2 - label)
return np.concatenate([y_sig[index:],
y_sig[:index]], axis=0)
def regular_theta(theta, mode='180', start=-pi/2):
"""
limit theta ∈ [-pi/2, pi/2)
"""
assert mode in ['360', '180']
cycle = 2 * pi if mode == '360' else pi
theta = theta - start
theta = theta % cycle
return theta + start
def poly2rbox(polys, img_size=(), num_cls_thata=180, radius=6.0, use_pi=False, use_gaussian=False):
"""
Trans poly format to rbox format.
Args:
polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
num_cls_thata (int): [1], theta class num
radius (float32): [1], window radius for Circular Smooth Label
use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180)
Returns:
use_gaussian True:
rboxes (array):
csl_labels (array): (num_gts, num_cls_thata)
elif
rboxes (array): (num_gts, [cx cy l s θ])
"""
assert polys.shape[-1] == 8
img_h, img_w = img_size[0], img_size[1]
if use_gaussian:
csl_labels = []
rboxes = []
for poly in polys:
poly = np.float32(poly.reshape(4, 2))
(x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90] # opencv>=4.5.1 若是< -90到0
angle = -angle # θ ∈ [-90, 0] # 故 rbbox2poly 中 角度再 负 了一次 定义是 ccw 逆时针
# # 两者的闭集位置进行了调换,所以在边界角度处的转换和非边界角度处的转换越有所不同。
# if angle >= 90:
# angle = angle - 180
# else:
# w, h = h, w
# angle = angle -90
theta = angle / 180 * pi # 转为pi制
# trans opencv format to longedge format θ ∈ [-pi/2, pi/2]
if w != max(w, h):
x = x / img_w
y = y / img_h
w, h = h, w
w = w / img_h
h = h / img_w
theta += pi/2
else:
w = w / img_w
h = h / img_h
x = x / img_w
y = y / img_h
theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2)
angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180)
if not use_pi: # 采用angle弧度制 θ ∈ [0, 180)
rboxes.append([x, y, w, h, angle])
else: # 采用pi制
rboxes.append([x, y, w, h, theta])
if use_gaussian:
csl_label = gaussian_label_cpu(label=angle, num_class=num_cls_thata, u=0, sig=radius)
csl_labels.append(csl_label)
if use_gaussian:
return np.array(rboxes), np.array(csl_labels)
return np.array(rboxes)
def poly2rbox_new(polys, num_cls_thata=5,angle_w=36, radius=6.0, use_pi=False, use_gaussian=False):
"""
Trans poly format to rbox format.
Args:
polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
num_cls_thata (int): [1], theta class num
radius (float32): [1], window radius for Circular Smooth Label
use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180)
Returns:
use_gaussian True:
rboxes (array):
csl_labels (array): (num_gts, num_cls_thata)
elif
rboxes (array): (num_gts, [cx cy l s θ])
"""
assert polys.shape[-1] == 8
if use_gaussian:
csl_labels = []
rboxes = []
for poly in polys:
poly = np.float32(poly.reshape(4, 2))
(x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90] # opencv>=4.5.1 若是< -90到0
angle = -angle # θ ∈ [-90, 0] # 故 rbbox2poly 中 角度再 负 了一次
# # 两者的闭集位置进行了调换,所以在边界角度处的转换和非边界角度处的转换越有所不同。
# if angle >= 90:
# angle = angle - 180
# else:
# w, h = h, w
# angle = angle -90
theta = angle / 180 * pi # 转为pi制
# trans opencv format to longedge format θ ∈ [-pi/2, pi/2]
if w != max(w, h):
w, h = h, w
theta += pi/2
theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2)
# while not pi / 2 > theta >= -pi / 2:
# if theta >= pi / 2:
# theta -= pi
# else:
# theta += pi
angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180)
if not use_pi: # 采用angle弧度制 θ ∈ [0, 180)
rboxes.append([x, y, w, h, angle])
else: # 采用pi制
rboxes.append([x, y, w, h, theta])
if use_gaussian:
csl_label = gaussian_label_cpu(label=angle, num_class=num_cls_thata, u=0, sig=radius)
csl_labels.append(csl_label)
if use_gaussian:
return np.array(rboxes), np.array(csl_labels)
return np.array(rboxes)
def rbox2poly(obboxes):
"""
Trans rbox format to poly format.
Args:
rboxes (array/tensor): (num_gts, [cx cy l s θ]) θ∈[-pi/2, pi/2)
Returns:
polys (array/tensor): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
"""
if isinstance(obboxes, torch.Tensor):
center, w, h, theta = obboxes[:, :2], obboxes[:, 2:3], obboxes[:, 3:4], obboxes[:, 4:5]
Cos, Sin = torch.cos(theta), torch.sin(theta)
vector1 = torch.cat(
(w/2 * Cos, -w/2 * Sin), dim=-1)
vector2 = torch.cat(
(-h/2 * Sin, -h/2 * Cos), dim=-1)
point1 = center + vector1 + vector2
point2 = center + vector1 - vector2
point3 = center - vector1 - vector2
point4 = center - vector1 + vector2
order = obboxes.shape[:-1]
return torch.cat(
(point1, point2, point3, point4), dim=-1).reshape(*order, 8)
else:
center, w, h, theta = np.split(obboxes, (2, 3, 4), axis=-1)
Cos, Sin = np.cos(theta), np.sin(theta)
vector1 = np.concatenate(
[w/2 * Cos, -w/2 * Sin], axis=-1)
vector2 = np.concatenate(
[-h/2 * Sin, -h/2 * Cos], axis=-1)
point1 = center + vector1 + vector2
point2 = center + vector1 - vector2
point3 = center - vector1 - vector2
point4 = center - vector1 + vector2
order = obboxes.shape[:-1]
return np.concatenate(
[point1, point2, point3, point4], axis=-1).reshape(*order, 8)
def poly2hbb(polys):
"""
Trans poly format to hbb format
Args:
rboxes (array/tensor): (num_gts, poly)
Returns:
hbboxes (array/tensor): (num_gts, [xc yc w h])
"""
assert polys.shape[-1] == 8
if isinstance(polys, torch.Tensor):
x = polys[:, 0::2] # (num, 4)
y = polys[:, 1::2]
x_max = torch.amax(x, dim=1) # (num)
x_min = torch.amin(x, dim=1)
y_max = torch.amax(y, dim=1)
y_min = torch.amin(y, dim=1)
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
h = y_max - y_min # (num)
w = x_max - x_min
x_ctr, y_ctr, w, h = x_ctr.reshape(-1, 1), y_ctr.reshape(-1, 1), w.reshape(-1, 1), h.reshape(-1, 1) # (num, 1)
hbboxes = torch.cat((x_ctr, y_ctr, w, h), dim=1)
else:
x = polys[:, 0::2] # (num, 4)
y = polys[:, 1::2]
x_max = np.amax(x, axis=1) # (num)
x_min = np.amin(x, axis=1)
y_max = np.amax(y, axis=1)
y_min = np.amin(y, axis=1)
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
h = y_max - y_min # (num)
w = x_max - x_min
x_ctr, y_ctr, w, h = x_ctr.reshape(-1, 1), y_ctr.reshape(-1, 1), w.reshape(-1, 1), h.reshape(-1, 1) # (num, 1)
hbboxes = np.concatenate((x_ctr, y_ctr, w, h), axis=1)
return hbboxes
def poly_filter(polys, h, w):
"""
Filter the poly labels which is out of the image.
Args:
polys (array): (num, 8)
Return:
keep_masks (array): (num)
"""
x = polys[:, 0::2] # (num, 4)
y = polys[:, 1::2]
x_max = np.amax(x, axis=1) # (num)
x_min = np.amin(x, axis=1)
y_max = np.amax(y, axis=1)
y_min = np.amin(y, axis=1)
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
keep_masks = (x_ctr > 0) & (x_ctr < w) & (y_ctr > 0) & (y_ctr < h)
return keep_masks
if __name__ == "__main__":
#print(np.pi)
poly = np.array([[204., 197., 273., 154., 290., 170., 217., 218.]])
print(poly2rbox_new(poly,use_pi=True))
\ No newline at end of file
......@@ -10,8 +10,8 @@ from PIL import ImageDraw, ImageFont
from nets.yolo import YoloBody
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
resize_image, show_config)
from utils.utils_bbox import DecodeBox
from utils.utils_bbox import non_max_suppression_obb
from utils.utils_rbox import rbox2poly
'''
训练自己的数据集必看注释!
'''
......@@ -84,7 +84,6 @@ class YOLO(object):
#---------------------------------------------------#
self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
#---------------------------------------------------#
# 画框设置不同的颜色
......@@ -145,19 +144,18 @@ class YOLO(object):
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = self.bbox_util.decode_box(outputs)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
results = non_max_suppression_obb(outputs, self.confidence, self.nms_iou, classes=self.num_classes)
if results[0] is None:
return image
top_label = np.array(results[0][:, 6], dtype = 'int32')
top_conf = results[0][:, 4] * results[0][:, 5]
top_boxes = results[0][:, :4]
top_label = np.array(results[0][:, 7], dtype = 'int32')
top_conf = results[0][:, 5] * results[0][:, 6]
top_rboxes = results[0][:, :5]
top_polys = rbox2poly(top_rboxes)
#---------------------------------------------------------#
# 设置字体与边框厚度
#---------------------------------------------------------#
......@@ -176,51 +174,25 @@ class YOLO(object):
classes_nums[i] = num
print("classes_nums:", classes_nums)
#---------------------------------------------------------#
# 是否进行目标的裁剪
#---------------------------------------------------------#
if crop:
for i, c in list(enumerate(top_boxes)):
top, left, bottom, right = top_boxes[i]
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
dir_save_path = "img_crop"
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
crop_image = image.crop([left, top, right, bottom])
crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
print("save crop_" + str(i) + ".png to " + dir_save_path)
#---------------------------------------------------------#
# 图像绘制
#---------------------------------------------------------#
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
poly = top_polys[i]
score = top_conf[i]
top, left, bottom, right = box
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
polygon_list = [(poly[0], poly[1]), (poly[2], poly[3]), \
(poly[4], poly[5]), (poly[6], poly[7])]
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
print(label, top, left, bottom, right)
print(label, polygon_list)
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
text_origin = np.array([poly[0], poly[1]], np.int32)
for i in range(thickness):
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
draw.polygon(xy=polygon_list, fill=(0, 0, 0), outline=self.colors[i], width=label_size)
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
del draw
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册