From ec8413215763520061437a5591a5ce815524dfea Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Thu, 10 Sep 2020 09:53:36 +0800 Subject: [PATCH] Update utils.py --- utils/utils.py | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index f5060cf..d4d8443 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -5,10 +5,11 @@ import time import torch import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable import numpy as np -from PIL import Image, ImageDraw, ImageFont import matplotlib.pyplot as plt +from torch.autograd import Variable +from PIL import Image, ImageDraw, ImageFont +from torchvision.ops import nms class DecodeBox(nn.Module): def __init__(self, anchors, num_classes, img_size): @@ -225,24 +226,37 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): if prediction.is_cuda: unique_labels = unique_labels.cuda() + detections = detections.cuda() for c in unique_labels: # 获得某一类初步筛选后全部的预测结果 detections_class = detections[detections[:, -1] == c] - # 按照存在物体的置信度排序 - _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) - detections_class = detections_class[conf_sort_index] - # 进行非极大抑制 - max_detections = [] - while detections_class.size(0): - # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 - max_detections.append(detections_class[0].unsqueeze(0)) - if len(detections_class) == 1: - break - ious = bbox_iou(max_detections[-1], detections_class[1:]) - detections_class = detections_class[1:][ious < nms_thres] - # 堆叠 - max_detections = torch.cat(max_detections).data + + #------------------------------------------# + # 使用官方自带的非极大抑制会速度更快一些! + #------------------------------------------# + keep = nms( + detections_class[:, :4], + detections_class[:, 4], + nms_thres + ) + max_detections = detections_class[keep] + + # # 按照存在物体的置信度排序 + # _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) + # detections_class = detections_class[conf_sort_index] + # # 进行非极大抑制 + # max_detections = [] + # while detections_class.size(0): + # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 + # max_detections.append(detections_class[0].unsqueeze(0)) + # if len(detections_class) == 1: + # break + # ious = bbox_iou(max_detections[-1], detections_class[1:]) + # detections_class = detections_class[1:][ious < nms_thres] + # # 堆叠 + # max_detections = torch.cat(max_detections).data + # Add max detections to outputs output[image_i] = max_detections if output[image_i] is None else torch.cat( (output[image_i], max_detections)) -- GitLab