提交 4f685d1a 编写于 作者: L LutaoChu 提交者: wuzewu

support colorful label function (#97)

* modify label tool

* support color label

* vis and check support colorful label

* remove needless modification

* remove needless get_color_map_list

* supprt tensorboard
上级 23c90ac3
......@@ -15,6 +15,7 @@ import imghdr
import logging
from utils.config import cfg
from reader import pil_imread
def init_global_variable():
......@@ -452,7 +453,7 @@ def check_train_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
grt = pil_imread(grt_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
......@@ -502,7 +503,7 @@ def check_val_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
grt = pil_imread(grt_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
......@@ -561,7 +562,7 @@ def check_test_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
grt = pil_imread(grt_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
......
......@@ -27,6 +27,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
import cv2
from PIL import Image
import data_aug as aug
from utils.config import cfg
......@@ -34,6 +35,13 @@ from data_utils import GeneratorEnqueuer
from models.model_builder import ModelPhase
import copy
def pil_imread(file_path):
"""read pseudo-color label"""
im = Image.open(file_path)
return np.asarray(im)
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform.
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
......@@ -179,7 +187,7 @@ class SegDataset(object):
if grt_name is not None:
grt_path = os.path.join(src_dir, grt_name)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
grt = pil_imread(grt_path)
else:
grt = None
......
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import glob
import os
import os.path as osp
import sys
import numpy as np
from PIL import Image
from pdseg.vis import get_color_map_list
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('dir_or_file',
help='input gray label directory or file list path')
parser.add_argument('output_dir',
help='output colorful label directory')
parser.add_argument('--dataset_dir',
help='dataset directory')
parser.add_argument('--file_separator',
help='file list separator')
return parser.parse_args()
def gray2pseudo_color(args):
"""将灰度标注图片转换为伪彩色图片"""
input = args.dir_or_file
output_dir = args.output_dir
if not osp.exists(output_dir):
os.makedirs(output_dir)
print('Creating colorful label directory:', output_dir)
color_map = get_color_map_list(256)
if os.path.isdir(input):
for grt_path in glob.glob(osp.join(input, '*.png')):
print('Converting original label:', grt_path)
basename = osp.basename(grt_path)
im = Image.open(grt_path)
lbl = np.asarray(im)
lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode='P')
lbl_pil.putpalette(color_map)
new_file = osp.join(output_dir, basename)
lbl_pil.save(new_file)
elif os.path.isfile(input):
if args.dataset_dir is None or args.file_separator is None:
print('No dataset_dir or file_separator input!')
sys.exit()
with open(input) as f:
for line in f:
parts = line.strip().split(args.file_separator)
grt_name = parts[1]
grt_path = os.path.join(args.dataset_dir, grt_name)
print('Converting original label:', grt_path)
basename = osp.basename(grt_path)
im = Image.open(grt_path)
lbl = np.asarray(im)
lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode='P')
lbl_pil.putpalette(color_map)
new_file = osp.join(output_dir, basename)
lbl_pil.save(new_file)
else:
print('It\'s neither a dir nor a file')
if __name__ == '__main__':
args = parse_args()
gray2pseudo_color(args)
......@@ -10,9 +10,10 @@ import os.path as osp
import numpy as np
import PIL.Image
import labelme
from pdseg.vis import get_color_map_list
def parse_args():
parser = argparse.ArgumentParser(
......@@ -55,6 +56,8 @@ def main(args):
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
color_map = get_color_map_list(256)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
......@@ -78,6 +81,8 @@ def main(args):
shape = {'label': name, 'points': points, 'shape_type': 'polygon'}
data_shapes.append(shape)
if 'size' not in data:
continue
data_size = data['size']
img_shape = (data_size['height'], data_size['width'], data_size['depth'])
......@@ -91,7 +96,8 @@ def main(args):
out_png_file += '.png'
# Assume label ranges [0, 255] for uint8,
if lbl.min() >= 0 and lbl.max() <= 255:
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='L')
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
lbl_pil.putpalette(color_map)
lbl_pil.save(out_png_file)
else:
raise ValueError(
......
......@@ -10,9 +10,10 @@ import os.path as osp
import numpy as np
import PIL.Image
import labelme
from pdseg.vis import get_color_map_list
def parse_args():
parser = argparse.ArgumentParser(
......@@ -35,7 +36,6 @@ def main(args):
with open(label_file) as f:
data = json.load(f)
for shape in data['shapes']:
points = shape['points']
label = shape['label']
cls_name = label
if not cls_name in class_names:
......@@ -55,6 +55,8 @@ def main(args):
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
color_map = get_color_map_list(256)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
......@@ -77,7 +79,8 @@ def main(args):
out_png_file += '.png'
# Assume label ranges [0, 255] for uint8,
if lbl.min() >= 0 and lbl.max() <= 255:
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='L')
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
lbl_pil.putpalette(color_map)
lbl_pil.save(out_png_file)
else:
raise ValueError(
......
......@@ -18,21 +18,19 @@ from __future__ import division
from __future__ import print_function
import os
# GPU memory garbage collection optimization flags
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
import sys
import time
import argparse
import pprint
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
from PIL import Image as PILImage
from utils.config import cfg
from metrics import ConfusionMatrix
from reader import SegDataset
from models.model_builder import build_model
from models.model_builder import ModelPhase
......@@ -54,11 +52,6 @@ def parse_args():
help='visual save dir',
type=str,
default='visual')
parser.add_argument(
'--also_save_raw_results',
dest='also_save_raw_results',
help='whether to save raw result',
action='store_true')
parser.add_argument(
'--local_test',
dest='local_test',
......@@ -80,7 +73,7 @@ def makedirs(directory):
os.makedirs(directory)
def get_color_map(num_classes):
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
......@@ -88,36 +81,20 @@ def get_color_map(num_classes):
Returns:
The color map
"""
#color_map = num_classes * 3 * [0]
color_map = num_classes * [[0, 0, 0]]
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
color_map[i] = [0, 0, 0]
lab = i
while lab:
color_map[i][0] |= (((lab >> 0) & 1) << (7 - j))
color_map[i][1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i][2] |= (((lab >> 2) & 1) << (7 - j))
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
return color_map
def colorize(image, shape, color_map):
"""
Convert segment result to color image.
"""
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(image, color_map[:, 0])
c2 = cv2.LUT(image, color_map[:, 1])
c3 = cv2.LUT(image, color_map[:, 2])
color_res = np.dstack((c1, c2, c3))
return color_res
def to_png_fn(fn):
"""
Append png as filename postfix
......@@ -131,8 +108,7 @@ def to_png_fn(fn):
def visualize(cfg,
vis_file_list=None,
use_gpu=False,
vis_dir="visual",
also_save_raw_results=False,
vis_dir="visual_predict",
ckpt_dir=None,
log_writer=None,
local_test=False,
......@@ -151,7 +127,7 @@ def visualize(cfg,
test_prog = test_prog.clone(for_test=True)
# Generator full colormap for maximum 256 classes
color_map = get_color_map(256)
color_map = get_color_map_list(256)
# Get device environment
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
......@@ -162,11 +138,8 @@ def visualize(cfg,
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
save_dir = os.path.join(vis_dir, 'visual_results')
save_dir = os.path.join('visual', vis_dir)
makedirs(save_dir)
if also_save_raw_results:
raw_save_dir = os.path.join(vis_dir, 'raw_results')
makedirs(raw_save_dir)
fetch_list = [pred.name]
test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True)
......@@ -185,7 +158,6 @@ def visualize(cfg,
# Add more comments
res_map = np.squeeze(pred[i, :, :, :]).astype(np.uint8)
img_name = img_names[i]
grt = grts[i]
res_shape = (res_map.shape[0], res_map.shape[1])
if res_shape[0] != pred_shape[0] or res_shape[1] != pred_shape[1]:
res_map = cv2.resize(
......@@ -197,28 +169,16 @@ def visualize(cfg,
res_map, (org_shape[1], org_shape[0]),
interpolation=cv2.INTER_NEAREST)
if grt is not None:
grt = grt[0:valid_shape[0], 0:valid_shape[1]]
grt = cv2.resize(
grt, (org_shape[1], org_shape[0]),
interpolation=cv2.INTER_NEAREST)
png_fn = to_png_fn(img_names[i])
if also_save_raw_results:
raw_fn = os.path.join(raw_save_dir, png_fn)
dirname = os.path.dirname(raw_save_dir)
makedirs(dirname)
cv2.imwrite(raw_fn, res_map)
png_fn = to_png_fn(img_name)
# colorful segment result visualization
vis_fn = os.path.join(save_dir, png_fn)
dirname = os.path.dirname(vis_fn)
makedirs(dirname)
pred_mask = colorize(res_map, org_shapes[i], color_map)
if grt is not None:
grt = colorize(grt, org_shapes[i], color_map)
cv2.imwrite(vis_fn, pred_mask)
pred_mask = PILImage.fromarray(res_map.astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(vis_fn)
img_cnt += 1
print("#{} visualize image path: {}".format(img_cnt, vis_fn))
......@@ -228,25 +188,33 @@ def visualize(cfg,
# Calulate epoch from ckpt_dir folder name
epoch = int(os.path.split(ckpt_dir)[-1])
print("Tensorboard visualization epoch", epoch)
pred_mask_np = np.array(pred_mask.convert("RGB"))
log_writer.add_image(
"Predict/{}".format(img_names[i]),
pred_mask[..., ::-1],
"Predict/{}".format(img_name),
pred_mask_np,
epoch,
dataformats='HWC')
# Original image
# BGR->RGB
img = cv2.imread(
os.path.join(cfg.DATASET.DATA_DIR, img_names[i]))[..., ::-1]
os.path.join(cfg.DATASET.DATA_DIR, img_name))[..., ::-1]
log_writer.add_image(
"Images/{}".format(img_names[i]),
"Images/{}".format(img_name),
img,
epoch,
dataformats='HWC')
#add ground truth (label) images
# add ground truth (label) images
grt = grts[i]
if grt is not None:
grt = grt[0:valid_shape[0], 0:valid_shape[1]]
grt_pil = PILImage.fromarray(grt.astype(np.uint8), mode='P')
grt_pil.putpalette(color_map)
grt_pil = grt_pil.resize((org_shape[1], org_shape[0]))
grt = np.array(grt_pil.convert("RGB"))
log_writer.add_image(
"Label/{}".format(img_names[i]),
grt[..., ::-1],
"Label/{}".format(img_name),
grt,
epoch,
dataformats='HWC')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册