提交 5fabe443 编写于 作者: W wuzewu

Add label visualization

上级 01d3f4eb
...@@ -106,19 +106,21 @@ class SegDataset(object): ...@@ -106,19 +106,21 @@ class SegDataset(object):
def batch(self, reader, batch_size, is_test=False, drop_last=False): def batch(self, reader, batch_size, is_test=False, drop_last=False):
def batch_reader(is_test=False, drop_last=drop_last): def batch_reader(is_test=False, drop_last=drop_last):
if is_test: if is_test:
imgs, img_names, valid_shapes, org_shapes = [], [], [], [] imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []
for img, img_name, valid_shape, org_shape in reader(): for img, grt, img_name, valid_shape, org_shape in reader():
imgs.append(img) imgs.append(img)
grts.append(grt)
img_names.append(img_name) img_names.append(img_name)
valid_shapes.append(valid_shape) valid_shapes.append(valid_shape)
org_shapes.append(org_shape) org_shapes.append(org_shape)
if len(imgs) == batch_size: if len(imgs) == batch_size:
yield np.array(imgs), img_names, np.array( yield np.array(imgs), np.array(
valid_shapes), np.array(org_shapes) grts), img_names, np.array(valid_shapes), np.array(
imgs, img_names, valid_shapes, org_shapes = [], [], [], [] org_shapes)
imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []
if not drop_last and len(imgs) > 0: if not drop_last and len(imgs) > 0:
yield np.array(imgs), img_names, np.array( yield np.array(imgs), np.array(grts), img_names, np.array(
valid_shapes), np.array(org_shapes) valid_shapes), np.array(org_shapes)
else: else:
imgs, labs, ignore = [], [], [] imgs, labs, ignore = [], [], []
...@@ -146,93 +148,64 @@ class SegDataset(object): ...@@ -146,93 +148,64 @@ class SegDataset(object):
# reserver alpha channel # reserver alpha channel
cv2_imread_flag = cv2.IMREAD_UNCHANGED cv2_imread_flag = cv2.IMREAD_UNCHANGED
if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL: parts = line.strip().split(cfg.DATASET.SEPARATOR)
parts = line.strip().split(cfg.DATASET.SEPARATOR) if len(parts) != 2:
if len(parts) != 2: if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
raise Exception("File list format incorrect! It should be" raise Exception("File list format incorrect! It should be"
" image_name{}label_name\\n".format( " image_name{}label_name\\n".format(
cfg.DATASET.SEPARATOR)) cfg.DATASET.SEPARATOR))
img_name, grt_name = parts[0], None
else:
img_name, grt_name = parts[0], parts[1] img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(src_dir, img_name)
grt_path = os.path.join(src_dir, grt_name)
img = cv2_imread(img_path, cv2_imread_flag) img_path = os.path.join(src_dir, img_name)
img = cv2_imread(img_path, cv2_imread_flag)
if grt_name is not None:
grt_path = os.path.join(src_dir, grt_name)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
else:
grt = None
if img is None or grt is None: if img is None:
raise Exception( raise Exception(
"Empty image, src_dir: {}, img: {} & lab: {}".format( "Empty image, src_dir: {}, img: {} & lab: {}".format(
src_dir, img_path, grt_path)) src_dir, img_path, grt_path))
img_height = img.shape[0] img_height = img.shape[0]
img_width = img.shape[1] img_width = img.shape[1]
if grt is not None:
grt_height = grt.shape[0] grt_height = grt.shape[0]
grt_width = grt.shape[1] grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width: if img_height != grt_height or img_width != grt_width:
raise Exception( raise Exception(
"source img and label img must has the same size") "source img and label img must has the same size")
else:
if len(img.shape) < 3: if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_channels = img.shape[2]
if img_channels < 3:
raise Exception(
"PaddleSeg only supports gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.format(img_channels, cfg.DATASET.DATADIM, img_name))
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".
format(img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
raise Exception(
"img name {}, img chns {} std size {}, size unequal".format(
img_name, img_channels, len(cfg.STD)))
# visualization mode
elif mode == ModelPhase.VISUAL:
if cfg.DATASET.SEPARATOR in line:
parts = line.strip().split(cfg.DATASET.SEPARATOR)
img_name = parts[0]
else:
img_name = line.strip()
img_path = os.path.join(src_dir, img_name)
img = cv2_imread(img_path, cv2_imread_flag)
if img is None:
raise Exception("empty image, src_dir:{}, img: {}".format(
src_dir, img_name))
# Convert grayscale image to BGR 3 channel image
if len(img.shape) < 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_height = img.shape[0]
img_width = img.shape[1]
img_channels = img.shape[2]
if img_channels < 3:
raise Exception("this repo only recept gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception("data dim must equal to image channels")
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".
format(img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
raise Exception( raise Exception(
"img name {}, img chns {} std size {}, size unequal".format( "Empty image, src_dir: {}, img: {} & lab: {}".format(
img_name, img_channels, len(cfg.STD))) src_dir, img_path, grt_path))
grt = None if len(img.shape) < 3:
grt_name = None img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
raise ValueError("mode error: {}".format(mode)) img_channels = img.shape[2]
if img_channels < 3:
raise Exception("PaddleSeg only supports gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.format(img_channels, cfg.DATASET.DATADIM, img_name))
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".format(
img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
raise Exception(
"img name {}, img chns {} std size {}, size unequal".format(
img_name, img_channels, len(cfg.STD)))
return img, grt, img_name, grt_name return img, grt, img_name, grt_name
...@@ -329,4 +302,4 @@ class SegDataset(object): ...@@ -329,4 +302,4 @@ class SegDataset(object):
elif ModelPhase.is_eval(mode): elif ModelPhase.is_eval(mode):
return (img, grt, ignore) return (img, grt, ignore)
elif ModelPhase.is_visual(mode): elif ModelPhase.is_visual(mode):
return (img, img_name, valid_shape, org_shape) return (img, grt, img_name, valid_shape, org_shape)
...@@ -171,7 +171,7 @@ def visualize(cfg, ...@@ -171,7 +171,7 @@ def visualize(cfg,
fetch_list = [pred.name] fetch_list = [pred.name]
test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True) test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True)
img_cnt = 0 img_cnt = 0
for imgs, img_names, valid_shapes, org_shapes in test_reader: for imgs, grts, img_names, valid_shapes, org_shapes in test_reader:
pred_shape = (imgs.shape[2], imgs.shape[3]) pred_shape = (imgs.shape[2], imgs.shape[3])
pred, = exe.run( pred, = exe.run(
program=test_prog, program=test_prog,
...@@ -185,6 +185,7 @@ def visualize(cfg, ...@@ -185,6 +185,7 @@ def visualize(cfg,
# Add more comments # Add more comments
res_map = np.squeeze(pred[i, :, :, :]).astype(np.uint8) res_map = np.squeeze(pred[i, :, :, :]).astype(np.uint8)
img_name = img_names[i] img_name = img_names[i]
grt = grts[i]
res_shape = (res_map.shape[0], res_map.shape[1]) res_shape = (res_map.shape[0], res_map.shape[1])
if res_shape[0] != pred_shape[0] or res_shape[1] != pred_shape[1]: if res_shape[0] != pred_shape[0] or res_shape[1] != pred_shape[1]:
res_map = cv2.resize( res_map = cv2.resize(
...@@ -196,6 +197,11 @@ def visualize(cfg, ...@@ -196,6 +197,11 @@ def visualize(cfg,
res_map, (org_shape[1], org_shape[0]), res_map, (org_shape[1], org_shape[0]),
interpolation=cv2.INTER_NEAREST) interpolation=cv2.INTER_NEAREST)
if grt is not None:
grt = cv2.resize(
grt, (org_shape[1], org_shape[0]),
interpolation=cv2.INTER_NEAREST)
png_fn = to_png_fn(img_names[i]) png_fn = to_png_fn(img_names[i])
if also_save_raw_results: if also_save_raw_results:
raw_fn = os.path.join(raw_save_dir, png_fn) raw_fn = os.path.join(raw_save_dir, png_fn)
...@@ -209,6 +215,8 @@ def visualize(cfg, ...@@ -209,6 +215,8 @@ def visualize(cfg,
makedirs(dirname) makedirs(dirname)
pred_mask = colorize(res_map, org_shapes[i], color_map) pred_mask = colorize(res_map, org_shapes[i], color_map)
if grt is not None:
grt = colorize(grt, org_shapes[i], color_map)
cv2.imwrite(vis_fn, pred_mask) cv2.imwrite(vis_fn, pred_mask)
img_cnt += 1 img_cnt += 1
...@@ -233,7 +241,13 @@ def visualize(cfg, ...@@ -233,7 +241,13 @@ def visualize(cfg,
img, img,
epoch, epoch,
dataformats='HWC') dataformats='HWC')
#TODO: add ground truth (label) images #add ground truth (label) images
if grt is not None:
log_writer.add_image(
"Label/{}".format(img_names[i]),
grt[..., ::-1],
epoch,
dataformats='HWC')
# If in local_test mode, only visualize 5 images just for testing # If in local_test mode, only visualize 5 images just for testing
# procedure # procedure
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册