提交 44a55f79 编写于 作者: F FlyingQianMM

visualize precision-recall curve

上级 c6616e00
......@@ -15,7 +15,11 @@
import os
import cv2
import numpy as np
#import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from .detection_eval import fixed_linspace, backup_linspace, loadRes
def visualize_detection(image, result, threshold=0.5, save_dir=None):
......@@ -160,3 +164,127 @@ def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
img_array[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(img_array.astype('uint8'))
return image
def draw_pr_curve(eval_details_file=None,
gt=None,
pred_bbox=None,
pred_mask=None,
iou_thresh=0.5,
save_dir='./'):
if eval_details_file is not None:
import json
with open(eval_details_file, 'r') as f:
eval_details = json.load(f)
pred_bbox = eval_details['bbox']
if 'mask' in eval_details:
pred_mask = eval_details['mask']
gt = eval_details['gt']
if gt is None or pred_bbox is None:
raise Exception(
"gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
)
if pred_bbox is not None and len(pred_bbox) == 0:
raise Exception("There is no predicted bbox.")
if pred_mask is not None and len(pred_mask) == 0:
raise Exception("There is no predicted mask.")
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco = COCO()
coco.dataset = gt
coco.createIndex()
def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
p = coco_gt.params
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
if ap == 1:
# dimension of precision: [TxRxKxAxM]
s = coco_gt.eval['precision']
# IoU
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:, :, :, aind, mind]
else:
# dimension of recall: [TxKxAxM]
s = coco_gt.eval['recall']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:, :, aind, mind]
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
return mean_s
def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
from pycocotools.cocoeval import COCOeval
coco_dt = loadRes(coco_gt, coco_dt)
np.linspace = fixed_linspace
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.params.iouThrs = np.linspace(
iou_thresh, iou_thresh, 1, endpoint=True)
np.linspace = backup_linspace
coco_eval.evaluate()
coco_eval.accumulate()
stats = _summarize(coco_eval, iouThr=iou_thresh)
catIds = coco_gt.getCatIds()
if len(catIds) != coco_eval.eval['precision'].shape[2]:
raise Exception(
"The category number must be same as the third dimension of precisions."
)
x = np.arange(0.0, 1.01, 0.01)
color_map = get_color_map_list(256)[1:256]
plt.subplot(1, 2, 1)
plt.title(style + " precision-recall IoU={}".format(iou_thresh))
plt.xlabel("recall")
plt.ylabel("precision")
plt.xlim(0, 1.01)
plt.ylim(0, 1.01)
plt.grid(linestyle='--', linewidth=1)
plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
my_x_ticks = np.arange(0, 1.01, 0.1)
my_y_ticks = np.arange(0, 1.01, 0.1)
plt.xticks(my_x_ticks, fontsize=5)
plt.yticks(my_y_ticks, fontsize=5)
for idx, catId in enumerate(catIds):
pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
precision = pr_array[pr_array > -1]
ap = np.mean(precision) if precision.size else float('nan')
nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
float(ap * 100))
color = tuple(color_map[idx])
color = [float(c) / 255 for c in color]
color.append(0.75)
plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
plt.legend(loc="lower left", fontsize=5)
plt.subplot(1, 2, 2)
plt.title(style + " score-recall IoU={}".format(iou_thresh))
plt.xlabel('recall')
plt.ylabel('score')
plt.xlim(0, 1.01)
plt.ylim(0, 1.01)
plt.grid(linestyle='--', linewidth=1)
plt.xticks(my_x_ticks, fontsize=5)
plt.yticks(my_y_ticks, fontsize=5)
for idx, catId in enumerate(catIds):
nm = coco_gt.loadCats(catId)[0]['name']
sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
color = tuple(color_map[idx])
color = [float(c) / 255 for c in color]
color.append(0.75)
plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
plt.legend(loc="lower right", fontsize=5)
plt.savefig(
os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
style, iou_thresh)),
dpi=800)
plt.close()
cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
if pred_mask is not None:
cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册