vis.py 8.6 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
# GPU memory garbage collection optimization flags
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"

import sys
import time
import argparse
import pprint
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid

from PIL import Image as PILImage
from utils.config import cfg
from metrics import ConfusionMatrix
from reader import SegDataset
from models.model_builder import build_model
from models.model_builder import ModelPhase


def parse_args():
    parser = argparse.ArgumentParser(description='PaddeSeg visualization tools')
    parser.add_argument(
        '--cfg',
        dest='cfg_file',
        help='Config file for training (and optionally testing)',
        default=None,
        type=str)
    parser.add_argument(
        '--use_gpu', dest='use_gpu', help='Use gpu or cpu', action='store_true')
    parser.add_argument(
        '--vis_dir',
        dest='vis_dir',
        help='visual save dir',
        type=str,
        default='visual')
    parser.add_argument(
        '--also_save_raw_results',
        dest='also_save_raw_results',
        help='whether to save raw result',
        action='store_true')
    parser.add_argument(
        '--local_test',
        dest='local_test',
        help='if in local test mode, only visualize 5 images for testing',
        action='store_true')
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    return parser.parse_args()


def makedirs(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def get_color_map(num_classes):
    """ Returns the color map for visualizing the segmentation mask,
        which can support arbitrary number of classes.
    Args:
        num_classes: Number of classes
    Returns:
        The color map
    """
    #color_map = num_classes * 3 *  [0]
    color_map = num_classes * [[0, 0, 0]]
    for i in range(0, num_classes):
        j = 0
        color_map[i] = [0, 0, 0]
        lab = i
        while lab:
            color_map[i][0] |= (((lab >> 0) & 1) << (7 - j))
            color_map[i][1] |= (((lab >> 1) & 1) << (7 - j))
            color_map[i][2] |= (((lab >> 2) & 1) << (7 - j))
            j += 1
            lab >>= 3

    return color_map


def colorize(image, shape, color_map):
    """
    Convert segment result to color image.
    """
    color_map = np.array(color_map).astype("uint8")
    # Use OpenCV LUT for color mapping
    c1 = cv2.LUT(image, color_map[:, 0])
    c2 = cv2.LUT(image, color_map[:, 1])
    c3 = cv2.LUT(image, color_map[:, 2])
    color_res = np.dstack((c1, c2, c3))

    return color_res


def to_png_fn(fn):
    """
    Append png as filename postfix
    """
    directory, filename = os.path.split(fn)
    basename, ext = os.path.splitext(filename)

    return basename + ".png"


def visualize(cfg,
              vis_file_list=None,
              use_gpu=False,
              vis_dir="visual",
              also_save_raw_results=False,
              ckpt_dir=None,
              log_writer=None,
              local_test=False,
              **kwargs):
    if vis_file_list is None:
        vis_file_list = cfg.DATASET.TEST_FILE_LIST
    dataset = SegDataset(
        file_list=vis_file_list,
        mode=ModelPhase.VISUAL,
        data_dir=cfg.DATASET.DATA_DIR)

    startup_prog = fluid.Program()
    test_prog = fluid.Program()
    pred, logit = build_model(test_prog, startup_prog, phase=ModelPhase.VISUAL)
    # Clone forward graph
    test_prog = test_prog.clone(for_test=True)

    # Generator full colormap for maximum 256 classes
    color_map = get_color_map(256)

    # Get device environment
    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(startup_prog)

    ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir

    fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)

    save_dir = os.path.join(vis_dir, 'visual_results')
    makedirs(save_dir)
    if also_save_raw_results:
        raw_save_dir = os.path.join(vis_dir, 'raw_results')
        makedirs(raw_save_dir)

    fetch_list = [pred.name]
    test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True)
    img_cnt = 0
W
wuzewu 已提交
174
    for imgs, grts, img_names, valid_shapes, org_shapes in test_reader:
W
wuzewu 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187
        pred_shape = (imgs.shape[2], imgs.shape[3])
        pred, = exe.run(
            program=test_prog,
            feed={'image': imgs},
            fetch_list=fetch_list,
            return_numpy=True)

        num_imgs = pred.shape[0]
        # TODO: use multi-thread to write images
        for i in range(num_imgs):
            # Add more comments
            res_map = np.squeeze(pred[i, :, :, :]).astype(np.uint8)
            img_name = img_names[i]
W
wuzewu 已提交
188
            grt = grts[i]
W
wuzewu 已提交
189 190 191 192 193 194 195 196 197 198 199
            res_shape = (res_map.shape[0], res_map.shape[1])
            if res_shape[0] != pred_shape[0] or res_shape[1] != pred_shape[1]:
                res_map = cv2.resize(
                    res_map, pred_shape, interpolation=cv2.INTER_NEAREST)
            valid_shape = (valid_shapes[i, 0], valid_shapes[i, 1])
            res_map = res_map[0:valid_shape[0], 0:valid_shape[1]]
            org_shape = (org_shapes[i, 0], org_shapes[i, 1])
            res_map = cv2.resize(
                res_map, (org_shape[1], org_shape[0]),
                interpolation=cv2.INTER_NEAREST)

W
wuzewu 已提交
200
            if grt is not None:
W
wuzewu 已提交
201
                grt = grt[0:valid_shape[0], 0:valid_shape[1]]
W
wuzewu 已提交
202 203 204 205
                grt = cv2.resize(
                    grt, (org_shape[1], org_shape[0]),
                    interpolation=cv2.INTER_NEAREST)

W
wuzewu 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218
            png_fn = to_png_fn(img_names[i])
            if also_save_raw_results:
                raw_fn = os.path.join(raw_save_dir, png_fn)
                dirname = os.path.dirname(raw_save_dir)
                makedirs(dirname)
                cv2.imwrite(raw_fn, res_map)

            # colorful segment result visualization
            vis_fn = os.path.join(save_dir, png_fn)
            dirname = os.path.dirname(vis_fn)
            makedirs(dirname)

            pred_mask = colorize(res_map, org_shapes[i], color_map)
W
wuzewu 已提交
219 220
            if grt is not None:
                grt = colorize(grt, org_shapes[i], color_map)
W
wuzewu 已提交
221 222 223 224 225 226 227 228
            cv2.imwrite(vis_fn, pred_mask)

            img_cnt += 1
            print("#{} visualize image path: {}".format(img_cnt, vis_fn))

            # Use Tensorboard to visualize image
            if log_writer is not None:
                # Calulate epoch from ckpt_dir folder name
W
wuzewu 已提交
229
                epoch = int(os.path.split(ckpt_dir)[-1])
W
wuzewu 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
                print("Tensorboard visualization epoch", epoch)
                log_writer.add_image(
                    "Predict/{}".format(img_names[i]),
                    pred_mask[..., ::-1],
                    epoch,
                    dataformats='HWC')
                # Original image
                # BGR->RGB
                img = cv2.imread(
                    os.path.join(cfg.DATASET.DATA_DIR, img_names[i]))[..., ::-1]
                log_writer.add_image(
                    "Images/{}".format(img_names[i]),
                    img,
                    epoch,
                    dataformats='HWC')
W
wuzewu 已提交
245 246 247 248 249 250 251
                #add ground truth (label) images
                if grt is not None:
                    log_writer.add_image(
                        "Label/{}".format(img_names[i]),
                        grt[..., ::-1],
                        epoch,
                        dataformats='HWC')
W
wuzewu 已提交
252 253 254 255 256 257 258 259 260 261 262

        # If in local_test mode, only visualize 5 images just for testing
        # procedure
        if local_test and img_cnt >= 5:
            break


if __name__ == '__main__':
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
263
    if args.opts:
W
wuzewu 已提交
264 265 266 267
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    visualize(cfg, **args.__dict__)