Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
5fabe443
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5fabe443
编写于
8月 29, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add label visualization
上级
01d3f4eb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
67 addition
and
80 deletion
+67
-80
pdseg/reader.py
pdseg/reader.py
+51
-78
pdseg/vis.py
pdseg/vis.py
+16
-2
未找到文件。
pdseg/reader.py
浏览文件 @
5fabe443
...
...
@@ -106,19 +106,21 @@ class SegDataset(object):
def
batch
(
self
,
reader
,
batch_size
,
is_test
=
False
,
drop_last
=
False
):
def
batch_reader
(
is_test
=
False
,
drop_last
=
drop_last
):
if
is_test
:
imgs
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[]
for
img
,
img_name
,
valid_shape
,
org_shape
in
reader
():
imgs
,
grts
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[]
for
img
,
grt
,
img_name
,
valid_shape
,
org_shape
in
reader
():
imgs
.
append
(
img
)
grts
.
append
(
grt
)
img_names
.
append
(
img_name
)
valid_shapes
.
append
(
valid_shape
)
org_shapes
.
append
(
org_shape
)
if
len
(
imgs
)
==
batch_size
:
yield
np
.
array
(
imgs
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
imgs
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[]
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
imgs
,
grts
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[]
if
not
drop_last
and
len
(
imgs
)
>
0
:
yield
np
.
array
(
imgs
),
img_names
,
np
.
array
(
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
else
:
imgs
,
labs
,
ignore
=
[],
[],
[]
...
...
@@ -146,93 +148,64 @@ class SegDataset(object):
# reserver alpha channel
cv2_imread_flag
=
cv2
.
IMREAD_UNCHANGED
if
mode
==
ModelPhase
.
TRAIN
or
mode
==
ModelPhase
.
EVAL
:
parts
=
line
.
strip
().
split
(
cfg
.
DATASET
.
SEPARATOR
)
if
len
(
parts
)
!=
2
:
parts
=
line
.
strip
().
split
(
cfg
.
DATASET
.
SEPARATOR
)
if
len
(
parts
)
!=
2
:
if
mode
==
ModelPhase
.
TRAIN
or
mode
==
ModelPhase
.
EVAL
:
raise
Exception
(
"File list format incorrect! It should be"
" image_name{}label_name
\\
n"
.
format
(
cfg
.
DATASET
.
SEPARATOR
))
img_name
,
grt_name
=
parts
[
0
],
None
else
:
img_name
,
grt_name
=
parts
[
0
],
parts
[
1
]
img_path
=
os
.
path
.
join
(
src_dir
,
img_name
)
grt_path
=
os
.
path
.
join
(
src_dir
,
grt_name
)
img
=
cv2_imread
(
img_path
,
cv2_imread_flag
)
img_path
=
os
.
path
.
join
(
src_dir
,
img_name
)
img
=
cv2_imread
(
img_path
,
cv2_imread_flag
)
if
grt_name
is
not
None
:
grt_path
=
os
.
path
.
join
(
src_dir
,
grt_name
)
grt
=
cv2_imread
(
grt_path
,
cv2
.
IMREAD_GRAYSCALE
)
else
:
grt
=
None
if
img
is
None
or
grt
is
None
:
raise
Exception
(
"Empty image, src_dir: {}, img: {} & lab: {}"
.
format
(
src_dir
,
img_path
,
grt_path
))
if
img
is
None
:
raise
Exception
(
"Empty image, src_dir: {}, img: {} & lab: {}"
.
format
(
src_dir
,
img_path
,
grt_path
))
img_height
=
img
.
shape
[
0
]
img_width
=
img
.
shape
[
1
]
img_height
=
img
.
shape
[
0
]
img_width
=
img
.
shape
[
1
]
if
grt
is
not
None
:
grt_height
=
grt
.
shape
[
0
]
grt_width
=
grt
.
shape
[
1
]
if
img_height
!=
grt_height
or
img_width
!=
grt_width
:
raise
Exception
(
"source img and label img must has the same size"
)
if
len
(
img
.
shape
)
<
3
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
img_channels
=
img
.
shape
[
2
]
if
img_channels
<
3
:
raise
Exception
(
"PaddleSeg only supports gray, rgb or rgba image"
)
if
img_channels
!=
cfg
.
DATASET
.
DATA_DIM
:
raise
Exception
(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.
format
(
img_channels
,
cfg
.
DATASET
.
DATADIM
,
img_name
))
if
img_channels
!=
len
(
cfg
.
MEAN
):
raise
Exception
(
"img name {}, img chns {} mean size {}, size unequal"
.
format
(
img_name
,
img_channels
,
len
(
cfg
.
MEAN
)))
if
img_channels
!=
len
(
cfg
.
STD
):
raise
Exception
(
"img name {}, img chns {} std size {}, size unequal"
.
format
(
img_name
,
img_channels
,
len
(
cfg
.
STD
)))
# visualization mode
elif
mode
==
ModelPhase
.
VISUAL
:
if
cfg
.
DATASET
.
SEPARATOR
in
line
:
parts
=
line
.
strip
().
split
(
cfg
.
DATASET
.
SEPARATOR
)
img_name
=
parts
[
0
]
else
:
img_name
=
line
.
strip
()
img_path
=
os
.
path
.
join
(
src_dir
,
img_name
)
img
=
cv2_imread
(
img_path
,
cv2_imread_flag
)
if
img
is
None
:
raise
Exception
(
"empty image, src_dir:{}, img: {}"
.
format
(
src_dir
,
img_name
))
# Convert grayscale image to BGR 3 channel image
if
len
(
img
.
shape
)
<
3
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
img_height
=
img
.
shape
[
0
]
img_width
=
img
.
shape
[
1
]
img_channels
=
img
.
shape
[
2
]
if
img_channels
<
3
:
raise
Exception
(
"this repo only recept gray, rgb or rgba image"
)
if
img_channels
!=
cfg
.
DATASET
.
DATA_DIM
:
raise
Exception
(
"data dim must equal to image channels"
)
if
img_channels
!=
len
(
cfg
.
MEAN
):
raise
Exception
(
"img name {}, img chns {} mean size {}, size unequal"
.
format
(
img_name
,
img_channels
,
len
(
cfg
.
MEAN
)))
if
img_channels
!=
len
(
cfg
.
STD
):
else
:
if
mode
==
ModelPhase
.
TRAIN
or
mode
==
ModelPhase
.
EVAL
:
raise
Exception
(
"
img name {}, img chns {} std size {}, size unequal
"
.
format
(
img_name
,
img_channels
,
len
(
cfg
.
STD
)
))
"
Empty image, src_dir: {}, img: {} & lab: {}
"
.
format
(
src_dir
,
img_path
,
grt_path
))
grt
=
None
grt_name
=
None
else
:
raise
ValueError
(
"mode error: {}"
.
format
(
mode
))
if
len
(
img
.
shape
)
<
3
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
img_channels
=
img
.
shape
[
2
]
if
img_channels
<
3
:
raise
Exception
(
"PaddleSeg only supports gray, rgb or rgba image"
)
if
img_channels
!=
cfg
.
DATASET
.
DATA_DIM
:
raise
Exception
(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.
format
(
img_channels
,
cfg
.
DATASET
.
DATADIM
,
img_name
))
if
img_channels
!=
len
(
cfg
.
MEAN
):
raise
Exception
(
"img name {}, img chns {} mean size {}, size unequal"
.
format
(
img_name
,
img_channels
,
len
(
cfg
.
MEAN
)))
if
img_channels
!=
len
(
cfg
.
STD
):
raise
Exception
(
"img name {}, img chns {} std size {}, size unequal"
.
format
(
img_name
,
img_channels
,
len
(
cfg
.
STD
)))
return
img
,
grt
,
img_name
,
grt_name
...
...
@@ -329,4 +302,4 @@ class SegDataset(object):
elif
ModelPhase
.
is_eval
(
mode
):
return
(
img
,
grt
,
ignore
)
elif
ModelPhase
.
is_visual
(
mode
):
return
(
img
,
img_name
,
valid_shape
,
org_shape
)
return
(
img
,
grt
,
img_name
,
valid_shape
,
org_shape
)
pdseg/vis.py
浏览文件 @
5fabe443
...
...
@@ -171,7 +171,7 @@ def visualize(cfg,
fetch_list
=
[
pred
.
name
]
test_reader
=
dataset
.
batch
(
dataset
.
generator
,
batch_size
=
1
,
is_test
=
True
)
img_cnt
=
0
for
imgs
,
img_names
,
valid_shapes
,
org_shapes
in
test_reader
:
for
imgs
,
grts
,
img_names
,
valid_shapes
,
org_shapes
in
test_reader
:
pred_shape
=
(
imgs
.
shape
[
2
],
imgs
.
shape
[
3
])
pred
,
=
exe
.
run
(
program
=
test_prog
,
...
...
@@ -185,6 +185,7 @@ def visualize(cfg,
# Add more comments
res_map
=
np
.
squeeze
(
pred
[
i
,
:,
:,
:]).
astype
(
np
.
uint8
)
img_name
=
img_names
[
i
]
grt
=
grts
[
i
]
res_shape
=
(
res_map
.
shape
[
0
],
res_map
.
shape
[
1
])
if
res_shape
[
0
]
!=
pred_shape
[
0
]
or
res_shape
[
1
]
!=
pred_shape
[
1
]:
res_map
=
cv2
.
resize
(
...
...
@@ -196,6 +197,11 @@ def visualize(cfg,
res_map
,
(
org_shape
[
1
],
org_shape
[
0
]),
interpolation
=
cv2
.
INTER_NEAREST
)
if
grt
is
not
None
:
grt
=
cv2
.
resize
(
grt
,
(
org_shape
[
1
],
org_shape
[
0
]),
interpolation
=
cv2
.
INTER_NEAREST
)
png_fn
=
to_png_fn
(
img_names
[
i
])
if
also_save_raw_results
:
raw_fn
=
os
.
path
.
join
(
raw_save_dir
,
png_fn
)
...
...
@@ -209,6 +215,8 @@ def visualize(cfg,
makedirs
(
dirname
)
pred_mask
=
colorize
(
res_map
,
org_shapes
[
i
],
color_map
)
if
grt
is
not
None
:
grt
=
colorize
(
grt
,
org_shapes
[
i
],
color_map
)
cv2
.
imwrite
(
vis_fn
,
pred_mask
)
img_cnt
+=
1
...
...
@@ -233,7 +241,13 @@ def visualize(cfg,
img
,
epoch
,
dataformats
=
'HWC'
)
#TODO: add ground truth (label) images
#add ground truth (label) images
if
grt
is
not
None
:
log_writer
.
add_image
(
"Label/{}"
.
format
(
img_names
[
i
]),
grt
[...,
::
-
1
],
epoch
,
dataformats
=
'HWC'
)
# If in local_test mode, only visualize 5 images just for testing
# procedure
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录