import datetime import os import torch import matplotlib matplotlib.use('Agg') import scipy.signal from matplotlib import pyplot as plt from torch.utils.tensorboard import SummaryWriter from utils.utils_rbox import rbox2poly, poly2hbb import shutil import numpy as np from PIL import Image from tqdm import tqdm from .utils import cvtColor, preprocess_input, resize_image from .utils_bbox import DecodeBox from .utils_map import get_coco_map, get_map class LossHistory(): def __init__(self, log_dir, model, input_shape): self.log_dir = log_dir self.losses = [] self.val_loss = [] os.makedirs(self.log_dir) self.writer = SummaryWriter(self.log_dir) try: dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) self.writer.add_graph(model, dummy_input) except: pass def append_loss(self, epoch, loss, val_loss): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) self.losses.append(loss) self.val_loss.append(val_loss) with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: f.write(str(loss)) f.write("\n") with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: f.write(str(val_loss)) f.write("\n") self.writer.add_scalar('loss', loss, epoch) self.writer.add_scalar('val_loss', val_loss, epoch) self.loss_plot() def loss_plot(self): iters = range(len(self.losses)) plt.figure() plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') try: if len(self.losses) < 25: num = 5 else: num = 15 plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') except: pass plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend(loc="upper right") plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) plt.cla() plt.close("all") class EvalCallback(): def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \ map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=False, MINOVERLAP=0.5, eval_flag=True, period=1): super(EvalCallback, self).__init__() self.net = net self.input_shape = input_shape self.anchors = anchors self.anchors_mask = anchors_mask self.class_names = class_names self.num_classes = num_classes self.val_lines = val_lines self.log_dir = log_dir self.cuda = cuda self.map_out_path = map_out_path self.max_boxes = max_boxes self.confidence = confidence self.nms_iou = nms_iou self.letterbox_image = letterbox_image self.MINOVERLAP = MINOVERLAP self.eval_flag = eval_flag self.period = period self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask) self.maps = [0] self.epoches = [0] if self.eval_flag: with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: f.write(str(0)) f.write("\n") def get_map_txt(self, image_id, image, class_names, map_out_path): f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8') image_shape = np.array(np.shape(image)[0:2]) #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize # 也可以直接resize进行识别 #---------------------------------------------------------# image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) #---------------------------------------------------------# # 添加上batch_size维度 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------------# # 将图像输入网络当中进行预测! #---------------------------------------------------------# outputs = self.net(images) outputs = self.bbox_util.decode_box(outputs) #---------------------------------------------------------# # 将预测框进行堆叠,然后进行非极大抑制 #---------------------------------------------------------# results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) if results[0] is None: return top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] top_rboxes[:, [0, 2]] *= image_shape[1] top_rboxes[:, [1, 3]] *= image_shape[0] top_polys = rbox2poly(top_rboxes) top_hbbs = poly2hbb(top_polys) top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] top_hbbs = top_hbbs[top_100] top_conf = top_conf[top_100] top_label = top_label[top_100] for i, c in list(enumerate(top_label)): predicted_class = self.class_names[int(c)] hbb = top_hbbs[i] score = str(top_conf[i]) xc, yc, w, h = hbb left = xc - w/2 top = yc - h/2 right = xc + w/2 bottom = yc + h/2 if predicted_class not in class_names: continue f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) f.close() return def on_epoch_end(self, epoch, model_eval): if epoch % self.period == 0 and self.eval_flag: self.net = model_eval if not os.path.exists(self.map_out_path): os.makedirs(self.map_out_path) if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): os.makedirs(os.path.join(self.map_out_path, "ground-truth")) if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): os.makedirs(os.path.join(self.map_out_path, "detection-results")) print("Get map.") for annotation_line in tqdm(self.val_lines): line = annotation_line.split() image_id = os.path.basename(line[0]).split('.')[0] #------------------------------# # 读取图像并转换成RGB图像 #------------------------------# image = Image.open(line[0]) #------------------------------# # 获得预测框 #------------------------------# gt_boxes = np.array([np.array(list(map(float,box.split(',')))) for box in line[1:]]) #------------------------------# # 将polygon转换为hbb #------------------------------# hbbs = np.zeros((gt_boxes.shape[0], 5)) hbbs[..., :4] = poly2hbb(gt_boxes[..., :8]) hbbs[..., 4] = gt_boxes[..., 8] #------------------------------# # 获得预测txt #------------------------------# self.get_map_txt(image_id, image, self.class_names, self.map_out_path) #------------------------------# # 获得真实框txt #------------------------------# with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: for hbb in hbbs: xc, yc, w, h, obj = hbb left = xc - w/2 top = yc - h/2 right = xc + w/2 bottom = yc + h/2 obj_name = self.class_names[int(obj)] new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) print("Calculate Map.") try: temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] except: temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) self.maps.append(temp_map) self.epoches.append(epoch) with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: f.write(str(temp_map)) f.write("\n") plt.figure() plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Map %s'%str(self.MINOVERLAP)) plt.title('A Map Curve') plt.legend(loc="upper right") plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) plt.cla() plt.close("all") print("Get map done.") shutil.rmtree(self.map_out_path)