diff --git a/paddlex/cv/models/utils/visualize.py b/paddlex/cv/models/utils/visualize.py index 3a0d3ff4406f3ac3d4e457a23341fc4e65c9ccb5..d608936c832900b45e5bfc79d29d341659896799 100644 --- a/paddlex/cv/models/utils/visualize.py +++ b/paddlex/cv/models/utils/visualize.py @@ -15,7 +15,11 @@ import os import cv2 import numpy as np +#import matplotlib +#matplotlib.use('Agg') +import matplotlib.pyplot as plt from PIL import Image, ImageDraw +from .detection_eval import fixed_linspace, backup_linspace, loadRes def visualize_detection(image, result, threshold=0.5, save_dir=None): @@ -160,3 +164,127 @@ def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7): img_array[idx[0], idx[1], :] += alpha * color_mask image = Image.fromarray(img_array.astype('uint8')) return image + + +def draw_pr_curve(eval_details_file=None, + gt=None, + pred_bbox=None, + pred_mask=None, + iou_thresh=0.5, + save_dir='./'): + if eval_details_file is not None: + import json + with open(eval_details_file, 'r') as f: + eval_details = json.load(f) + pred_bbox = eval_details['bbox'] + if 'mask' in eval_details: + pred_mask = eval_details['mask'] + gt = eval_details['gt'] + if gt is None or pred_bbox is None: + raise Exception( + "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask." + ) + if pred_bbox is not None and len(pred_bbox) == 0: + raise Exception("There is no predicted bbox.") + if pred_mask is not None and len(pred_mask) == 0: + raise Exception("There is no predicted mask.") + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + coco = COCO() + coco.dataset = gt + coco.createIndex() + + def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100): + p = coco_gt.params + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = coco_gt.eval['precision'] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = coco_gt.eval['recall'] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + return mean_s + + def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'): + from pycocotools.cocoeval import COCOeval + coco_dt = loadRes(coco_gt, coco_dt) + np.linspace = fixed_linspace + coco_eval = COCOeval(coco_gt, coco_dt, style) + coco_eval.params.iouThrs = np.linspace( + iou_thresh, iou_thresh, 1, endpoint=True) + np.linspace = backup_linspace + coco_eval.evaluate() + coco_eval.accumulate() + stats = _summarize(coco_eval, iouThr=iou_thresh) + catIds = coco_gt.getCatIds() + if len(catIds) != coco_eval.eval['precision'].shape[2]: + raise Exception( + "The category number must be same as the third dimension of precisions." + ) + x = np.arange(0.0, 1.01, 0.01) + color_map = get_color_map_list(256)[1:256] + + plt.subplot(1, 2, 1) + plt.title(style + " precision-recall IoU={}".format(iou_thresh)) + plt.xlabel("recall") + plt.ylabel("precision") + plt.xlim(0, 1.01) + plt.ylim(0, 1.01) + plt.grid(linestyle='--', linewidth=1) + plt.plot([0, 1], [0, 1], 'r--', linewidth=1) + my_x_ticks = np.arange(0, 1.01, 0.1) + my_y_ticks = np.arange(0, 1.01, 0.1) + plt.xticks(my_x_ticks, fontsize=5) + plt.yticks(my_y_ticks, fontsize=5) + for idx, catId in enumerate(catIds): + pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2] + precision = pr_array[pr_array > -1] + ap = np.mean(precision) if precision.size else float('nan') + nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format( + float(ap * 100)) + color = tuple(color_map[idx]) + color = [float(c) / 255 for c in color] + color.append(0.75) + plt.plot(x, pr_array, color=color, label=nm, linewidth=1) + plt.legend(loc="lower left", fontsize=5) + + plt.subplot(1, 2, 2) + plt.title(style + " score-recall IoU={}".format(iou_thresh)) + plt.xlabel('recall') + plt.ylabel('score') + plt.xlim(0, 1.01) + plt.ylim(0, 1.01) + plt.grid(linestyle='--', linewidth=1) + plt.xticks(my_x_ticks, fontsize=5) + plt.yticks(my_y_ticks, fontsize=5) + for idx, catId in enumerate(catIds): + nm = coco_gt.loadCats(catId)[0]['name'] + sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2] + color = tuple(color_map[idx]) + color = [float(c) / 255 for c in color] + color.append(0.75) + plt.plot(x, sr_array, color=color, label=nm, linewidth=1) + plt.legend(loc="lower right", fontsize=5) + plt.savefig( + os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format( + style, iou_thresh)), + dpi=800) + plt.close() + + cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox') + if pred_mask is not None: + cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')