提交 52d9fcbd 编写于 作者: F FlyingQianMM

reimplement detection visualize

上级 15a73dd4
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
import os import os
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image, ImageDraw import matplotlib as mpl
import matplotlib.figure as mplfigure
import matplotlib.colors as mplc
from matplotlib.backends.backend_agg import FigureCanvasAgg
def visualize_detection(image, result, threshold=0.5, save_dir=None): def visualize_detection(image, result, threshold=0.5, save_dir=None):
...@@ -24,13 +27,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir=None): ...@@ -24,13 +27,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir=None):
""" """
image_name = os.path.split(image)[-1] image_name = os.path.split(image)[-1]
image = Image.open(image).convert('RGB') image = cv2.imread(image)
image = draw_bbox_mask(image, result, threshold=threshold) image = draw_bbox_mask(image, result, threshold=threshold)
if save_dir is not None: if save_dir is not None:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name)) out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
image.save(out_path, quality=95) cv2.imwrite(out_path, image)
else: else:
return image return image
...@@ -117,46 +120,141 @@ def clip_bbox(bbox): ...@@ -117,46 +120,141 @@ def clip_bbox(bbox):
return xmin, ymin, xmax, ymax return xmin, ymin, xmax, ymax
def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7): def draw_bbox_mask(image, results, threshold=0.5):
# refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
_SMALL_OBJECT_AREA_THRESH = 1000
# setup figure
width, height = image.shape[1], image.shape[0]
scale = 1
fig = mplfigure.Figure(frameon=False)
dpi = fig.get_dpi()
fig.set_size_inches(
(width * scale + 1e-2) / dpi,
(height * scale + 1e-2) / dpi,
)
canvas = FigureCanvasAgg(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
ax.set_xlim(0.0, width)
ax.set_ylim(height)
default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
linewidth = max(default_font_size / 4, 1)
labels = list() labels = list()
for dt in np.array(results): for dt in np.array(results):
if dt['category'] not in labels: if dt['category'] not in labels:
labels.append(dt['category']) labels.append(dt['category'])
color_map = get_color_map_list(len(labels)) color_map = get_color_map_list(256)
keep_results = []
areas = []
for dt in np.array(results): for dt in np.array(results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score'] cname, bbox, score = dt['category'], dt['bbox'], dt['score']
if score < threshold: if score < threshold:
continue continue
keep_results.append(dt)
areas.append(bbox[2] * bbox[3])
areas = np.asarray(areas)
sorted_idxs = np.argsort(-areas).tolist()
keep_results = [keep_results[k]
for k in sorted_idxs] if len(keep_results) > 0 else []
for dt in np.array(keep_results):
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
xmin, ymin, w, h = bbox xmin, ymin, w, h = bbox
xmax = xmin + w xmax = xmin + w
ymax = ymin + h ymax = ymin + h
color = tuple(color_map[labels.index(cname)]) color = tuple(color_map[labels.index(cname) + 2])
color = [c / 255. for c in color]
# draw bbox # draw bbox
draw = ImageDraw.Draw(image) ax.add_patch(
draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), mpl.patches.Rectangle(
(xmin, ymin)], (xmin, ymin),
width=2, w,
fill=color) h,
fill=False,
# draw label edgecolor=color,
text = "{} {:.2f}".format(cname, score) linewidth=linewidth * scale,
tw, th = draw.textsize(text) alpha=0.5,
draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], linestyle="-",
fill=color) ))
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
# draw mask # draw mask
if 'mask' in dt: if 'mask' in dt:
mask = dt['mask'] mask = dt['mask']
color_mask = np.array(color_map[labels.index( mask = np.ascontiguousarray(mask)
dt['category'])]).astype('float32') res = cv2.findContours(
img_array = np.array(image).astype('float32') mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
idx = np.nonzero(mask) hierarchy = res[-1]
img_array[idx[0], idx[1], :] *= 1.0 - alpha alpha = 0.75
img_array[idx[0], idx[1], :] += alpha * color_mask if hierarchy is not None:
image = Image.fromarray(img_array.astype('uint8')) has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
return image res = res[-2]
res = [x.flatten() for x in res]
res = [x for x in res if len(x) >= 6]
for segment in res:
segment = segment.reshape(-1, 2)
edge_color = mplc.to_rgb(color) + (1, )
polygon = mpl.patches.Polygon(
segment,
fill=True,
facecolor=mplc.to_rgb(color) + (alpha, ),
edgecolor=edge_color,
linewidth=max(default_font_size // 15 * scale, 1),
)
ax.add_patch(polygon)
# draw label
text_pos = (xmin, ymin)
horiz_align = "left"
instance_area = w * h
if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
or h < 40 * scale):
if ymin >= height - 5:
text_pos = (xmin, ymin)
else:
text_pos = (xmin, ymax)
height_ratio = h / np.sqrt(height * width)
font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
default_font_size)
text = "{} {:.2f}".format(cname, score)
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
ax.text(
text_pos[0],
text_pos[1],
text,
size=font_size * scale,
family="sans-serif",
bbox={
"facecolor": "black",
"alpha": 0.8,
"pad": 0.7,
"edgecolor": "none"
},
verticalalignment="top",
horizontalalignment=horiz_align,
color=color,
zorder=10,
rotation=0,
)
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
try:
import numexpr as ne
visualized_image = ne.evaluate(
"image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
except ImportError:
alpha = alpha.astype("float32") / 255.0
visualized_image = image * (1 - alpha) + rgb * alpha
visualized_image = visualized_image.astype("uint8")
return visualized_image
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册