提交 cfca9ce1 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修复解耦

上级 be8c0376
......@@ -120,33 +120,6 @@ class DecodeBox():
outputs.append(output.data)
return outputs
def yolo_correct_boxes(self, box_xy, box_wh, angle, input_shape, image_shape, letterbox_image):
#-----------------------------------------------------------------#
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘
#-----------------------------------------------------------------#
box_yx = box_xy[..., ::-1]
box_hw = box_wh[..., ::-1]
input_shape = np.array(input_shape)
image_shape = np.array(image_shape)
if letterbox_image:
#-----------------------------------------------------------------#
# 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
# new_shape指的是宽高缩放情况
#-----------------------------------------------------------------#
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
offset = (input_shape - new_shape)/2./input_shape
scale = input_shape/new_shape
box_yx = (box_yx - offset) * scale
box_hw *= scale
box_xy = box_yx[..., ::-1] * image_shape
box_wh = box_hw[..., ::-1] * image_shape
boxes = np.concatenate([box_xy[..., 0:1], box_xy[..., 1:2], box_wh[..., 0:1], box_wh[..., 1:2], angle[..., 0:1] ], axis=-1)
return boxes
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
#----------------------------------------------------------#
# prediction [batch_size, num_anchors, 85]
......@@ -205,28 +178,11 @@ class DecodeBox():
)
max_detections = detections_class[keep]
# # 按照存在物体的置信度排序
# _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# detections_class = detections_class[conf_sort_index]
# # 进行非极大抑制
# max_detections = []
# while detections_class.size(0):
# # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
# max_detections.append(detections_class[0].unsqueeze(0))
# if len(detections_class) == 1:
# break
# ious = bbox_iou(max_detections[-1], detections_class[1:])
# detections_class = detections_class[1:][ious < nms_thres]
# # 堆叠
# max_detections = torch.cat(max_detections).data
# Add max detections to outputs
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
if output[i] is not None:
output[i] = output[i].cpu().numpy()
box_xy, box_wh, angle = output[i][:, 0:2], output[i][:, 2:4], output[i][:, 4:5]
output[i][:, :5] = self.yolo_correct_boxes(box_xy, box_wh, angle, input_shape, image_shape, letterbox_image)
output[i] = output[i].cpu().numpy()
return output
if __name__ == "__main__":
......
......@@ -158,6 +158,11 @@ class YOLO(object):
top_conf = results[0][:, 5] * results[0][:, 6]
top_rboxes = results[0][:, :5]
top_polys = rbox2poly(top_rboxes)
#---------------------------------------------------------#
# 将归一化的预测结果变为真实的预测框
#---------------------------------------------------------#
top_polys[..., [0, 2, 4, 6]] *= image_shape[0]
top_polys[..., [1, 3, 5, 7]] *= image_shape[1]
#---------------------------------------------------------#
# 设置字体与边框厚度
#---------------------------------------------------------#
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册