提交 9e2bf52f 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修改训练过程中的eval

上级 915b2ccf
...@@ -25,7 +25,7 @@ if __name__ == "__main__": ...@@ -25,7 +25,7 @@ if __name__ == "__main__":
# map_mode为3代表仅仅计算VOC_map。 # map_mode为3代表仅仅计算VOC_map。
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
#-------------------------------------------------------------------------------------------------------------------# #-------------------------------------------------------------------------------------------------------------------#
map_mode = 3 map_mode = 0
#--------------------------------------------------------------------------------------# #--------------------------------------------------------------------------------------#
# 此处的classes_path用于指定需要测量VOC_map的类别 # 此处的classes_path用于指定需要测量VOC_map的类别
# 一般情况下与训练和预测所用的classes_path一致即可 # 一般情况下与训练和预测所用的classes_path一致即可
......
...@@ -41,7 +41,7 @@ if __name__ == "__main__": ...@@ -41,7 +41,7 @@ if __name__ == "__main__":
# Cuda 是否使用Cuda # Cuda 是否使用Cuda
# 没有GPU可以设置成False # 没有GPU可以设置成False
#---------------------------------# #---------------------------------#
Cuda = False Cuda = True
#---------------------------------------------------------------------# #---------------------------------------------------------------------#
# distributed 用于指定是否使用单机多卡分布式运行 # distributed 用于指定是否使用单机多卡分布式运行
# 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
......
...@@ -7,7 +7,7 @@ matplotlib.use('Agg') ...@@ -7,7 +7,7 @@ matplotlib.use('Agg')
import scipy.signal import scipy.signal
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils.utils_rbox import rbox2poly, poly2hbb
import shutil import shutil
import numpy as np import numpy as np
...@@ -145,21 +145,31 @@ class EvalCallback(): ...@@ -145,21 +145,31 @@ class EvalCallback():
if results[0] is None: if results[0] is None:
return return
top_label = np.array(results[0][:, 6], dtype = 'int32') top_label = np.array(results[0][:, 7], dtype = 'int32')
top_conf = results[0][:, 4] * results[0][:, 5] top_conf = results[0][:, 5] * results[0][:, 6]
top_boxes = results[0][:, :4] top_rboxes = results[0][:, :5]
top_polys = rbox2poly(top_rboxes)
#---------------------------------------------------------#
# 将归一化的预测结果变为真实的预测框
#---------------------------------------------------------#
top_polys[..., [0, 2, 4, 6]] *= image_shape[1]
top_polys[..., [1, 3, 5, 7]] *= image_shape[0]
top_hbbs = poly2hbb(top_polys)
top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
top_boxes = top_boxes[top_100] top_hbbs = top_hbbs[top_100]
top_conf = top_conf[top_100] top_conf = top_conf[top_100]
top_label = top_label[top_100] top_label = top_label[top_100]
for i, c in list(enumerate(top_label)): for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)] predicted_class = self.class_names[int(c)]
box = top_boxes[i] hbb = top_hbbs[i]
score = str(top_conf[i]) score = str(top_conf[i])
top, left, bottom, right = box xc, yc, w, h = hbb
left = xc - w/2
top = yc - h/2
right = xc + w/2
bottom = yc + h/2
if predicted_class not in class_names: if predicted_class not in class_names:
continue continue
...@@ -190,6 +200,12 @@ class EvalCallback(): ...@@ -190,6 +200,12 @@ class EvalCallback():
#------------------------------# #------------------------------#
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
#------------------------------# #------------------------------#
# 将polygon转换为hbb
#------------------------------#
hbbs = np.zeros((gt_boxes.shape[0], 6))
hbbs[:, :5] = poly2hbb(gt_boxes[:, :8])
hbbs[:, 5] = gt_boxes[:, 8]
#------------------------------#
# 获得预测txt # 获得预测txt
#------------------------------# #------------------------------#
self.get_map_txt(image_id, image, self.class_names, self.map_out_path) self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
...@@ -198,8 +214,12 @@ class EvalCallback(): ...@@ -198,8 +214,12 @@ class EvalCallback():
# 获得真实框txt # 获得真实框txt
#------------------------------# #------------------------------#
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
for box in gt_boxes: for hbb in hbbs:
left, top, right, bottom, obj = box xc, yc, w, h, obj = hbb
left = xc - w/2
top = yc - h/2
right = xc + w/2
bottom = yc + h/2
obj_name = self.class_names[obj] obj_name = self.class_names[obj]
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册