From 4f78214cb416dc00266bce0dadcda38a175c92f7 Mon Sep 17 00:00:00 2001 From: LaraStuStu Date: Sat, 28 Mar 2020 14:01:37 +0800 Subject: [PATCH] Create test_draw.py --- .../labelme_tests/utils_tests/test_draw.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 DataAnnotation/labelme/tests/labelme_tests/utils_tests/test_draw.py diff --git a/DataAnnotation/labelme/tests/labelme_tests/utils_tests/test_draw.py b/DataAnnotation/labelme/tests/labelme_tests/utils_tests/test_draw.py new file mode 100644 index 0000000..b563fe3 --- /dev/null +++ b/DataAnnotation/labelme/tests/labelme_tests/utils_tests/test_draw.py @@ -0,0 +1,48 @@ +import numpy as np + +from labelme.utils import draw as draw_module +from labelme.utils import shape as shape_module + +from .util import get_img_and_lbl + + +# ----------------------------------------------------------------------------- + + +def test_label_colormap(): + N = 255 + colormap = draw_module.label_colormap(N=N) + assert colormap.shape == (N, 3) + + +def test_label2rgb(): + img, lbl, label_names = get_img_and_lbl() + n_labels = len(label_names) + + viz = draw_module.label2rgb(lbl=lbl, n_labels=n_labels) + assert lbl.shape == viz.shape[:2] + assert viz.dtype == np.uint8 + + viz = draw_module.label2rgb(lbl=lbl, img=img, n_labels=n_labels) + assert img.shape[:2] == lbl.shape == viz.shape[:2] + assert viz.dtype == np.uint8 + + +def test_draw_label(): + img, lbl, label_names = get_img_and_lbl() + + viz = draw_module.draw_label(lbl, img, label_names=label_names) + assert viz.shape[:2] == img.shape[:2] == lbl.shape[:2] + assert viz.dtype == np.uint8 + + +def test_draw_instances(): + img, lbl, label_names = get_img_and_lbl() + labels_and_masks = {l: lbl == l for l in np.unique(lbl) if l != 0} + labels, masks = zip(*labels_and_masks.items()) + masks = np.asarray(masks) + bboxes = shape_module.masks_to_bboxes(masks) + captions = [label_names[l] for l in labels] + viz = draw_module.draw_instances(img, bboxes, labels, captions=captions) + assert viz.shape[:2] == img.shape[:2] + assert viz.dtype == np.uint8 -- GitLab