未验证 提交 ece29fe5 编写于 作者: J Jason 提交者: GitHub

Merge pull request #21 from FlyingQianMM/develop_draw

generate precision-recall curve based on the eval_detail_file
......@@ -3,7 +3,7 @@ PaddleX提供了一系列模型预测和结果分析的可视化函数。
## 目标检测/实例分割预测结果可视化
```
paddlex.det.visualize(image, result, threshold=0.5, save_dir=None)
paddlex.det.visualize(image, result, threshold=0.5, save_dir='./')
```
将目标检测/实例分割模型预测得到的Box框和Mask在原图上进行可视化
......@@ -11,7 +11,7 @@ paddlex.det.visualize(image, result, threshold=0.5, save_dir=None)
> * **image** (str): 原图文件路径。
> * **result** (str): 模型预测结果。
> * **threshold**(float): score阈值,将Box置信度低于该阈值的框过滤不进行可视化。默认0.5
> * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下
> * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下。默认值为'./'。
### 使用示例
> 点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)和[测试图片](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong.jpeg)
......@@ -23,17 +23,81 @@ pdx.det.visualize('xiaoduxiong.jpeg', result, save_dir='./')
# 预测结果保存在./visualize_xiaoduxiong.jpeg
```
## 目标检测/实例分割准确率-召回率可视化
```
paddlex.det.draw_pr_curve(eval_details_file=None, gt=None, pred_bbox=None, pred_mask=None, iou_thresh=0.5, save_dir='./')
```
将目标检测/实例分割模型评估结果中各个类别的准确率和召回率的对应关系进行可视化,同时可视化召回率和置信度阈值的对应关系。
### 参数
> * **eval_details_file** (str): 模型评估结果的保存路径,包含真值信息和预测结果。默认值为None。
> * **gt** (list): 数据集的真值信息。默认值为None。
> * **pred_bbox** (list): 模型在数据集上的预测框。默认值为None。
> * **pred_mask** (list): 模型在数据集上的预测mask。默认值为None。
> * **iou_thresh** (float): 判断预测框或预测mask为真阳时的IoU阈值。默认值为0.5。
> * **save_dir** (str): 可视化结果保存路径。默认值为'./'。
**注意:**`eval_details_file`的优先级更高,只要`eval_details_file`不为None,就会从`eval_details_file`提取真值信息和预测结果做分析。当`eval_details_file`为None时,则用`gt``pred_mask``pred_mask`做分析。
### 使用示例
> 示例一:
点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)[数据集](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_ins_det.tar.gz)
```
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.det import transforms
import paddlex as pdx
eval_transforms = transforms.Compose([
transforms.Normalize(),
transforms.ResizeByShort(short_size=800, max_size=1333),
transforms.Padding(coarsest_stride=32)
])
eval_dataset = pdx.datasets.CocoDetection(
data_dir='xiaoduxiong_ins_det/JPEGImages',
ann_file='xiaoduxiong_ins_det/val.json',
transforms=eval_transforms)
model = pdx.load_model('xiaoduxiong_epoch_12')
metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=1, return_details=True)
gt = evaluate_details['gt']
bbox = evaluate_details['bbox']
mask = evaluate_details['mask']
# 分别可视化bboxmask的准召曲线
pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, pred_mask=mask, save_dir='./xiaoduxiong')
```
预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/xiaoduxiong_bbox_pr_curve(iou-0.5).png)
预测mask的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/xiaoduxiong_segm_pr_curve(iou-0.5).png)
> 示例二:
使用[yolov3_darknet53.py示例代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/detection/yolov3_darknet53.py)训练完成后,加载模型评估结果文件进行分析:
```
import paddlex as pdx
eval_details_file = 'output/yolov3_darknet53/best_model/eval_details.json'
pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
```
预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/insect_bbox_pr_curve(iou-0.5).png)
## 语义分割预测结果可视化
```
paddlex.seg.visualize(image, result, weight=0.6, save_dir=None)
paddlex.seg.visualize(image, result, weight=0.6, save_dir='./')
```
将语义分割模型预测得到的Mask在原图上进行可视化
### 参数
> * **image** (str): 原图文件路径。
> * **result** (str): 模型预测结果。
> * **weight**(float): mask可视化结果与原图权重因子,weight表示原图的权重。默认0.6
> * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下
> * **weight**(float): mask可视化结果与原图权重因子,weight表示原图的权重。默认0.6
> * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下。默认值为'./'。
### 使用示例
> 点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/cityscape_deeplab.tar.gz)和[测试图片](https://bj.bcebos.com/paddlex/datasets/city.png)
......
......@@ -15,10 +15,14 @@
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import paddlex.utils.logging as logging
from .detection_eval import fixed_linspace, backup_linspace, loadRes
def visualize_detection(image, result, threshold=0.5, save_dir=None):
def visualize_detection(image, result, threshold=0.5, save_dir='./'):
"""
Visualize bbox and mask results
"""
......@@ -31,11 +35,12 @@ def visualize_detection(image, result, threshold=0.5, save_dir=None):
os.makedirs(save_dir)
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
image.save(out_path, quality=95)
logging.info('The visualized result is saved as {}'.format(out_path))
else:
return image
def visualize_segmentation(image, result, weight=0.6, save_dir=None):
def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
"""
Convert segment result to color image, and save added image.
Args:
......@@ -62,6 +67,7 @@ def visualize_segmentation(image, result, weight=0.6, save_dir=None):
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
cv2.imwrite(out_path, vis_result)
logging.info('The visualized result is saved as {}'.format(out_path))
else:
return vis_result
......@@ -160,3 +166,129 @@ def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
img_array[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(img_array.astype('uint8'))
return image
def draw_pr_curve(eval_details_file=None,
gt=None,
pred_bbox=None,
pred_mask=None,
iou_thresh=0.5,
save_dir='./'):
if eval_details_file is not None:
import json
with open(eval_details_file, 'r') as f:
eval_details = json.load(f)
pred_bbox = eval_details['bbox']
if 'mask' in eval_details:
pred_mask = eval_details['mask']
gt = eval_details['gt']
if gt is None or pred_bbox is None:
raise Exception(
"gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
)
if pred_bbox is not None and len(pred_bbox) == 0:
raise Exception("There is no predicted bbox.")
if pred_mask is not None and len(pred_mask) == 0:
raise Exception("There is no predicted mask.")
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco = COCO()
coco.dataset = gt
coco.createIndex()
def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
p = coco_gt.params
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
if ap == 1:
# dimension of precision: [TxRxKxAxM]
s = coco_gt.eval['precision']
# IoU
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:, :, :, aind, mind]
else:
# dimension of recall: [TxKxAxM]
s = coco_gt.eval['recall']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:, :, aind, mind]
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
return mean_s
def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
from pycocotools.cocoeval import COCOeval
coco_dt = loadRes(coco_gt, coco_dt)
np.linspace = fixed_linspace
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.params.iouThrs = np.linspace(
iou_thresh, iou_thresh, 1, endpoint=True)
np.linspace = backup_linspace
coco_eval.evaluate()
coco_eval.accumulate()
stats = _summarize(coco_eval, iouThr=iou_thresh)
catIds = coco_gt.getCatIds()
if len(catIds) != coco_eval.eval['precision'].shape[2]:
raise Exception(
"The category number must be same as the third dimension of precisions."
)
x = np.arange(0.0, 1.01, 0.01)
color_map = get_color_map_list(256)[1:256]
plt.subplot(1, 2, 1)
plt.title(style + " precision-recall IoU={}".format(iou_thresh))
plt.xlabel("recall")
plt.ylabel("precision")
plt.xlim(0, 1.01)
plt.ylim(0, 1.01)
plt.grid(linestyle='--', linewidth=1)
plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
my_x_ticks = np.arange(0, 1.01, 0.1)
my_y_ticks = np.arange(0, 1.01, 0.1)
plt.xticks(my_x_ticks, fontsize=5)
plt.yticks(my_y_ticks, fontsize=5)
for idx, catId in enumerate(catIds):
pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
precision = pr_array[pr_array > -1]
ap = np.mean(precision) if precision.size else float('nan')
nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
float(ap * 100))
color = tuple(color_map[idx])
color = [float(c) / 255 for c in color]
color.append(0.75)
plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
plt.legend(loc="lower left", fontsize=5)
plt.subplot(1, 2, 2)
plt.title(style + " score-recall IoU={}".format(iou_thresh))
plt.xlabel('recall')
plt.ylabel('score')
plt.xlim(0, 1.01)
plt.ylim(0, 1.01)
plt.grid(linestyle='--', linewidth=1)
plt.xticks(my_x_ticks, fontsize=5)
plt.yticks(my_y_ticks, fontsize=5)
for idx, catId in enumerate(catIds):
nm = coco_gt.loadCats(catId)[0]['name']
sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
color = tuple(color_map[idx])
color = [float(c) / 255 for c in color]
color.append(0.75)
plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
plt.legend(loc="lower left", fontsize=5)
plt.savefig(
os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
style, iou_thresh)),
dpi=800)
plt.close()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
if pred_mask is not None:
cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')
......@@ -20,3 +20,4 @@ YOLOv3 = cv.models.YOLOv3
MaskRCNN = cv.models.MaskRCNN
transforms = cv.transforms.det_transforms
visualize = cv.models.utils.visualize.visualize_detection
draw_pr_curve = cv.models.utils.visualize.draw_pr_curve
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册