Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
085dabf3
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
085dabf3
编写于
6月 01, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add infer for save prediction result
上级
84426956
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
177 addition
and
4 deletion
+177
-4
dygraph/infer.py
dygraph/infer.py
+174
-0
dygraph/utils/utils.py
dygraph/utils/utils.py
+3
-4
未找到文件。
dygraph/infer.py
0 → 100644
浏览文件 @
085dabf3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
import
cv2
import
tqdm
import
transforms
as
T
import
models
import
utils
import
utils.logging
as
logging
from
utils
import
get_environ_info
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Model training'
)
# params of model
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
"Model type for traing, which is one of ('UNet')"
,
type
=
str
,
default
=
'UNet'
)
# params of dataset
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'The root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--test_list'
,
dest
=
'test_list'
,
help
=
'Val list file of dataset'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--num_classes'
,
dest
=
'num_classes'
,
help
=
'Number of classes'
,
type
=
int
,
default
=
2
)
# params of prediction
parser
.
add_argument
(
"--input_size"
,
dest
=
"input_size"
,
help
=
"The image size for net inputs."
,
nargs
=
2
,
default
=
[
512
,
512
],
type
=
int
)
parser
.
add_argument
(
'--batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--model_dir'
,
dest
=
'model_dir'
,
help
=
'The path of model for evaluation'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the inference results'
,
type
=
str
,
default
=
'./output/result'
)
return
parser
.
parse_args
()
def
mkdir
(
path
):
sub_dir
=
osp
.
dirname
(
path
)
if
not
osp
.
exists
(
sub_dir
):
os
.
makedirs
(
sub_dir
)
def
infer
(
model
,
data_dir
=
None
,
test_list
=
None
,
model_dir
=
None
,
transforms
=
None
):
ckpt_path
=
osp
.
join
(
model_dir
,
'model'
)
para_state_dict
,
opti_state_dict
=
fluid
.
load_dygraph
(
ckpt_path
)
model
.
set_dict
(
para_state_dict
)
model
.
eval
()
added_saved_dir
=
osp
.
join
(
args
.
save_dir
,
'added'
)
pred_saved_dir
=
osp
.
join
(
args
.
save_dir
,
'prediction'
)
logging
.
info
(
"Start to predict..."
)
with
open
(
test_list
,
'r'
)
as
f
:
files
=
f
.
readlines
()
for
file
in
tqdm
.
tqdm
(
files
):
file
=
file
.
strip
()
im_file
=
osp
.
join
(
data_dir
,
file
)
im
,
im_info
=
transforms
(
im_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im
=
to_variable
(
im
)
pred
,
_
=
model
(
im
,
mode
=
'test'
)
pred
=
pred
.
numpy
()
pred
=
np
.
squeeze
(
pred
).
astype
(
'uint8'
)
keys
=
list
(
im_info
.
keys
())
for
k
in
keys
[::
-
1
]:
if
k
==
'shape_before_resize'
:
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
pred
=
cv2
.
resize
(
pred
,
(
w
,
h
),
cv2
.
INTER_NEAREST
)
elif
k
==
'shape_before_padding'
:
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
pred
=
pred
[
0
:
h
,
0
:
w
]
# save added image
added_image
=
utils
.
visualize
(
im_file
,
pred
,
weight
=
0.6
)
added_image_path
=
osp
.
join
(
added_saved_dir
,
file
)
mkdir
(
added_image_path
)
cv2
.
imwrite
(
added_image_path
,
added_image
)
# save prediction
pred_im
=
utils
.
visualize
(
im_file
,
pred
,
weight
=
0.0
)
pred_saved_path
=
osp
.
join
(
pred_saved_dir
,
file
)
mkdir
(
pred_saved_path
)
cv2
.
imwrite
(
pred_saved_path
,
pred_im
)
def
arrange_transform
(
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
main
(
args
):
test_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
arrange_transform
(
test_transforms
,
mode
=
'test'
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
args
.
num_classes
)
infer
(
model
,
data_dir
=
args
.
data_dir
,
test_list
=
args
.
test_list
,
model_dir
=
args
.
model_dir
,
transforms
=
test_transforms
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
env_info
=
get_environ_info
()
if
env_info
[
'place'
]
==
'cpu'
:
places
=
fluid
.
CPUPlace
()
else
:
places
=
fluid
.
CUDAPlace
(
0
)
with
fluid
.
dygraph
.
guard
(
places
):
main
(
args
)
dygraph/utils/utils.py
浏览文件 @
085dabf3
...
@@ -228,13 +228,12 @@ def visualize(image, result, save_dir=None, weight=0.6):
...
@@ -228,13 +228,12 @@ def visualize(image, result, save_dir=None, weight=0.6):
save_dir: the directory for saving visual image
save_dir: the directory for saving visual image
weight: the image weight of visual image, and the result weight is (1 - weight)
weight: the image weight of visual image, and the result weight is (1 - weight)
"""
"""
label_map
=
result
[
'label_map'
]
color_map
=
get_color_map_list
(
256
)
color_map
=
get_color_map_list
(
256
)
color_map
=
np
.
array
(
color_map
).
astype
(
"uint8"
)
color_map
=
np
.
array
(
color_map
).
astype
(
"uint8"
)
# Use OpenCV LUT for color mapping
# Use OpenCV LUT for color mapping
c1
=
cv2
.
LUT
(
label_map
,
color_map
[:,
0
])
c1
=
cv2
.
LUT
(
result
,
color_map
[:,
0
])
c2
=
cv2
.
LUT
(
label_map
,
color_map
[:,
1
])
c2
=
cv2
.
LUT
(
result
,
color_map
[:,
1
])
c3
=
cv2
.
LUT
(
label_map
,
color_map
[:,
2
])
c3
=
cv2
.
LUT
(
result
,
color_map
[:,
2
])
pseudo_img
=
np
.
dstack
((
c1
,
c2
,
c3
))
pseudo_img
=
np
.
dstack
((
c1
,
c2
,
c3
))
im
=
cv2
.
imread
(
image
)
im
=
cv2
.
imread
(
image
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录