visualizer.py 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import pycocotools.mask as mask_util
from PIL import Image, ImageDraw

from .colormap import colormap

__all__ = ['visualize_results']


Y
Yang Zhang 已提交
29
def visualize_results(image,
K
Kaipeng Deng 已提交
30
                      im_id,
31 32 33
                      catid2name,
                      threshold=0.5,
                      bbox_results=None,
K
Kaipeng Deng 已提交
34 35
                      mask_results=None,
                      is_bbox_normalized=False):
36 37 38 39
    """
    Visualize bbox and mask results
    """
    if mask_results:
K
Kaipeng Deng 已提交
40
        image = draw_mask(image, im_id, mask_results, threshold)
41
    if bbox_results:
K
Kaipeng Deng 已提交
42 43
        image = draw_bbox(image, im_id, catid2name, bbox_results,
                          threshold, is_bbox_normalized)
Y
Yang Zhang 已提交
44
    return image
45 46


K
Kaipeng Deng 已提交
47
def draw_mask(image, im_id, segms, threshold, alpha=0.7):
48 49 50 51 52
    """
    Draw mask on image
    """
    mask_color_id = 0
    w_ratio = .4
K
Kaipeng Deng 已提交
53
    color_list = colormap(rgb=True)
Y
Yang Zhang 已提交
54
    img_array = np.array(image).astype('float32')
55
    for dt in np.array(segms):
K
Kaipeng Deng 已提交
56 57
        if im_id != dt['image_id']:
            continue
58 59 60 61 62 63 64 65 66
        segm, score = dt['segmentation'], dt['score']
        if score < threshold:
            continue
        mask = mask_util.decode(segm) * 255
        color_mask = color_list[mask_color_id % len(color_list), 0:3]
        mask_color_id += 1
        for c in range(3):
            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
        idx = np.nonzero(mask)
Y
Yang Zhang 已提交
67 68 69
        img_array[idx[0], idx[1], :] *= 1.0 - alpha
        img_array[idx[0], idx[1], :] += alpha * color_mask
    return Image.fromarray(img_array.astype('uint8'))
70 71


72
def draw_bbox(image, im_id, catid2name, bboxes, threshold,
K
Kaipeng Deng 已提交
73
              is_bbox_normalized=False):
74 75 76 77 78
    """
    Draw bbox on image
    """
    draw = ImageDraw.Draw(image)

K
Kaipeng Deng 已提交
79 80
    catid2color = {}
    color_list = colormap(rgb=True)[:40]
81
    for dt in np.array(bboxes):
K
Kaipeng Deng 已提交
82 83
        if im_id != dt['image_id']:
            continue
84 85 86
        catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
        if score < threshold:
            continue
K
Kaipeng Deng 已提交
87

K
Kaipeng Deng 已提交
88
        xmin, ymin, w, h = bbox
K
Kaipeng Deng 已提交
89 90 91 92 93 94
        if is_bbox_normalized:
            im_width, im_height = image.size
            xmin *= im_width
            ymin *= im_height
            w *= im_width
            h *= im_height
95 96
        xmax = xmin + w
        ymax = ymin + h
K
Kaipeng Deng 已提交
97 98 99 100 101 102 103

        if catid not in catid2color:
            idx = np.random.randint(len(color_list))
            catid2color[catid] = color_list[idx]
        color = tuple(catid2color[catid])

        # draw bbox
104 105 106 107
        draw.line(
            [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
             (xmin, ymin)],
            width=2,
K
Kaipeng Deng 已提交
108 109 110 111 112 113 114 115 116
            fill=color)

        # draw label
        text = "{} {:.2f}".format(catid2name[catid], score)
        tw, th = draw.textsize(text)
        draw.rectangle([(xmin + 1, ymin - th), 
                       (xmin + tw + 1, ymin)],
                       fill=color)
        draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
117

K
Kaipeng Deng 已提交
118
    return image