提交 5fabe443 编写于 作者: W wuzewu

Add label visualization

上级 01d3f4eb
......@@ -106,19 +106,21 @@ class SegDataset(object):
def batch(self, reader, batch_size, is_test=False, drop_last=False):
def batch_reader(is_test=False, drop_last=drop_last):
if is_test:
imgs, img_names, valid_shapes, org_shapes = [], [], [], []
for img, img_name, valid_shape, org_shape in reader():
imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []
for img, grt, img_name, valid_shape, org_shape in reader():
imgs.append(img)
grts.append(grt)
img_names.append(img_name)
valid_shapes.append(valid_shape)
org_shapes.append(org_shape)
if len(imgs) == batch_size:
yield np.array(imgs), img_names, np.array(
valid_shapes), np.array(org_shapes)
imgs, img_names, valid_shapes, org_shapes = [], [], [], []
yield np.array(imgs), np.array(
grts), img_names, np.array(valid_shapes), np.array(
org_shapes)
imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []
if not drop_last and len(imgs) > 0:
yield np.array(imgs), img_names, np.array(
yield np.array(imgs), np.array(grts), img_names, np.array(
valid_shapes), np.array(org_shapes)
else:
imgs, labs, ignore = [], [], []
......@@ -146,93 +148,64 @@ class SegDataset(object):
# reserver alpha channel
cv2_imread_flag = cv2.IMREAD_UNCHANGED
if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
parts = line.strip().split(cfg.DATASET.SEPARATOR)
if len(parts) != 2:
parts = line.strip().split(cfg.DATASET.SEPARATOR)
if len(parts) != 2:
if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
raise Exception("File list format incorrect! It should be"
" image_name{}label_name\\n".format(
cfg.DATASET.SEPARATOR))
img_name, grt_name = parts[0], None
else:
img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(src_dir, img_name)
grt_path = os.path.join(src_dir, grt_name)
img = cv2_imread(img_path, cv2_imread_flag)
img_path = os.path.join(src_dir, img_name)
img = cv2_imread(img_path, cv2_imread_flag)
if grt_name is not None:
grt_path = os.path.join(src_dir, grt_name)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
else:
grt = None
if img is None or grt is None:
raise Exception(
"Empty image, src_dir: {}, img: {} & lab: {}".format(
src_dir, img_path, grt_path))
if img is None:
raise Exception(
"Empty image, src_dir: {}, img: {} & lab: {}".format(
src_dir, img_path, grt_path))
img_height = img.shape[0]
img_width = img.shape[1]
img_height = img.shape[0]
img_width = img.shape[1]
if grt is not None:
grt_height = grt.shape[0]
grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width:
raise Exception(
"source img and label img must has the same size")
if len(img.shape) < 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_channels = img.shape[2]
if img_channels < 3:
raise Exception(
"PaddleSeg only supports gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.format(img_channels, cfg.DATASET.DATADIM, img_name))
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".
format(img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
raise Exception(
"img name {}, img chns {} std size {}, size unequal".format(
img_name, img_channels, len(cfg.STD)))
# visualization mode
elif mode == ModelPhase.VISUAL:
if cfg.DATASET.SEPARATOR in line:
parts = line.strip().split(cfg.DATASET.SEPARATOR)
img_name = parts[0]
else:
img_name = line.strip()
img_path = os.path.join(src_dir, img_name)
img = cv2_imread(img_path, cv2_imread_flag)
if img is None:
raise Exception("empty image, src_dir:{}, img: {}".format(
src_dir, img_name))
# Convert grayscale image to BGR 3 channel image
if len(img.shape) < 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_height = img.shape[0]
img_width = img.shape[1]
img_channels = img.shape[2]
if img_channels < 3:
raise Exception("this repo only recept gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception("data dim must equal to image channels")
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".
format(img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
else:
if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
raise Exception(
"img name {}, img chns {} std size {}, size unequal".format(
img_name, img_channels, len(cfg.STD)))
"Empty image, src_dir: {}, img: {} & lab: {}".format(
src_dir, img_path, grt_path))
grt = None
grt_name = None
else:
raise ValueError("mode error: {}".format(mode))
if len(img.shape) < 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_channels = img.shape[2]
if img_channels < 3:
raise Exception("PaddleSeg only supports gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.format(img_channels, cfg.DATASET.DATADIM, img_name))
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".format(
img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
raise Exception(
"img name {}, img chns {} std size {}, size unequal".format(
img_name, img_channels, len(cfg.STD)))
return img, grt, img_name, grt_name
......@@ -329,4 +302,4 @@ class SegDataset(object):
elif ModelPhase.is_eval(mode):
return (img, grt, ignore)
elif ModelPhase.is_visual(mode):
return (img, img_name, valid_shape, org_shape)
return (img, grt, img_name, valid_shape, org_shape)
......@@ -171,7 +171,7 @@ def visualize(cfg,
fetch_list = [pred.name]
test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True)
img_cnt = 0
for imgs, img_names, valid_shapes, org_shapes in test_reader:
for imgs, grts, img_names, valid_shapes, org_shapes in test_reader:
pred_shape = (imgs.shape[2], imgs.shape[3])
pred, = exe.run(
program=test_prog,
......@@ -185,6 +185,7 @@ def visualize(cfg,
# Add more comments
res_map = np.squeeze(pred[i, :, :, :]).astype(np.uint8)
img_name = img_names[i]
grt = grts[i]
res_shape = (res_map.shape[0], res_map.shape[1])
if res_shape[0] != pred_shape[0] or res_shape[1] != pred_shape[1]:
res_map = cv2.resize(
......@@ -196,6 +197,11 @@ def visualize(cfg,
res_map, (org_shape[1], org_shape[0]),
interpolation=cv2.INTER_NEAREST)
if grt is not None:
grt = cv2.resize(
grt, (org_shape[1], org_shape[0]),
interpolation=cv2.INTER_NEAREST)
png_fn = to_png_fn(img_names[i])
if also_save_raw_results:
raw_fn = os.path.join(raw_save_dir, png_fn)
......@@ -209,6 +215,8 @@ def visualize(cfg,
makedirs(dirname)
pred_mask = colorize(res_map, org_shapes[i], color_map)
if grt is not None:
grt = colorize(grt, org_shapes[i], color_map)
cv2.imwrite(vis_fn, pred_mask)
img_cnt += 1
......@@ -233,7 +241,13 @@ def visualize(cfg,
img,
epoch,
dataformats='HWC')
#TODO: add ground truth (label) images
#add ground truth (label) images
if grt is not None:
log_writer.add_image(
"Label/{}".format(img_names[i]),
grt[..., ::-1],
epoch,
dataformats='HWC')
# If in local_test mode, only visualize 5 images just for testing
# procedure
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册