提交 28b3923d 编写于 作者: K Kentaro Wada

Merge semantic_segmentation contrib

上级 b0235b64
......@@ -24,69 +24,6 @@ from labelme.utils import label2rgb
from labelme.utils import label_colormap
# TODO(wkentaro): Move to labelme/utils.py
# contrib
# -----------------------------------------------------------------------------
def labelme_shapes_to_label(img_shape, shapes, label_name_to_value):
lbl = np.zeros(img_shape[:2], dtype=np.int32)
for shape in shapes:
polygons = shape['points']
label_name = shape['label']
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
mask = labelme.utils.polygons_to_mask(img_shape[:2], polygons)
lbl[mask] = label_value
return lbl
def draw_label(label, img, label_names, colormap=None):
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
wspace=0, hspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
if colormap is None:
colormap = label_colormap(len(label_names))
label_viz = label2rgb(
label, img, n_labels=len(label_names), alpha=.5)
plt.imshow(label_viz)
plt.axis('off')
plt_handlers = []
plt_titles = []
for label_value, label_name in enumerate(label_names):
if label_value not in label:
continue
if label_name.startswith('_'):
continue
fc = colormap[label_value]
p = plt.Rectangle((0, 0), 1, 1, fc=fc)
plt_handlers.append(p)
plt_titles.append(label_name)
plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
f = io.BytesIO()
plt.savefig(f, bbox_inches='tight', pad_inches=0)
plt.cla()
plt.close()
out_size = (img.shape[1], img.shape[0])
out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
out = np.asarray(out)
return out
# -----------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
......@@ -142,18 +79,19 @@ def main():
img = skimage.io.imread(img_file)
skimage.io.imsave(out_img_file, img)
lbl = labelme_shapes_to_label(
lbl = labelme.utils.shapes_to_label(
img_shape=img.shape,
shapes=data['shapes'],
label_name_to_value=class_name_to_id,
)
lbl_pil = PIL.Image.fromarray(lbl)
# Only works with uint8 label
# lbl_pil = PIL.Image.fromarray(lbl, mode='P')
# lbl_pil.putpalette((colormap * 255).flatten())
lbl_pil.save(out_lbl_file)
viz = draw_label(
viz = labelme.utils.draw_label(
lbl, img, class_names, colormap=colormap)
skimage.io.imsave(out_viz_file, viz)
......
......@@ -95,6 +95,10 @@ def draw_label(label, img, label_names, colormap=None):
plt_handlers = []
plt_titles = []
for label_value, label_name in enumerate(label_names):
if label_value not in label:
continue
if label_name.startswith('_'):
continue
fc = colormap[label_value]
p = plt.Rectangle((0, 0), 1, 1, fc=fc)
plt_handlers.append(p)
......@@ -114,22 +118,29 @@ def draw_label(label, img, label_names, colormap=None):
return out
def labelme_shapes_to_label(img_shape, shapes):
label_name_to_val = {'background': 0}
def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
lbl = np.zeros(img_shape[:2], dtype=np.int32)
for shape in shapes:
polygons = shape['points']
label_name = shape['label']
if label_name in label_name_to_val:
label_value = label_name_to_val[label_name]
else:
label_value = len(label_name_to_val)
label_name_to_val[label_name] = label_value
label_value = label_name_to_value[label_name]
mask = polygons_to_mask(img_shape[:2], polygons)
lbl[mask] = label_value
return lbl
def labelme_shapes_to_label(img_shape, shapes):
warnings.warn('labelme_shapes_to_label is deprecated, so please use '
'shapes_to_label.')
lbl_names = [None] * (max(label_name_to_val.values()) + 1)
for label_name, label_value in label_name_to_val.items():
lbl_names[label_value] = label_name
label_name_to_value = {}
for shape in shapes:
label_name = shape['label']
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
return lbl, lbl_names
lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
return lbl, label_name_to_value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册