未验证 提交 afa35471 编写于 作者: J Jason 提交者: GitHub

Merge pull request #26 from FlyingQianMM/develop_qh

reimplement detection visualize
......@@ -40,8 +40,16 @@ paddlex.det.draw_pr_curve(eval_details_file=None, gt=None, pred_bbox=None, pred_
**注意:**`eval_details_file`的优先级更高,只要`eval_details_file`不为None,就会从`eval_details_file`提取真值信息和预测结果做分析。当`eval_details_file`为None时,则用`gt``pred_mask``pred_mask`做分析。
### 使用示例
> 示例一:
点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)[数据集](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_ins_det.tar.gz)
点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/insect_epoch_270.zip)[数据集](https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz)
> 方式一:分析训练过程中保存的模型文件夹中的评估结果文件`eval_details.json`,例如[模型](https://bj.bcebos.com/paddlex/models/insect_epoch_270.zip)中的`eval_details.json`。
```
import paddlex as pdx
eval_details_file = 'insect_epoch_270/eval_details.json'
pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
```
> 方式二:分析模型评估函数返回的评估结果。
```
import os
# 选择使用0号卡
......@@ -50,40 +58,18 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.det import transforms
import paddlex as pdx
eval_transforms = transforms.Compose([
transforms.Normalize(),
transforms.ResizeByShort(short_size=800, max_size=1333),
transforms.Padding(coarsest_stride=32)
])
eval_dataset = pdx.datasets.CocoDetection(
data_dir='xiaoduxiong_ins_det/JPEGImages',
ann_file='xiaoduxiong_ins_det/val.json',
transforms=eval_transforms)
model = pdx.load_model('xiaoduxiong_epoch_12')
metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=1, return_details=True)
model = pdx.load_model('insect_epoch_270')
eval_dataset = pdx.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/val_list.txt',
label_list='insect_det/labels.txt',
transforms=model.eval_transforms)
metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=8, return_details=True)
gt = evaluate_details['gt']
bbox = evaluate_details['bbox']
mask = evaluate_details['mask']
# 分别可视化bboxmask的准召曲线
pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, pred_mask=mask, save_dir='./xiaoduxiong')
pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, save_dir='./insect')
```
预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/xiaoduxiong_bbox_pr_curve(iou-0.5).png)
预测mask的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/xiaoduxiong_segm_pr_curve(iou-0.5).png)
> 示例二:
使用[yolov3_darknet53.py示例代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/detection/yolov3_darknet53.py)训练完成后,加载模型评估结果文件进行分析:
```
import paddlex as pdx
eval_details_file = 'output/yolov3_darknet53/best_model/eval_details.json'
pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
```
预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
![](./images/insect_bbox_pr_curve(iou-0.5).png)
......
......@@ -17,6 +17,7 @@ import copy
import os.path as osp
import random
import numpy as np
from collections import OrderedDict
import xml.etree.ElementTree as ET
import paddlex.utils.logging as logging
from .dataset import Dataset
......@@ -66,7 +67,7 @@ class VOCDetection(Dataset):
annotations['categories'] = []
annotations['annotations'] = []
cname2cid = {}
cname2cid = OrderedDict()
label_id = 1
with open(label_list, 'r', encoding=get_encoding(label_list)) as fr:
for line in fr.readlines():
......
......@@ -14,9 +14,8 @@
import os
import cv2
import colorsys
import numpy as np
from PIL import Image, ImageDraw
import paddlex.utils.logging as logging
from .detection_eval import fixed_linspace, backup_linspace, loadRes
......@@ -27,13 +26,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
"""
image_name = os.path.split(image)[-1]
image = Image.open(image).convert('RGB')
image = cv2.imread(image)
image = draw_bbox_mask(image, result, threshold=threshold)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
image.save(out_path, quality=95)
cv2.imwrite(out_path, image)
logging.info('The visualized result is saved as {}'.format(out_path))
else:
return image
......@@ -122,49 +121,163 @@ def clip_bbox(bbox):
return xmin, ymin, xmax, ymax
def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
def draw_bbox_mask(image, results, threshold=0.5):
import matplotlib
matplotlib.use('Agg')
import matplotlib as mpl
import matplotlib.figure as mplfigure
import matplotlib.colors as mplc
from matplotlib.backends.backend_agg import FigureCanvasAgg
# refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
def _change_color_brightness(color, brightness_factor):
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (
brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(
polygon_color[0], modified_lightness, polygon_color[2])
return modified_color
_SMALL_OBJECT_AREA_THRESH = 1000
# setup figure
width, height = image.shape[1], image.shape[0]
scale = 1
fig = mplfigure.Figure(frameon=False)
dpi = fig.get_dpi()
fig.set_size_inches(
(width * scale + 1e-2) / dpi,
(height * scale + 1e-2) / dpi,
)
canvas = FigureCanvasAgg(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
ax.set_xlim(0.0, width)
ax.set_ylim(height)
default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
linewidth = max(default_font_size / 4, 1)
labels = list()
for dt in np.array(results):
if dt['category'] not in labels:
labels.append(dt['category'])
color_map = get_color_map_list(len(labels))
color_map = get_color_map_list(256)
keep_results = []
areas = []
for dt in np.array(results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
if score < threshold:
continue
keep_results.append(dt)
areas.append(bbox[2] * bbox[3])
areas = np.asarray(areas)
sorted_idxs = np.argsort(-areas).tolist()
keep_results = [keep_results[k]
for k in sorted_idxs] if len(keep_results) > 0 else []
for dt in np.array(keep_results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
color = tuple(color_map[labels.index(cname)])
color = tuple(color_map[labels.index(cname) + 2])
color = [c / 255. for c in color]
# draw bbox
draw = ImageDraw.Draw(image)
draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill=color)
# draw label
text = "{} {:.2f}".format(cname, score)
tw, th = draw.textsize(text)
draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
ax.add_patch(
mpl.patches.Rectangle(
(xmin, ymin),
w,
h,
fill=False,
edgecolor=color,
linewidth=linewidth * scale,
alpha=0.8,
linestyle="-",
))
# draw mask
if 'mask' in dt:
mask = dt['mask']
color_mask = np.array(color_map[labels.index(
dt['category'])]).astype('float32')
img_array = np.array(image).astype('float32')
idx = np.nonzero(mask)
img_array[idx[0], idx[1], :] *= 1.0 - alpha
img_array[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(img_array.astype('uint8'))
return image
mask = np.ascontiguousarray(mask)
res = cv2.findContours(
mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
hierarchy = res[-1]
alpha = 0.5
if hierarchy is not None:
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
res = res[-2]
res = [x.flatten() for x in res]
res = [x for x in res if len(x) >= 6]
for segment in res:
segment = segment.reshape(-1, 2)
edge_color = mplc.to_rgb(color) + (1, )
polygon = mpl.patches.Polygon(
segment,
fill=True,
facecolor=mplc.to_rgb(color) + (alpha, ),
edgecolor=edge_color,
linewidth=max(default_font_size // 15 * scale, 1),
)
ax.add_patch(polygon)
# draw label
text_pos = (xmin, ymin)
horiz_align = "left"
instance_area = w * h
if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
or h < 40 * scale):
if ymin >= height - 5:
text_pos = (xmin, ymin)
else:
text_pos = (xmin, ymax)
height_ratio = h / np.sqrt(height * width)
font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
default_font_size)
text = "{} {:.2f}".format(cname, score)
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
color = _change_color_brightness(color, brightness_factor=0.7)
ax.text(
text_pos[0],
text_pos[1],
text,
size=font_size * scale,
family="sans-serif",
bbox={
"facecolor": "black",
"alpha": 0.8,
"pad": 0.7,
"edgecolor": "none"
},
verticalalignment="top",
horizontalalignment=horiz_align,
color=color,
zorder=10,
rotation=0,
)
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
try:
import numexpr as ne
visualized_image = ne.evaluate(
"image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
except ImportError:
alpha = alpha.astype("float32") / 255.0
visualized_image = image * (1 - alpha) + rgb * alpha
visualized_image = visualized_image.astype("uint8")
return visualized_image
def draw_pr_curve(eval_details_file=None,
......@@ -189,6 +302,9 @@ def draw_pr_curve(eval_details_file=None,
raise Exception("There is no predicted bbox.")
if pred_mask is not None and len(pred_mask) == 0:
raise Exception("There is no predicted mask.")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco = COCO()
......@@ -221,7 +337,6 @@ def draw_pr_curve(eval_details_file=None,
return mean_s
def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
import matplotlib.pyplot as plt
from pycocotools.cocoeval import COCOeval
coco_dt = loadRes(coco_gt, coco_dt)
np.linspace = fixed_linspace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册