提交 52d9fcbd 编写于 作者: F FlyingQianMM

reimplement detection visualize

上级 15a73dd4
......@@ -15,7 +15,10 @@
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
import matplotlib as mpl
import matplotlib.figure as mplfigure
import matplotlib.colors as mplc
from matplotlib.backends.backend_agg import FigureCanvasAgg
def visualize_detection(image, result, threshold=0.5, save_dir=None):
......@@ -24,13 +27,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir=None):
"""
image_name = os.path.split(image)[-1]
image = Image.open(image).convert('RGB')
image = cv2.imread(image)
image = draw_bbox_mask(image, result, threshold=threshold)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
image.save(out_path, quality=95)
cv2.imwrite(out_path, image)
else:
return image
......@@ -117,46 +120,141 @@ def clip_bbox(bbox):
return xmin, ymin, xmax, ymax
def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
def draw_bbox_mask(image, results, threshold=0.5):
# refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
_SMALL_OBJECT_AREA_THRESH = 1000
# setup figure
width, height = image.shape[1], image.shape[0]
scale = 1
fig = mplfigure.Figure(frameon=False)
dpi = fig.get_dpi()
fig.set_size_inches(
(width * scale + 1e-2) / dpi,
(height * scale + 1e-2) / dpi,
)
canvas = FigureCanvasAgg(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
ax.set_xlim(0.0, width)
ax.set_ylim(height)
default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
linewidth = max(default_font_size / 4, 1)
labels = list()
for dt in np.array(results):
if dt['category'] not in labels:
labels.append(dt['category'])
color_map = get_color_map_list(len(labels))
color_map = get_color_map_list(256)
keep_results = []
areas = []
for dt in np.array(results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
if score < threshold:
continue
keep_results.append(dt)
areas.append(bbox[2] * bbox[3])
areas = np.asarray(areas)
sorted_idxs = np.argsort(-areas).tolist()
keep_results = [keep_results[k]
for k in sorted_idxs] if len(keep_results) > 0 else []
for dt in np.array(keep_results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
color = tuple(color_map[labels.index(cname)])
color = tuple(color_map[labels.index(cname) + 2])
color = [c / 255. for c in color]
# draw bbox
draw = ImageDraw.Draw(image)
draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill=color)
# draw label
text = "{} {:.2f}".format(cname, score)
tw, th = draw.textsize(text)
draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
ax.add_patch(
mpl.patches.Rectangle(
(xmin, ymin),
w,
h,
fill=False,
edgecolor=color,
linewidth=linewidth * scale,
alpha=0.5,
linestyle="-",
))
# draw mask
if 'mask' in dt:
mask = dt['mask']
color_mask = np.array(color_map[labels.index(
dt['category'])]).astype('float32')
img_array = np.array(image).astype('float32')
idx = np.nonzero(mask)
img_array[idx[0], idx[1], :] *= 1.0 - alpha
img_array[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(img_array.astype('uint8'))
return image
mask = np.ascontiguousarray(mask)
res = cv2.findContours(
mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
hierarchy = res[-1]
alpha = 0.75
if hierarchy is not None:
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
res = res[-2]
res = [x.flatten() for x in res]
res = [x for x in res if len(x) >= 6]
for segment in res:
segment = segment.reshape(-1, 2)
edge_color = mplc.to_rgb(color) + (1, )
polygon = mpl.patches.Polygon(
segment,
fill=True,
facecolor=mplc.to_rgb(color) + (alpha, ),
edgecolor=edge_color,
linewidth=max(default_font_size // 15 * scale, 1),
)
ax.add_patch(polygon)
# draw label
text_pos = (xmin, ymin)
horiz_align = "left"
instance_area = w * h
if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
or h < 40 * scale):
if ymin >= height - 5:
text_pos = (xmin, ymin)
else:
text_pos = (xmin, ymax)
height_ratio = h / np.sqrt(height * width)
font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
default_font_size)
text = "{} {:.2f}".format(cname, score)
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
ax.text(
text_pos[0],
text_pos[1],
text,
size=font_size * scale,
family="sans-serif",
bbox={
"facecolor": "black",
"alpha": 0.8,
"pad": 0.7,
"edgecolor": "none"
},
verticalalignment="top",
horizontalalignment=horiz_align,
color=color,
zorder=10,
rotation=0,
)
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
try:
import numexpr as ne
visualized_image = ne.evaluate(
"image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
except ImportError:
alpha = alpha.astype("float32") / 255.0
visualized_image = image * (1 - alpha) + rgb * alpha
visualized_image = visualized_image.astype("uint8")
return visualized_image
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册