未验证 提交 afa35471 编写于 作者: J Jason 提交者: GitHub

Merge pull request #26 from FlyingQianMM/develop_qh

reimplement detection visualize
...@@ -40,8 +40,16 @@ paddlex.det.draw_pr_curve(eval_details_file=None, gt=None, pred_bbox=None, pred_ ...@@ -40,8 +40,16 @@ paddlex.det.draw_pr_curve(eval_details_file=None, gt=None, pred_bbox=None, pred_
**注意:**`eval_details_file`的优先级更高,只要`eval_details_file`不为None,就会从`eval_details_file`提取真值信息和预测结果做分析。当`eval_details_file`为None时,则用`gt``pred_mask``pred_mask`做分析。 **注意:**`eval_details_file`的优先级更高,只要`eval_details_file`不为None,就会从`eval_details_file`提取真值信息和预测结果做分析。当`eval_details_file`为None时,则用`gt``pred_mask``pred_mask`做分析。
### 使用示例 ### 使用示例
> 示例一: 点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/insect_epoch_270.zip)[数据集](https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz)
点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)[数据集](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_ins_det.tar.gz)
> 方式一:分析训练过程中保存的模型文件夹中的评估结果文件`eval_details.json`,例如[模型](https://bj.bcebos.com/paddlex/models/insect_epoch_270.zip)中的`eval_details.json`。
```
import paddlex as pdx
eval_details_file = 'insect_epoch_270/eval_details.json'
pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
```
> 方式二:分析模型评估函数返回的评估结果。
``` ```
import os import os
# 选择使用0号卡 # 选择使用0号卡
...@@ -50,40 +58,18 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0' ...@@ -50,40 +58,18 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.det import transforms from paddlex.det import transforms
import paddlex as pdx import paddlex as pdx
eval_transforms = transforms.Compose([ model = pdx.load_model('insect_epoch_270')
transforms.Normalize(), eval_dataset = pdx.datasets.VOCDetection(
transforms.ResizeByShort(short_size=800, max_size=1333), data_dir='insect_det',
transforms.Padding(coarsest_stride=32) file_list='insect_det/val_list.txt',
]) label_list='insect_det/labels.txt',
transforms=model.eval_transforms)
eval_dataset = pdx.datasets.CocoDetection( metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=8, return_details=True)
data_dir='xiaoduxiong_ins_det/JPEGImages',
ann_file='xiaoduxiong_ins_det/val.json',
transforms=eval_transforms)
model = pdx.load_model('xiaoduxiong_epoch_12')
metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=1, return_details=True)
gt = evaluate_details['gt'] gt = evaluate_details['gt']
bbox = evaluate_details['bbox'] bbox = evaluate_details['bbox']
mask = evaluate_details['mask'] pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, save_dir='./insect')
# 分别可视化bboxmask的准召曲线
pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, pred_mask=mask, save_dir='./xiaoduxiong')
``` ```
预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/xiaoduxiong_bbox_pr_curve(iou-0.5).png)
预测mask的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/xiaoduxiong_segm_pr_curve(iou-0.5).png)
> 示例二:
使用[yolov3_darknet53.py示例代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/detection/yolov3_darknet53.py)训练完成后,加载模型评估结果文件进行分析:
```
import paddlex as pdx
eval_details_file = 'output/yolov3_darknet53/best_model/eval_details.json'
pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
```
预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下: 预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/insect_bbox_pr_curve(iou-0.5).png) ![](./images/insect_bbox_pr_curve(iou-0.5).png)
......
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import os.path as osp import os.path as osp
import random import random
import numpy as np import numpy as np
from collections import OrderedDict
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
from .dataset import Dataset from .dataset import Dataset
...@@ -66,7 +67,7 @@ class VOCDetection(Dataset): ...@@ -66,7 +67,7 @@ class VOCDetection(Dataset):
annotations['categories'] = [] annotations['categories'] = []
annotations['annotations'] = [] annotations['annotations'] = []
cname2cid = {} cname2cid = OrderedDict()
label_id = 1 label_id = 1
with open(label_list, 'r', encoding=get_encoding(label_list)) as fr: with open(label_list, 'r', encoding=get_encoding(label_list)) as fr:
for line in fr.readlines(): for line in fr.readlines():
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
import os import os
import cv2 import cv2
import colorsys
import numpy as np import numpy as np
from PIL import Image, ImageDraw
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
from .detection_eval import fixed_linspace, backup_linspace, loadRes from .detection_eval import fixed_linspace, backup_linspace, loadRes
...@@ -27,13 +26,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'): ...@@ -27,13 +26,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
""" """
image_name = os.path.split(image)[-1] image_name = os.path.split(image)[-1]
image = Image.open(image).convert('RGB') image = cv2.imread(image)
image = draw_bbox_mask(image, result, threshold=threshold) image = draw_bbox_mask(image, result, threshold=threshold)
if save_dir is not None: if save_dir is not None:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name)) out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
image.save(out_path, quality=95) cv2.imwrite(out_path, image)
logging.info('The visualized result is saved as {}'.format(out_path)) logging.info('The visualized result is saved as {}'.format(out_path))
else: else:
return image return image
...@@ -122,49 +121,163 @@ def clip_bbox(bbox): ...@@ -122,49 +121,163 @@ def clip_bbox(bbox):
return xmin, ymin, xmax, ymax return xmin, ymin, xmax, ymax
def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7): def draw_bbox_mask(image, results, threshold=0.5):
import matplotlib
matplotlib.use('Agg')
import matplotlib as mpl
import matplotlib.figure as mplfigure
import matplotlib.colors as mplc
from matplotlib.backends.backend_agg import FigureCanvasAgg
# refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
def _change_color_brightness(color, brightness_factor):
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (
brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(
polygon_color[0], modified_lightness, polygon_color[2])
return modified_color
_SMALL_OBJECT_AREA_THRESH = 1000
# setup figure
width, height = image.shape[1], image.shape[0]
scale = 1
fig = mplfigure.Figure(frameon=False)
dpi = fig.get_dpi()
fig.set_size_inches(
(width * scale + 1e-2) / dpi,
(height * scale + 1e-2) / dpi,
)
canvas = FigureCanvasAgg(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
ax.set_xlim(0.0, width)
ax.set_ylim(height)
default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
linewidth = max(default_font_size / 4, 1)
labels = list() labels = list()
for dt in np.array(results): for dt in np.array(results):
if dt['category'] not in labels: if dt['category'] not in labels:
labels.append(dt['category']) labels.append(dt['category'])
color_map = get_color_map_list(len(labels)) color_map = get_color_map_list(256)
keep_results = []
areas = []
for dt in np.array(results): for dt in np.array(results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score'] cname, bbox, score = dt['category'], dt['bbox'], dt['score']
if score < threshold: if score < threshold:
continue continue
keep_results.append(dt)
areas.append(bbox[2] * bbox[3])
areas = np.asarray(areas)
sorted_idxs = np.argsort(-areas).tolist()
keep_results = [keep_results[k]
for k in sorted_idxs] if len(keep_results) > 0 else []
for dt in np.array(keep_results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
xmin, ymin, w, h = bbox xmin, ymin, w, h = bbox
xmax = xmin + w xmax = xmin + w
ymax = ymin + h ymax = ymin + h
color = tuple(color_map[labels.index(cname)]) color = tuple(color_map[labels.index(cname) + 2])
color = [c / 255. for c in color]
# draw bbox # draw bbox
draw = ImageDraw.Draw(image) ax.add_patch(
draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), mpl.patches.Rectangle(
(xmin, ymin)], (xmin, ymin),
width=2, w,
fill=color) h,
fill=False,
# draw label edgecolor=color,
text = "{} {:.2f}".format(cname, score) linewidth=linewidth * scale,
tw, th = draw.textsize(text) alpha=0.8,
draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], linestyle="-",
fill=color) ))
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
# draw mask # draw mask
if 'mask' in dt: if 'mask' in dt:
mask = dt['mask'] mask = dt['mask']
color_mask = np.array(color_map[labels.index( mask = np.ascontiguousarray(mask)
dt['category'])]).astype('float32') res = cv2.findContours(
img_array = np.array(image).astype('float32') mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
idx = np.nonzero(mask) hierarchy = res[-1]
img_array[idx[0], idx[1], :] *= 1.0 - alpha alpha = 0.5
img_array[idx[0], idx[1], :] += alpha * color_mask if hierarchy is not None:
image = Image.fromarray(img_array.astype('uint8')) has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
return image res = res[-2]
res = [x.flatten() for x in res]
res = [x for x in res if len(x) >= 6]
for segment in res:
segment = segment.reshape(-1, 2)
edge_color = mplc.to_rgb(color) + (1, )
polygon = mpl.patches.Polygon(
segment,
fill=True,
facecolor=mplc.to_rgb(color) + (alpha, ),
edgecolor=edge_color,
linewidth=max(default_font_size // 15 * scale, 1),
)
ax.add_patch(polygon)
# draw label
text_pos = (xmin, ymin)
horiz_align = "left"
instance_area = w * h
if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
or h < 40 * scale):
if ymin >= height - 5:
text_pos = (xmin, ymin)
else:
text_pos = (xmin, ymax)
height_ratio = h / np.sqrt(height * width)
font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
default_font_size)
text = "{} {:.2f}".format(cname, score)
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
color = _change_color_brightness(color, brightness_factor=0.7)
ax.text(
text_pos[0],
text_pos[1],
text,
size=font_size * scale,
family="sans-serif",
bbox={
"facecolor": "black",
"alpha": 0.8,
"pad": 0.7,
"edgecolor": "none"
},
verticalalignment="top",
horizontalalignment=horiz_align,
color=color,
zorder=10,
rotation=0,
)
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
try:
import numexpr as ne
visualized_image = ne.evaluate(
"image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
except ImportError:
alpha = alpha.astype("float32") / 255.0
visualized_image = image * (1 - alpha) + rgb * alpha
visualized_image = visualized_image.astype("uint8")
return visualized_image
def draw_pr_curve(eval_details_file=None, def draw_pr_curve(eval_details_file=None,
...@@ -189,6 +302,9 @@ def draw_pr_curve(eval_details_file=None, ...@@ -189,6 +302,9 @@ def draw_pr_curve(eval_details_file=None,
raise Exception("There is no predicted bbox.") raise Exception("There is no predicted bbox.")
if pred_mask is not None and len(pred_mask) == 0: if pred_mask is not None and len(pred_mask) == 0:
raise Exception("There is no predicted mask.") raise Exception("There is no predicted mask.")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
coco = COCO() coco = COCO()
...@@ -221,7 +337,6 @@ def draw_pr_curve(eval_details_file=None, ...@@ -221,7 +337,6 @@ def draw_pr_curve(eval_details_file=None,
return mean_s return mean_s
def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'): def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
import matplotlib.pyplot as plt
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
coco_dt = loadRes(coco_gt, coco_dt) coco_dt = loadRes(coco_gt, coco_dt)
np.linspace = fixed_linspace np.linspace = fixed_linspace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册