提交 8edd2182 编写于 作者: K Kentaro Wada

Replace labelme.utils.draw with imgviz

上级 a98437ad
......@@ -9,6 +9,7 @@ import os
import os.path as osp
import sys
import imgviz
try:
import lxml.builder
import lxml.etree
......@@ -18,8 +19,6 @@ except ImportError:
import numpy as np
import PIL.Image
import labelme
def main():
parser = argparse.ArgumentParser(
......@@ -110,7 +109,7 @@ def main():
xmin, xmax = sorted([xmin, xmax])
ymin, ymax = sorted([ymin, ymax])
bboxes.append([xmin, ymin, xmax, ymax])
bboxes.append([ymin, xmin, ymax, xmax])
labels.append(class_id)
xml.append(
......@@ -130,10 +129,14 @@ def main():
if not args.noviz:
captions = [class_names[l] for l in labels]
viz = labelme.utils.draw_instances(
img, bboxes, labels, captions=captions
viz = imgviz.instances2rgb(
image=img,
labels=labels,
bboxes=bboxes,
captions=captions,
font_size=15,
)
PIL.Image.fromarray(viz).save(out_viz_file)
imgviz.io.imsave(out_viz_file, viz)
with open(out_xml_file, 'wb') as f:
f.write(lxml.etree.tostring(xml, pretty_print=True))
......
......@@ -9,6 +9,7 @@ import os
import os.path as osp
import sys
import imgviz
import numpy as np
import PIL.Image
......@@ -65,8 +66,6 @@ def main():
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
colormap = labelme.utils.label_colormap(255)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
......@@ -112,10 +111,14 @@ def main():
labelme.utils.lblsave(out_clsp_file, cls)
np.save(out_cls_file, cls)
if not args.noviz:
clsv = labelme.utils.draw_label(
cls, img, class_names, colormap=colormap
clsv = imgviz.label2rgb(
label=cls,
img=imgviz.rgb2gray(img),
label_names=class_names,
font_size=15,
loc='rb',
)
PIL.Image.fromarray(clsv).save(out_clsv_file)
imgviz.io.imsave(out_clsv_file, clsv)
# instance label
labelme.utils.lblsave(out_insp_file, ins)
......@@ -123,8 +126,14 @@ def main():
if not args.noviz:
instance_ids = np.unique(ins)
instance_names = [str(i) for i in range(max(instance_ids) + 1)]
insv = labelme.utils.draw_label(ins, img, instance_names)
PIL.Image.fromarray(insv).save(out_insv_file)
insv = imgviz.label2rgb(
label=ins,
img=imgviz.rgb2gray(img),
label_names=instance_names,
font_size=15,
loc='rb',
)
imgviz.io.imsave(out_insv_file, insv)
if __name__ == '__main__':
......
......@@ -9,6 +9,7 @@ import os
import os.path as osp
import sys
import imgviz
import numpy as np
import PIL.Image
......@@ -59,8 +60,6 @@ def main():
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
colormap = labelme.utils.label_colormap(255)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
......@@ -94,9 +93,14 @@ def main():
np.save(out_lbl_file, lbl)
if not args.noviz:
viz = labelme.utils.draw_label(
lbl, img, class_names, colormap=colormap)
PIL.Image.fromarray(viz).save(out_viz_file)
viz = imgviz.label2rgb(
label=lbl,
img=imgviz.rgb2gray(img),
font_size=15,
label_names=class_names,
loc='rb',
)
imgviz.io.imsave(out_viz_file, viz)
if __name__ == '__main__':
......
......@@ -6,6 +6,7 @@ import json
import os
import sys
import imgviz
import matplotlib.pyplot as plt
from labelme import utils
......@@ -45,7 +46,13 @@ def main():
label_names = [None] * (max(label_name_to_value.values()) + 1)
for name, value in label_name_to_value.items():
label_names[value] = name
lbl_viz = utils.draw_label(lbl, img, label_names)
lbl_viz = imgviz.label2rgb(
label=lbl,
img=imgviz.rgb2gray(img),
label_names=label_names,
font_size=30,
loc='rb',
)
plt.subplot(121)
plt.imshow(img)
......
......@@ -4,6 +4,7 @@ import json
import os
import os.path as osp
import imgviz
import PIL.Image
import yaml
......@@ -54,7 +55,9 @@ def main():
label_names = [None] * (max(label_name_to_value.values()) + 1)
for name, value in label_name_to_value.items():
label_names[value] = name
lbl_viz = utils.draw_label(lbl, img, label_names)
lbl_viz = imgviz.label2rgb(
label=lbl, img=img, label_names=label_names, loc='rb'
)
PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))
utils.lblsave(osp.join(out_dir, 'label.png'), lbl)
......
......@@ -13,11 +13,6 @@ from .shape import polygons_to_mask
from .shape import shape_to_mask
from .shape import shapes_to_label
from .draw import draw_instances
from .draw import draw_label
from .draw import label_colormap
from .draw import label2rgb
from .qt import newIcon
from .qt import newButton
from .qt import newAction
......
......@@ -3,17 +3,17 @@ import os.path as osp
import numpy as np
import PIL.Image
from labelme.utils.draw import label_colormap
def lblsave(filename, lbl):
import imgviz
if osp.splitext(filename)[1] != '.png':
filename += '.png'
# Assume label ranses [-1, 254] for int32,
# and [0, 255] for uint8 as VOC.
if lbl.min() >= -1 and lbl.max() < 255:
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
colormap = label_colormap(255)
colormap = imgviz.label_colormap()
lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
lbl_pil.save(filename)
else:
......
import io
import os.path as osp
import numpy as np
import PIL.Image
import PIL.ImageDraw
import PIL.ImageFont
def label_colormap(N=256):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
cmap = np.zeros((N, 3))
for i in range(0, N):
id = i
r, g, b = 0, 0, 0
for j in range(0, 8):
r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
id = (id >> 3)
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
cmap = cmap.astype(np.float32) / 255
return cmap
def _validate_colormap(colormap, n_labels):
if colormap is None:
colormap = label_colormap(n_labels)
else:
assert colormap.shape == (colormap.shape[0], 3), \
'colormap must be sequence of RGB values'
assert 0 <= colormap.min() and colormap.max() <= 1, \
'colormap must ranges 0 to 1'
return colormap
# similar function as skimage.color.label2rgb
def label2rgb(
lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
):
if n_labels is None:
n_labels = len(np.unique(lbl))
colormap = _validate_colormap(colormap, n_labels)
colormap = (colormap * 255).astype(np.uint8)
lbl_viz = colormap[lbl]
lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
if img is not None:
img_gray = PIL.Image.fromarray(img).convert('LA')
img_gray = np.asarray(img_gray.convert('RGB'))
# img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
lbl_viz = lbl_viz.astype(np.uint8)
return lbl_viz
def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
"""Draw pixel-wise label with colorization and label names.
label: ndarray, (H, W)
Pixel-wise labels to colorize.
img: ndarray, (H, W, 3), optional
Image on which the colorized label will be drawn.
label_names: iterable
List of label names.
"""
import matplotlib.pyplot as plt
backend_org = plt.rcParams['backend']
plt.switch_backend('agg')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
wspace=0, hspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
if label_names is None:
label_names = [str(l) for l in range(label.max() + 1)]
colormap = _validate_colormap(colormap, len(label_names))
label_viz = label2rgb(
label, img, n_labels=len(label_names), colormap=colormap, **kwargs
)
plt.imshow(label_viz)
plt.axis('off')
plt_handlers = []
plt_titles = []
for label_value, label_name in enumerate(label_names):
if label_value not in label:
continue
fc = colormap[label_value]
p = plt.Rectangle((0, 0), 1, 1, fc=fc)
plt_handlers.append(p)
plt_titles.append('{value}: {name}'
.format(value=label_value, name=label_name))
plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
f = io.BytesIO()
plt.savefig(f, bbox_inches='tight', pad_inches=0)
plt.cla()
plt.close()
plt.switch_backend(backend_org)
out_size = (label_viz.shape[1], label_viz.shape[0])
out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
out = np.asarray(out)
return out
def draw_instances(
image=None,
bboxes=None,
labels=None,
masks=None,
captions=None,
):
import matplotlib
# TODO(wkentaro)
assert image is not None
assert bboxes is not None
assert labels is not None
assert masks is None
assert captions is not None
viz = PIL.Image.fromarray(image)
draw = PIL.ImageDraw.ImageDraw(viz)
font_path = osp.join(
osp.dirname(matplotlib.__file__),
'mpl-data/fonts/ttf/DejaVuSans.ttf'
)
font = PIL.ImageFont.truetype(font_path)
colormap = label_colormap(255)
for bbox, label, caption in zip(bboxes, labels, captions):
color = colormap[label]
color = tuple((color * 255).astype(np.uint8).tolist())
xmin, ymin, xmax, ymax = bbox
draw.rectangle((xmin, ymin, xmax, ymax), outline=color)
draw.text((xmin, ymin), caption, font=font)
return np.asarray(viz)
......@@ -29,6 +29,7 @@ del here
install_requires = [
'imgviz',
'matplotlib',
'numpy',
'Pillow>=2.8.0',
......
import numpy as np
from labelme.utils import draw as draw_module
from labelme.utils import shape as shape_module
from .util import get_img_and_lbl
# -----------------------------------------------------------------------------
def test_label_colormap():
N = 255
colormap = draw_module.label_colormap(N=N)
assert colormap.shape == (N, 3)
def test_label2rgb():
img, lbl, label_names = get_img_and_lbl()
n_labels = len(label_names)
viz = draw_module.label2rgb(lbl=lbl, n_labels=n_labels)
assert lbl.shape == viz.shape[:2]
assert viz.dtype == np.uint8
viz = draw_module.label2rgb(lbl=lbl, img=img, n_labels=n_labels)
assert img.shape[:2] == lbl.shape == viz.shape[:2]
assert viz.dtype == np.uint8
def test_draw_label():
img, lbl, label_names = get_img_and_lbl()
viz = draw_module.draw_label(lbl, img, label_names=label_names)
assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2]
assert viz.dtype == np.uint8
def test_draw_instances():
img, lbl, label_names = get_img_and_lbl()
labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0}
labels, masks = zip(*labels_and_masks.items())
masks = np.asarray(masks)
bboxes = shape_module.masks_to_bboxes(masks)
captions = [label_names[l] for l in labels]
viz = draw_module.draw_instances(img, bboxes, labels, captions=captions)
assert viz.shape[:2] == img.shape[:2]
assert viz.dtype == np.uint8
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册