diff --git a/utils/utils_bbox.py b/utils/utils_bbox.py index d1b0954539a32db287e32fabc0ce217ad589e4b1..6c203dd285d0f91f936aea9962418df20fad81e4 100644 --- a/utils/utils_bbox.py +++ b/utils/utils_bbox.py @@ -120,33 +120,6 @@ class DecodeBox(): outputs.append(output.data) return outputs - def yolo_correct_boxes(self, box_xy, box_wh, angle, input_shape, image_shape, letterbox_image): - #-----------------------------------------------------------------# - # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 - #-----------------------------------------------------------------# - box_yx = box_xy[..., ::-1] - box_hw = box_wh[..., ::-1] - input_shape = np.array(input_shape) - image_shape = np.array(image_shape) - - if letterbox_image: - #-----------------------------------------------------------------# - # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 - # new_shape指的是宽高缩放情况 - #-----------------------------------------------------------------# - new_shape = np.round(image_shape * np.min(input_shape/image_shape)) - offset = (input_shape - new_shape)/2./input_shape - scale = input_shape/new_shape - - box_yx = (box_yx - offset) * scale - box_hw *= scale - - box_xy = box_yx[..., ::-1] * image_shape - box_wh = box_hw[..., ::-1] * image_shape - - boxes = np.concatenate([box_xy[..., 0:1], box_xy[..., 1:2], box_wh[..., 0:1], box_wh[..., 1:2], angle[..., 0:1] ], axis=-1) - return boxes - def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): #----------------------------------------------------------# # prediction [batch_size, num_anchors, 85] @@ -205,28 +178,11 @@ class DecodeBox(): ) max_detections = detections_class[keep] - # # 按照存在物体的置信度排序 - # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) - # detections_class = detections_class[conf_sort_index] - # # 进行非极大抑制 - # max_detections = [] - # while detections_class.size(0): - # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 - # max_detections.append(detections_class[0].unsqueeze(0)) - # if len(detections_class) == 1: - # break - # ious = bbox_iou(max_detections[-1], detections_class[1:]) - # detections_class = detections_class[1:][ious < nms_thres] - # # 堆叠 - # max_detections = torch.cat(max_detections).data - # Add max detections to outputs output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) if output[i] is not None: - output[i] = output[i].cpu().numpy() - box_xy, box_wh, angle = output[i][:, 0:2], output[i][:, 2:4], output[i][:, 4:5] - output[i][:, :5] = self.yolo_correct_boxes(box_xy, box_wh, angle, input_shape, image_shape, letterbox_image) + output[i] = output[i].cpu().numpy() return output if __name__ == "__main__": diff --git a/yolo.py b/yolo.py index e739683e97874d816fc226a6db232cf45a0581b4..ab053df1c05244b01498c1dad887b46152cf9f52 100644 --- a/yolo.py +++ b/yolo.py @@ -158,6 +158,11 @@ class YOLO(object): top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] top_polys = rbox2poly(top_rboxes) + #---------------------------------------------------------# + # 将归一化的预测结果变为真实的预测框 + #---------------------------------------------------------# + top_polys[..., [0, 2, 4, 6]] *= image_shape[0] + top_polys[..., [1, 3, 5, 7]] *= image_shape[1] #---------------------------------------------------------# # 设置字体与边框厚度 #---------------------------------------------------------#