提交 8b7c5d20 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

poly2rbox

上级 a7c29abb
因为 它太大了无法显示 source diff 。你可以改为 查看blob
此差异已折叠。
...@@ -41,7 +41,7 @@ if __name__ == "__main__": ...@@ -41,7 +41,7 @@ if __name__ == "__main__":
# Cuda 是否使用Cuda # Cuda 是否使用Cuda
# 没有GPU可以设置成False # 没有GPU可以设置成False
#---------------------------------# #---------------------------------#
Cuda = True Cuda = False
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
# distributed 用于指定是否使用单机多卡分布式运行 # distributed 用于指定是否使用单机多卡分布式运行
# 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
......
...@@ -76,7 +76,6 @@ class YoloDataset(Dataset): ...@@ -76,7 +76,6 @@ class YoloDataset(Dataset):
#---------------------------------------------------# #---------------------------------------------------#
# box[:, 2:4] = box[:, 2:4] - box[:, 0:2] # box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
# box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2 # box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
#---------------------------------------------------# #---------------------------------------------------#
# 调整顺序,符合训练的格式 # 调整顺序,符合训练的格式
# labels_out中序号为0的部分在collect时处理 # labels_out中序号为0的部分在collect时处理
...@@ -105,6 +104,12 @@ class YoloDataset(Dataset): ...@@ -105,6 +104,12 @@ class YoloDataset(Dataset):
# 获得预测框 # 获得预测框
#------------------------------# #------------------------------#
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
#------------------------------#
# 将polygon转换为rbox
#------------------------------#
rbox = np.zeros((box.shape[0], 6))
rbox[..., :5] = poly2rbox(box[..., :8], (h, w), use_pi=True)
rbox[..., 5] = box[..., 8]
image = image.resize((w,h), Image.BICUBIC) image = image.resize((w,h), Image.BICUBIC)
image_data = np.array(image, np.float32) image_data = np.array(image, np.float32)
......
import numpy as np import numpy as np
import torch import torch
import math
from torchvision.ops import nms from torchvision.ops import nms
# from utils.nms_rotated import obb_nms from utils.nms_rotated import obb_nms
class DecodeBox(): class DecodeBox():
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]): def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
super(DecodeBox, self).__init__() super(DecodeBox, self).__init__()
self.anchors = anchors self.anchors = anchors
self.num_classes = num_classes self.num_classes = num_classes
self.bbox_attrs = 5 + 1 + num_classes self.bbox_attrs = 6 + num_classes
self.input_shape = input_shape self.input_shape = input_shape
#-----------------------------------------------------------# #-----------------------------------------------------------#
# 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401] # 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
...@@ -62,6 +63,10 @@ class DecodeBox(): ...@@ -62,6 +63,10 @@ class DecodeBox():
w = torch.sigmoid(prediction[..., 2]) w = torch.sigmoid(prediction[..., 2])
h = torch.sigmoid(prediction[..., 3]) h = torch.sigmoid(prediction[..., 3])
#-----------------------------------------------# #-----------------------------------------------#
# 获取旋转角度
#-----------------------------------------------#
angle = torch.sigmoid(prediction[..., 4])
#-----------------------------------------------#
# 获得置信度,是否有物体 # 获得置信度,是否有物体
#-----------------------------------------------# #-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 5]) conf = torch.sigmoid(prediction[..., 5])
...@@ -105,17 +110,17 @@ class DecodeBox(): ...@@ -105,17 +110,17 @@ class DecodeBox():
pred_boxes[..., 1] = y.data * 2. - 0.5 + grid_y pred_boxes[..., 1] = y.data * 2. - 0.5 + grid_y
pred_boxes[..., 2] = (w.data * 2) ** 2 * anchor_w pred_boxes[..., 2] = (w.data * 2) ** 2 * anchor_w
pred_boxes[..., 3] = (h.data * 2) ** 2 * anchor_h pred_boxes[..., 3] = (h.data * 2) ** 2 * anchor_h
pred_theta = (angle.data - 0.5) * math.pi
#----------------------------------------------------------# #----------------------------------------------------------#
# 将输出结果归一化成小数的形式 # 将输出结果归一化成小数的形式
#----------------------------------------------------------# #----------------------------------------------------------#
_scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor) _scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale, output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale, pred_theta.view(batch_size, -1, 1),
conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1) conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
outputs.append(output.data) outputs.append(output.data)
return outputs return outputs
def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image): def yolo_correct_boxes(self, box_xy, box_wh, angle, input_shape, image_shape, letterbox_image):
#-----------------------------------------------------------------# #-----------------------------------------------------------------#
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘 # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
#-----------------------------------------------------------------# #-----------------------------------------------------------------#
...@@ -136,23 +141,16 @@ class DecodeBox(): ...@@ -136,23 +141,16 @@ class DecodeBox():
box_yx = (box_yx - offset) * scale box_yx = (box_yx - offset) * scale
box_hw *= scale box_hw *= scale
box_mins = box_yx - (box_hw / 2.) box_xy = box_yx[..., ::-1] * image_shape
box_maxes = box_yx + (box_hw / 2.) box_wh = box_hw[..., ::-1] * image_shape
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
boxes *= np.concatenate([image_shape, image_shape], axis=-1) boxes = np.concatenate([box_xy[..., 0:1], box_xy[..., 1:2], box_wh[..., 0:1], box_wh[..., 1:2], angle[..., 0:1] ], axis=-1)
return boxes return boxes
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
#----------------------------------------------------------# #----------------------------------------------------------#
# 将预测结果的格式转换成左上角右下角的格式。
# prediction [batch_size, num_anchors, 85] # prediction [batch_size, num_anchors, 85]
#----------------------------------------------------------# #----------------------------------------------------------#
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))] output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction): for i, image_pred in enumerate(prediction):
...@@ -161,13 +159,12 @@ class DecodeBox(): ...@@ -161,13 +159,12 @@ class DecodeBox():
# class_conf [num_anchors, 1] 种类置信度 # class_conf [num_anchors, 1] 种类置信度
# class_pred [num_anchors, 1] 种类 # class_pred [num_anchors, 1] 种类
#----------------------------------------------------------# #----------------------------------------------------------#
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) class_conf, class_pred = torch.max(image_pred[:, 6:6 + num_classes], 1, keepdim=True)
#----------------------------------------------------------# #----------------------------------------------------------#
# 利用置信度进行第一轮筛选 # 利用置信度进行第一轮筛选
#----------------------------------------------------------# #----------------------------------------------------------#
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() conf_mask = (image_pred[:, 5] * class_conf[:, 0] >= conf_thres).squeeze()
#----------------------------------------------------------# #----------------------------------------------------------#
# 根据置信度进行预测结果的筛选 # 根据置信度进行预测结果的筛选
#----------------------------------------------------------# #----------------------------------------------------------#
...@@ -177,10 +174,10 @@ class DecodeBox(): ...@@ -177,10 +174,10 @@ class DecodeBox():
if not image_pred.size(0): if not image_pred.size(0):
continue continue
#-------------------------------------------------------------------------# #-------------------------------------------------------------------------#
# detections [num_anchors, 7] # detections [num_anchors, 8]
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred # 8的内容为:x, y, w, h, angle, obj_conf, class_conf, class_pred
#-------------------------------------------------------------------------# #-------------------------------------------------------------------------#
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1) detections = torch.cat((image_pred[:, :6], class_conf.float(), class_pred.float()), 1)
#------------------------------------------# #------------------------------------------#
# 获得预测结果中包含的所有种类 # 获得预测结果中包含的所有种类
...@@ -201,9 +198,9 @@ class DecodeBox(): ...@@ -201,9 +198,9 @@ class DecodeBox():
# 使用官方自带的非极大抑制会速度更快一些! # 使用官方自带的非极大抑制会速度更快一些!
# 筛选出一定区域内,属于同一种类得分最大的框 # 筛选出一定区域内,属于同一种类得分最大的框
#------------------------------------------# #------------------------------------------#
keep = nms( _, keep = obb_nms(
detections_class[:, :4], detections_class[:, :5],
detections_class[:, 4] * detections_class[:, 5], detections_class[:, 5] * detections_class[:, 6],
nms_thres nms_thres
) )
max_detections = detections_class[keep] max_detections = detections_class[keep]
...@@ -227,9 +224,9 @@ class DecodeBox(): ...@@ -227,9 +224,9 @@ class DecodeBox():
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
if output[i] is not None: if output[i] is not None:
output[i] = output[i].cpu().numpy() output[i] = output[i].cpu().numpy()
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] box_xy, box_wh, angle = output[i][:, 0:2], output[i][:, 2:4], output[i][:, 4:5]
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) output[i][:, :5] = self.yolo_correct_boxes(box_xy, box_wh, angle, input_shape, image_shape, letterbox_image)
return output return output
def non_max_suppression_obb(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, def non_max_suppression_obb(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
......
...@@ -56,7 +56,10 @@ def convert_annotation(year, image_id, list_file): ...@@ -56,7 +56,10 @@ def convert_annotation(year, image_id, list_file):
continue continue
cls_id = classes.index(cls) cls_id = classes.index(cls)
xmlbox = obj.find('rotated_bndbox') xmlbox = obj.find('rotated_bndbox')
b = (int(float(xmlbox.find('rotated_bbox_cx').text)), int(float(xmlbox.find('rotated_bbox_cy').text)), int(float(xmlbox.find('rotated_bbox_w').text)), int(float(xmlbox.find('rotated_bbox_h').text)), int(float(xmlbox.find('rotated_bbox_theta').text))) b = (int(float(xmlbox.find('x1').text)), int(float(xmlbox.find('y1').text)), \
int(float(xmlbox.find('x2').text)), int(float(xmlbox.find('y2').text)), \
int(float(xmlbox.find('x3').text)), int(float(xmlbox.find('y3').text)), \
int(float(xmlbox.find('x4').text)), int(float(xmlbox.find('y4').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
nums[classes.index(cls)] = nums[classes.index(cls)] + 1 nums[classes.index(cls)] = nums[classes.index(cls)] + 1
......
...@@ -25,8 +25,8 @@ class YOLO(object): ...@@ -25,8 +25,8 @@ class YOLO(object):
# 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。 # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
#--------------------------------------------------------------------------# #--------------------------------------------------------------------------#
"model_path" : 'model_data/yolov7_weights.pth', "model_path" : 'logs/best_epoch_weights.pth',
"classes_path" : 'model_data/coco_classes.txt', "classes_path" : 'model_data/ssdd_classes.txt',
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
# anchors_path代表先验框对应的txt文件,一般不修改。 # anchors_path代表先验框对应的txt文件,一般不修改。
# anchors_mask用于帮助代码找到对应的先验框,一般不修改。 # anchors_mask用于帮助代码找到对应的先验框,一般不修改。
...@@ -46,7 +46,7 @@ class YOLO(object): ...@@ -46,7 +46,7 @@ class YOLO(object):
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
# 只有得分大于置信度的预测框会被保留下来 # 只有得分大于置信度的预测框会被保留下来
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
"confidence" : 0.5, "confidence" : 0.05,
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
# 非极大抑制所用到的nms_iou大小 # 非极大抑制所用到的nms_iou大小
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
...@@ -60,7 +60,7 @@ class YOLO(object): ...@@ -60,7 +60,7 @@ class YOLO(object):
# 是否使用Cuda # 是否使用Cuda
# 没有GPU可以设置成False # 没有GPU可以设置成False
#-------------------------------# #-------------------------------#
"cuda" : True, "cuda" : False,
} }
@classmethod @classmethod
...@@ -148,7 +148,8 @@ class YOLO(object): ...@@ -148,7 +148,8 @@ class YOLO(object):
#---------------------------------------------------------# #---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制 # 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------# #---------------------------------------------------------#
results = self.bbox_util.non_max_suppression_obb(torch.cat(outputs, 1), self.confidence, self.nms_iou, classes=self.num_classes) results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
if results[0] is None: if results[0] is None:
return image return image
...@@ -179,12 +180,10 @@ class YOLO(object): ...@@ -179,12 +180,10 @@ class YOLO(object):
#---------------------------------------------------------# #---------------------------------------------------------#
for i, c in list(enumerate(top_label)): for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)] predicted_class = self.class_names[int(c)]
poly = top_polys[i] poly = top_polys[i].astype(np.int32)
score = top_conf[i] score = top_conf[i]
polygon_list = [(poly[0], poly[1]), (poly[2], poly[3]), \ polygon_list = list(poly)
(poly[4], poly[5]), (poly[6], poly[7])]
label = '{} {:.2f}'.format(predicted_class, score) label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font) label_size = draw.textsize(label, font)
...@@ -193,7 +192,7 @@ class YOLO(object): ...@@ -193,7 +192,7 @@ class YOLO(object):
text_origin = np.array([poly[0], poly[1]], np.int32) text_origin = np.array([poly[0], poly[1]], np.int32)
draw.polygon(xy=polygon_list, fill=(0, 0, 0), outline=self.colors[i], width=label_size) draw.polygon(xy=polygon_list, outline=self.colors[c])
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font) draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
del draw del draw
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册