diff --git a/examples/semantic_segmentation/labelme2voc.py b/examples/semantic_segmentation/labelme2voc.py index 953946e1c386fa973180e338f80178bfb4c32261..eeaba21402aa767c45ff4cd0d4140aae707f455d 100755 --- a/examples/semantic_segmentation/labelme2voc.py +++ b/examples/semantic_segmentation/labelme2voc.py @@ -24,69 +24,6 @@ from labelme.utils import label2rgb from labelme.utils import label_colormap -# TODO(wkentaro): Move to labelme/utils.py -# contrib -# ----------------------------------------------------------------------------- - - -def labelme_shapes_to_label(img_shape, shapes, label_name_to_value): - lbl = np.zeros(img_shape[:2], dtype=np.int32) - for shape in shapes: - polygons = shape['points'] - label_name = shape['label'] - if label_name in label_name_to_value: - label_value = label_name_to_value[label_name] - else: - label_value = len(label_name_to_value) - label_name_to_value[label_name] = label_value - mask = labelme.utils.polygons_to_mask(img_shape[:2], polygons) - lbl[mask] = label_value - - return lbl - - -def draw_label(label, img, label_names, colormap=None): - plt.subplots_adjust(left=0, right=1, top=1, bottom=0, - wspace=0, hspace=0) - plt.margins(0, 0) - plt.gca().xaxis.set_major_locator(plt.NullLocator()) - plt.gca().yaxis.set_major_locator(plt.NullLocator()) - - if colormap is None: - colormap = label_colormap(len(label_names)) - - label_viz = label2rgb( - label, img, n_labels=len(label_names), alpha=.5) - plt.imshow(label_viz) - plt.axis('off') - - plt_handlers = [] - plt_titles = [] - for label_value, label_name in enumerate(label_names): - if label_value not in label: - continue - if label_name.startswith('_'): - continue - fc = colormap[label_value] - p = plt.Rectangle((0, 0), 1, 1, fc=fc) - plt_handlers.append(p) - plt_titles.append(label_name) - plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5) - - f = io.BytesIO() - plt.savefig(f, bbox_inches='tight', pad_inches=0) - plt.cla() - plt.close() - - out_size = (img.shape[1], img.shape[0]) - out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB') - out = np.asarray(out) - return out - - -# ----------------------------------------------------------------------------- - - def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -142,18 +79,19 @@ def main(): img = skimage.io.imread(img_file) skimage.io.imsave(out_img_file, img) - lbl = labelme_shapes_to_label( + lbl = labelme.utils.shapes_to_label( img_shape=img.shape, shapes=data['shapes'], label_name_to_value=class_name_to_id, ) + lbl_pil = PIL.Image.fromarray(lbl) # Only works with uint8 label # lbl_pil = PIL.Image.fromarray(lbl, mode='P') # lbl_pil.putpalette((colormap * 255).flatten()) lbl_pil.save(out_lbl_file) - viz = draw_label( + viz = labelme.utils.draw_label( lbl, img, class_names, colormap=colormap) skimage.io.imsave(out_viz_file, viz) diff --git a/labelme/utils.py b/labelme/utils.py index 025f9b9fb4eb240f72d3fde014031af2a7cf210c..7bf9e87b974859c6840043e850ddc3ecfeaddb0e 100644 --- a/labelme/utils.py +++ b/labelme/utils.py @@ -95,6 +95,10 @@ def draw_label(label, img, label_names, colormap=None): plt_handlers = [] plt_titles = [] for label_value, label_name in enumerate(label_names): + if label_value not in label: + continue + if label_name.startswith('_'): + continue fc = colormap[label_value] p = plt.Rectangle((0, 0), 1, 1, fc=fc) plt_handlers.append(p) @@ -114,22 +118,29 @@ def draw_label(label, img, label_names, colormap=None): return out -def labelme_shapes_to_label(img_shape, shapes): - label_name_to_val = {'background': 0} +def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'): lbl = np.zeros(img_shape[:2], dtype=np.int32) for shape in shapes: polygons = shape['points'] label_name = shape['label'] - if label_name in label_name_to_val: - label_value = label_name_to_val[label_name] - else: - label_value = len(label_name_to_val) - label_name_to_val[label_name] = label_value + label_value = label_name_to_value[label_name] mask = polygons_to_mask(img_shape[:2], polygons) lbl[mask] = label_value + return lbl + + +def labelme_shapes_to_label(img_shape, shapes): + warnings.warn('labelme_shapes_to_label is deprecated, so please use ' + 'shapes_to_label.') - lbl_names = [None] * (max(label_name_to_val.values()) + 1) - for label_name, label_value in label_name_to_val.items(): - lbl_names[label_value] = label_name + label_name_to_value = {} + for shape in shapes: + label_name = shape['label'] + if label_name in label_name_to_value: + label_value = label_name_to_value[label_name] + else: + label_value = len(label_name_to_value) + label_name_to_value[label_name] = label_value - return lbl, lbl_names + lbl = shapes_to_label(img_shape, shapes, label_name_to_value) + return lbl, label_name_to_value