get_dr_txt.py 5.9 KB
Newer Older
J
JiaQi Xu 已提交
1 2 3 4 5 6 7
#-------------------------------------#
#       mAP所需文件计算代码
#       具体教程请查看Bilibili
#       Bubbliiiing
#-------------------------------------#
import colorsys
import os
B
Bubbliiiing 已提交
8 9 10

import cv2
import numpy as np
J
JiaQi Xu 已提交
11 12
import torch
import torch.backends.cudnn as cudnn
B
Bubbliiiing 已提交
13 14
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
J
JiaQi Xu 已提交
15
from torch.autograd import Variable
B
Bubbliiiing 已提交
16
from tqdm import tqdm
B
Bubbliiiing 已提交
17 18 19 20 21 22 23

from nets.yolo4 import YoloBody
from utils.utils import (DecodeBox, bbox_iou, letterbox_image,
                         non_max_suppression, yolo_correct_boxes)
from yolo import YOLO


J
JiaQi Xu 已提交
24 25 26 27 28
class mAP_Yolo(YOLO):
    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self,image_id,image):
B
Bubbliiiing 已提交
29
        self.confidence = 0.01
B
Bubbliiiing 已提交
30
        self.iou = 0.5
J
JiaQi Xu 已提交
31 32 33
        f = open("./input/detection-results/"+image_id+".txt","w") 
        image_shape = np.array(np.shape(image)[0:2])

B
Bubbliiiing 已提交
34 35
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
B
Bubbliiiing 已提交
36
        #   也可以直接resize进行识别
B
Bubbliiiing 已提交
37
        #---------------------------------------------------------#
B
Bubbliiiing 已提交
38 39 40 41 42
        if self.letterbox_image:
            crop_img = np.array(letterbox_image(image, (self.model_image_size[1],self.model_image_size[0])))
        else:
            crop_img = image.convert('RGB')
            crop_img = crop_img.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC)
B
Bubbliiiing 已提交
43
        photo = np.array(crop_img,dtype = np.float32) / 255.0
J
JiaQi Xu 已提交
44
        photo = np.transpose(photo, (2, 0, 1))
B
Bubbliiiing 已提交
45 46 47 48
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        images = [photo]
J
JiaQi Xu 已提交
49 50

        with torch.no_grad():
B
Bubbliiiing 已提交
51
            images = torch.from_numpy(np.asarray(images))
J
JiaQi Xu 已提交
52 53
            if self.cuda:
                images = images.cuda()
B
Bubbliiiing 已提交
54 55 56 57

            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
J
JiaQi Xu 已提交
58
            outputs = self.net(images)
B
Bubbliiiing 已提交
59 60 61 62 63 64 65 66 67 68 69
            output_list = []
            for i in range(3):
                output_list.append(self.yolo_decodes[i](outputs[i]))

            #---------------------------------------------------------#
            #   将预测框进行堆叠,然后进行非极大抑制
            #---------------------------------------------------------#
            output = torch.cat(output_list, 1)
            batch_detections = non_max_suppression(output, len(self.class_names),
                                                    conf_thres=self.confidence,
                                                    nms_thres=self.iou)
J
JiaQi Xu 已提交
70

B
Bubbliiiing 已提交
71 72 73 74 75 76
            #---------------------------------------------------------#
            #   如果没有检测出物体,返回原图
            #---------------------------------------------------------#
            try:
                batch_detections = batch_detections[0].cpu().numpy()
            except:
B
Bubbliiiing 已提交
77
                return 
J
JiaQi Xu 已提交
78
            
B
Bubbliiiing 已提交
79 80 81 82 83 84 85 86
            #---------------------------------------------------------#
            #   对预测框进行得分筛选
            #---------------------------------------------------------#
            top_index = batch_detections[:,4] * batch_detections[:,5] > self.confidence
            top_conf = batch_detections[top_index,4]*batch_detections[top_index,5]
            top_label = np.array(batch_detections[top_index,-1],np.int32)
            top_bboxes = np.array(batch_detections[top_index,:4])
            top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)
J
JiaQi Xu 已提交
87

B
Bubbliiiing 已提交
88 89 90 91 92
            #-----------------------------------------------------------------#
            #   在图像传入网络预测前会进行letterbox_image给图像周围添加灰条
            #   因此生成的top_bboxes是相对于有灰条的图像的
            #   我们需要对其进行修改,去除灰条的部分。
            #-----------------------------------------------------------------#
B
Bubbliiiing 已提交
93 94 95 96 97 98 99 100 101
            if self.letterbox_image:
                boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)
            else:
                top_xmin = top_xmin / self.model_image_size[1] * image_shape[1]
                top_ymin = top_ymin / self.model_image_size[0] * image_shape[0]
                top_xmax = top_xmax / self.model_image_size[1] * image_shape[1]
                top_ymax = top_ymax / self.model_image_size[0] * image_shape[0]
                boxes = np.concatenate([top_ymin,top_xmin,top_ymax,top_xmax], axis=-1)
                
J
JiaQi Xu 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
        for i, c in enumerate(top_label):
            predicted_class = self.class_names[c]
            score = str(top_conf[i])

            top, left, bottom, right = boxes[i]
            f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))

        f.close()
        return 

yolo = mAP_Yolo()
image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split()

if not os.path.exists("./input"):
    os.makedirs("./input")
if not os.path.exists("./input/detection-results"):
    os.makedirs("./input/detection-results")
if not os.path.exists("./input/images-optional"):
    os.makedirs("./input/images-optional")


B
Bubbliiiing 已提交
123
for image_id in tqdm(image_ids):
J
JiaQi Xu 已提交
124 125 126 127 128 129
    image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg"
    image = Image.open(image_path)
    # 开启后在之后计算mAP可以可视化
    # image.save("./input/images-optional/"+image_id+".jpg")
    yolo.detect_image(image_id,image)
    
J
JiaQi Xu 已提交
130
print("Conversion completed!")