提交 a7c29abb 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修改解耦

上级 d59a72cb
import numpy as np
import torch
from torchvision.ops import nms
from utils.nms_rotated import obb_nms
# from utils.nms_rotated import obb_nms
class DecodeBox():
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
......@@ -23,7 +23,7 @@ class DecodeBox():
#-----------------------------------------------#
# 输入的input一共有三个,他们的shape分别是
# batch_size = 1
# batch_size, 3 * (4 + 1 + 80), 20, 20
# batch_size, 3 * (5 + 1 + 80), 20, 20
# batch_size, 255, 40, 40
# batch_size, 255, 80, 80
#-----------------------------------------------#
......@@ -64,11 +64,11 @@ class DecodeBox():
#-----------------------------------------------#
# 获得置信度,是否有物体
#-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4])
conf = torch.sigmoid(prediction[..., 5])
#-----------------------------------------------#
# 种类置信度
#-----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:])
pred_cls = torch.sigmoid(prediction[..., 6:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
......@@ -232,7 +232,7 @@ class DecodeBox():
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output
def non_max_suppression_obb(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
def non_max_suppression_obb(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
......@@ -339,7 +339,7 @@ if __name__ == "__main__":
#---------------------------------------------------#
def get_anchors_and_decode(input, input_shape, anchors, anchors_mask, num_classes):
#-----------------------------------------------#
# input batch_size, 3 * (4 + 1 + num_classes), 20, 20
# input batch_size, 3 * (5 + 1 + num_classes), 20, 20
#-----------------------------------------------#
batch_size = input.size(0)
input_height = input.size(2)
......@@ -364,7 +364,7 @@ if __name__ == "__main__":
# batch_size, 3, 20, 20, 4 + 1 + num_classes
#-----------------------------------------------#
prediction = input.view(batch_size, len(anchors_mask[2]),
num_classes + 5, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
num_classes + 6, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
#-----------------------------------------------#
# 先验框的中心位置的调整参数
......@@ -379,11 +379,11 @@ if __name__ == "__main__":
#-----------------------------------------------#
# 获得置信度,是否有物体 0 - 1
#-----------------------------------------------#
conf = torch.sigmoid(prediction[..., 4])
conf = torch.sigmoid(prediction[..., 5])
#-----------------------------------------------#
# 种类置信度 0 - 1
#-----------------------------------------------#
pred_cls = torch.sigmoid(prediction[..., 5:])
pred_cls = torch.sigmoid(prediction[..., 6:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
......@@ -498,7 +498,7 @@ if __name__ == "__main__":
plt.show()
#
feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 255, 20, 20])).float()
feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 258, 20, 20])).float()
anchors = np.array([[116, 90], [156, 198], [373, 326], [30,61], [62,45], [59,119], [10,13], [16,30], [33,23]])
anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
get_anchors_and_decode(feat, [640, 640], anchors, anchors_mask, 80)
......@@ -10,7 +10,7 @@ from PIL import ImageDraw, ImageFont
from nets.yolo import YoloBody
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
resize_image, show_config)
from utils.utils_bbox import non_max_suppression_obb
from utils.utils_bbox import DecodeBox
from utils.utils_rbox import rbox2poly
'''
训练自己的数据集必看注释!
......@@ -84,7 +84,7 @@ class YOLO(object):
#---------------------------------------------------#
self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
......@@ -144,10 +144,11 @@ class YOLO(object):
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = self.bbox_util.decode_box(outputs)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression_obb(outputs, self.confidence, self.nms_iou, classes=self.num_classes)
results = self.bbox_util.non_max_suppression_obb(torch.cat(outputs, 1), self.confidence, self.nms_iou, classes=self.num_classes)
if results[0] is None:
return image
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册