提交 37fe44a9 编写于 作者: K Kentaro Wada

Add colormap arg to draw_label

上级 9a4c5a70
......@@ -26,15 +26,28 @@ def label_colormap(N=256):
return cmap
def _validate_colormap(colormap, n_labels):
if colormap is None:
colormap = label_colormap(n_labels)
else:
assert colormap.shape == (colormap.shape[0], 3), \
'colormap must be sequence of RGB values'
assert 0 <= colormap.min() and colormap.max() <= 1, \
'colormap must ranges 0 to 1'
return colormap
# similar function as skimage.color.label2rgb
def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
def label2rgb(
lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
):
if n_labels is None:
n_labels = len(np.unique(lbl))
cmap = label_colormap(n_labels)
cmap = (cmap * 255).astype(np.uint8)
colormap = _validate_colormap(colormap, n_labels)
colormap = (colormap * 255).astype(np.uint8)
lbl_viz = cmap[lbl]
lbl_viz = colormap[lbl]
lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
if img is not None:
......@@ -48,8 +61,18 @@ def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
return lbl_viz
def draw_label(label, img=None, label_names=None, colormap=None):
def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
"""Draw pixel-wise label with colorization and label names.
label: ndarray, (H, W)
Pixel-wise labels to colorize.
img: ndarray, (H, W, 3), optional
Image on which the colorized label will be drawn.
label_names: iterable
List of label names.
"""
import matplotlib.pyplot as plt
backend_org = plt.rcParams['backend']
plt.switch_backend('agg')
......@@ -62,10 +85,11 @@ def draw_label(label, img=None, label_names=None, colormap=None):
if label_names is None:
label_names = [str(l) for l in range(label.max() + 1)]
if colormap is None:
colormap = label_colormap(len(label_names))
colormap = _validate_colormap(colormap, len(label_names))
label_viz = label2rgb(label, img, n_labels=len(label_names))
label_viz = label2rgb(
label, img, n_labels=len(label_names), colormap=colormap, **kwargs
)
plt.imshow(label_viz)
plt.axis('off')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册