Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
b4e23302
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b4e23302
编写于
6月 02, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm ArrageSegmeter
上级
578f83f0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
8 addition
and
72 deletion
+8
-72
dygraph/infer.py
dygraph/infer.py
+1
-10
dygraph/train.py
dygraph/train.py
+1
-11
dygraph/transforms/transforms.py
dygraph/transforms/transforms.py
+4
-40
dygraph/val.py
dygraph/val.py
+2
-11
未找到文件。
dygraph/infer.py
浏览文件 @
b4e23302
...
...
@@ -111,7 +111,7 @@ def infer(model, data_dir=None, test_list=None, model_dir=None,
for
file
in
tqdm
.
tqdm
(
files
):
file
=
file
.
strip
()
im_file
=
osp
.
join
(
data_dir
,
file
)
im
,
im_info
=
transforms
(
im_file
)
im
,
im_info
,
_
=
transforms
(
im_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im
=
to_variable
(
im
)
...
...
@@ -140,17 +140,8 @@ def infer(model, data_dir=None, test_list=None, model_dir=None,
cv2
.
imwrite
(
pred_saved_path
,
pred_im
)
def
arrange_transform
(
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
main
(
args
):
test_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
arrange_transform
(
test_transforms
,
mode
=
'test'
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
args
.
num_classes
)
...
...
dygraph/train.py
浏览文件 @
b4e23302
...
...
@@ -143,7 +143,7 @@ def train(model,
for
epoch
in
range
(
num_epochs
):
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
]).
astype
(
'int64'
)
labels
=
np
.
array
([
d
[
2
]
for
d
in
data
]).
astype
(
'int64'
)
images
=
to_variable
(
images
)
labels
=
to_variable
(
labels
)
loss
=
model
(
images
,
labels
,
mode
=
'train'
)
...
...
@@ -175,21 +175,12 @@ def train(model,
model
.
train
()
def
arrange_transform
(
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
main
(
args
):
# Creat dataset reader
train_transforms
=
T
.
Compose
(
[
T
.
Resize
(
args
.
input_size
),
T
.
RandomHorizontalFlip
(),
T
.
Normalize
()])
arrange_transform
(
train_transforms
,
mode
=
'train'
)
train_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
train_list
,
...
...
@@ -200,7 +191,6 @@ def main(args):
shuffle
=
True
)
if
args
.
val_list
is
not
None
:
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
arrange_transform
(
eval_transforms
,
mode
=
'eval'
)
eval_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
val_list
,
...
...
dygraph/transforms/transforms.py
浏览文件 @
b4e23302
...
...
@@ -74,7 +74,10 @@ class Compose:
im_info
=
outputs
[
1
]
if
len
(
outputs
)
==
3
:
label
=
outputs
[
2
]
return
outputs
im
=
permute
(
im
)
if
len
(
outputs
)
==
3
:
label
=
label
[
np
.
newaxis
,
:,
:]
return
(
im
,
im_info
,
label
)
class
RandomHorizontalFlip
:
...
...
@@ -873,42 +876,3 @@ class RandomDistort:
return
(
im
,
im_info
)
else
:
return
(
im
,
im_info
,
label
)
class
ArrangeSegmenter
:
"""获取训练/验证/预测所需的信息。
Args:
mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
Raises:
ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内
"""
def
__init__
(
self
,
mode
):
if
mode
not
in
[
'train'
,
'eval'
,
'test'
,
'quant'
]:
raise
ValueError
(
"mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
)
self
.
mode
=
mode
def
__call__
(
self
,
im
,
im_info
,
label
=
None
):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict): 存储与图像相关的信息。
label (np.ndarray): 标注图像np.ndarray数据。
Returns:
tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为
'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
"""
im
=
permute
(
im
)
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
label
=
label
[
np
.
newaxis
,
:,
:]
return
(
im
,
label
)
elif
self
.
mode
==
'test'
:
return
(
im
,
im_info
)
else
:
return
(
im
,
)
dygraph/val.py
浏览文件 @
b4e23302
...
...
@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.base import to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
from
datasets
.dataset
import
Dataset
from
datasets
import
Dataset
import
transforms
as
T
import
models
import
utils.logging
as
logging
...
...
@@ -112,7 +112,7 @@ def evaluate(model,
eval_dataset
.
num_samples
,
total_steps
))
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
]).
astype
(
'int64'
)
labels
=
np
.
array
([
d
[
2
]
for
d
in
data
]).
astype
(
'int64'
)
images
=
to_variable
(
images
)
pred
,
_
=
model
(
images
,
labels
,
mode
=
'eval'
)
...
...
@@ -134,17 +134,8 @@ def evaluate(model,
logging
.
info
(
"[EVAL] Kappa:{:.4f} "
.
format
(
conf_mat
.
kappa
()))
def
arrange_transform
(
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
main
(
args
):
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
arrange_transform
(
eval_transforms
,
mode
=
'eval'
)
eval_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
val_list
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录