gray2pseudo_color.py 3.6 KB
Newer Older
L
LutaoChu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# -*- coding: utf-8 -*-
from __future__ import print_function

import argparse
import os
import os.path as osp
import sys
import numpy as np
from PIL import Image


def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('dir_or_file',
                        help='input gray label directory or file list path')
    parser.add_argument('output_dir',
                        help='output colorful label directory')
    parser.add_argument('--dataset_dir',
                        help='dataset directory')
    parser.add_argument('--file_separator',
                        help='file list separator')
    return parser.parse_args()


27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def get_color_map_list(num_classes):
    """ Returns the color map for visualizing the segmentation mask,
        which can support arbitrary number of classes.
    Args:
        num_classes: Number of classes
    Returns:
        The color map
    """
    color_map = num_classes * [0, 0, 0]
    for i in range(0, num_classes):
        j = 0
        lab = i
        while lab:
            color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
            color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
            color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
            j += 1
            lab >>= 3

    return color_map


L
LutaoChu 已提交
49 50 51 52 53 54 55 56 57 58
def gray2pseudo_color(args):
    """将灰度标注图片转换为伪彩色图片"""
    input = args.dir_or_file
    output_dir = args.output_dir
    if not osp.exists(output_dir):
        os.makedirs(output_dir)
        print('Creating colorful label directory:', output_dir)

    color_map = get_color_map_list(256)
    if os.path.isdir(input):
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        for fpath, dirs, fs in os.walk(input):
            for f in fs:
                try:
                    grt_path = osp.join(fpath, f)
                    _output_dir = fpath.replace(input, '')
                    if _output_dir[0] == os.path.sep:
                        _output_dir = _output_dir.lstrip(os.path.sep)

                    im = Image.open(grt_path)
                    lbl = np.asarray(im)

                    lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode='P')
                    lbl_pil.putpalette(color_map)

                    real_dir = osp.join(output_dir, _output_dir)
                    if not osp.exists(real_dir):
                        os.makedirs(real_dir)
                    new_grt_path = osp.join(real_dir, f)

                    lbl_pil.save(new_grt_path)
                    print('New label path:', new_grt_path)
                except:
                    continue
L
LutaoChu 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    elif os.path.isfile(input):
        if args.dataset_dir is None or args.file_separator is None:
            print('No dataset_dir or file_separator input!')
            sys.exit()
        with open(input) as f:
            for line in f:
                parts = line.strip().split(args.file_separator)
                grt_name = parts[1]
                grt_path = os.path.join(args.dataset_dir, grt_name)

                im = Image.open(grt_path)
                lbl = np.asarray(im)

                lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode='P')
                lbl_pil.putpalette(color_map)

98 99 100 101 102 103 104 105
                grt_dir, _ = osp.split(grt_name)
                new_dir = osp.join(output_dir, grt_dir)
                if not osp.exists(new_dir):
                    os.makedirs(new_dir)
                new_grt_path = osp.join(output_dir, grt_name)

                lbl_pil.save(new_grt_path)
                print('New label path:', new_grt_path)
L
LutaoChu 已提交
106 107 108 109 110 111 112
    else:
        print('It\'s neither a dir nor a file')


if __name__ == '__main__':
    args = parse_args()
    gray2pseudo_color(args)