Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
8d785cff
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8d785cff
编写于
8月 06, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update datasets
上级
967ebd6f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
37 addition
and
63 deletion
+37
-63
dygraph/datasets/dataset.py
dygraph/datasets/dataset.py
+16
-23
dygraph/datasets/optic_disc_seg.py
dygraph/datasets/optic_disc_seg.py
+5
-11
dygraph/datasets/voc.py
dygraph/datasets/voc.py
+13
-19
dygraph/infer.py
dygraph/infer.py
+1
-8
dygraph/train.py
dygraph/train.py
+1
-1
dygraph/val.py
dygraph/val.py
+1
-1
未找到文件。
dygraph/datasets/dataset.py
浏览文件 @
8d785cff
...
@@ -25,8 +25,7 @@ class Dataset(fluid.io.Dataset):
...
@@ -25,8 +25,7 @@ class Dataset(fluid.io.Dataset):
Args:
Args:
data_dir: The dataset directory.
data_dir: The dataset directory.
num_classes: Number of classes.
num_classes: Number of classes.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'test'). Default: 'train'.
mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow:
The contents of train_list file are as follow:
image1.jpg ground_truth1.png
image1.jpg ground_truth1.png
...
@@ -46,7 +45,6 @@ class Dataset(fluid.io.Dataset):
...
@@ -46,7 +45,6 @@ class Dataset(fluid.io.Dataset):
def
__init__
(
self
,
def
__init__
(
self
,
data_dir
,
data_dir
,
num_classes
,
num_classes
,
image_set
=
'train'
,
mode
=
'train'
,
mode
=
'train'
,
train_list
=
None
,
train_list
=
None
,
val_list
=
None
,
val_list
=
None
,
...
@@ -59,21 +57,16 @@ class Dataset(fluid.io.Dataset):
...
@@ -59,21 +57,16 @@ class Dataset(fluid.io.Dataset):
self
.
mode
=
mode
self
.
mode
=
mode
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
if
image_set
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
if
mode
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
raise
Exception
(
raise
Exception
(
"image_set should be one of ('train', 'val', 'test'), but got {}."
"mode should be 'train', 'val' or 'test', but got {}."
.
format
(
.
format
(
image_set
))
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
raise
Exception
(
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
mode
))
if
self
.
transforms
is
None
:
if
self
.
transforms
is
None
:
raise
Exception
(
"transforms is necessary, but it is None."
)
raise
Exception
(
"transforms is necessary, but it is None."
)
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
if
image_set
==
'train'
:
if
mode
==
'train'
:
if
train_list
is
None
:
if
train_list
is
None
:
raise
Exception
(
raise
Exception
(
'When mode is "train", train_list is necessary, but it is None.'
'When mode is "train", train_list is necessary, but it is None.'
...
@@ -83,10 +76,10 @@ class Dataset(fluid.io.Dataset):
...
@@ -83,10 +76,10 @@ class Dataset(fluid.io.Dataset):
'train_list is not found: {}'
.
format
(
train_list
))
'train_list is not found: {}'
.
format
(
train_list
))
else
:
else
:
file_list
=
train_list
file_list
=
train_list
elif
image_set
==
'e
val'
:
elif
mode
==
'
val'
:
if
val_list
is
None
:
if
val_list
is
None
:
raise
Exception
(
raise
Exception
(
'When mode is "
e
val", val_list is necessary, but it is None.'
'When mode is "val", val_list is necessary, but it is None.'
)
)
elif
not
os
.
path
.
exists
(
val_list
):
elif
not
os
.
path
.
exists
(
val_list
):
raise
Exception
(
'val_list is not found: {}'
.
format
(
val_list
))
raise
Exception
(
'val_list is not found: {}'
.
format
(
val_list
))
...
@@ -106,9 +99,9 @@ class Dataset(fluid.io.Dataset):
...
@@ -106,9 +99,9 @@ class Dataset(fluid.io.Dataset):
for
line
in
f
:
for
line
in
f
:
items
=
line
.
strip
().
split
(
separator
)
items
=
line
.
strip
().
split
(
separator
)
if
len
(
items
)
!=
2
:
if
len
(
items
)
!=
2
:
if
mode
==
'train'
or
mode
==
'
e
val'
:
if
mode
==
'train'
or
mode
==
'val'
:
raise
Exception
(
raise
Exception
(
"File list format incorrect! It should be"
"File list format incorrect! I
n training or evaluation task i
t should be"
" image_name{}label_name
\\
n"
.
format
(
separator
))
" image_name{}label_name
\\
n"
.
format
(
separator
))
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
items
[
0
])
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
items
[
0
])
grt_path
=
None
grt_path
=
None
...
@@ -119,19 +112,19 @@ class Dataset(fluid.io.Dataset):
...
@@ -119,19 +112,19 @@ class Dataset(fluid.io.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
image_path
,
grt_path
=
self
.
file_list
[
idx
]
image_path
,
grt_path
=
self
.
file_list
[
idx
]
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'test'
:
im
,
im_info
,
label
=
self
.
transforms
(
im
=
image_path
,
label
=
grt_path
)
im
,
im_info
,
_
=
self
.
transforms
(
im
=
image_path
)
return
im
,
label
im
=
im
[
np
.
newaxis
,
...]
elif
self
.
mode
==
'eval'
:
return
im
,
im_info
,
image_path
elif
self
.
mode
==
'val'
:
im
,
im_info
,
_
=
self
.
transforms
(
im
=
image_path
)
im
,
im_info
,
_
=
self
.
transforms
(
im
=
image_path
)
im
=
im
[
np
.
newaxis
,
...]
im
=
im
[
np
.
newaxis
,
...]
label
=
np
.
asarray
(
Image
.
open
(
grt_path
))
label
=
np
.
asarray
(
Image
.
open
(
grt_path
))
label
=
label
[
np
.
newaxis
,
np
.
newaxis
,
:,
:]
label
=
label
[
np
.
newaxis
,
np
.
newaxis
,
:,
:]
return
im
,
im_info
,
label
return
im
,
im_info
,
label
if
self
.
mode
==
'test'
:
else
:
im
,
im_info
,
_
=
self
.
transforms
(
im
=
image_path
)
im
,
im_info
,
label
=
self
.
transforms
(
im
=
image_path
,
label
=
grt_path
)
im
=
im
[
np
.
newaxis
,
...]
return
im
,
label
return
im
,
im_info
,
image_path
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
file_list
)
return
len
(
self
.
file_list
)
dygraph/datasets/optic_disc_seg.py
浏览文件 @
8d785cff
...
@@ -25,7 +25,6 @@ class OpticDiscSeg(Dataset):
...
@@ -25,7 +25,6 @@ class OpticDiscSeg(Dataset):
def
__init__
(
self
,
def
__init__
(
self
,
data_dir
=
None
,
data_dir
=
None
,
transforms
=
None
,
transforms
=
None
,
image_set
=
'train'
,
mode
=
'train'
,
mode
=
'train'
,
download
=
True
):
download
=
True
):
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
...
@@ -34,14 +33,9 @@ class OpticDiscSeg(Dataset):
...
@@ -34,14 +33,9 @@ class OpticDiscSeg(Dataset):
self
.
mode
=
mode
self
.
mode
=
mode
self
.
num_classes
=
2
self
.
num_classes
=
2
if
image_set
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
if
mode
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
raise
Exception
(
raise
Exception
(
"image_set should be one of ('train', 'val', 'test'), but got {}."
"mode should be 'train', 'val' or 'test', but got {}."
.
format
(
.
format
(
image_set
))
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
raise
Exception
(
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
mode
))
if
self
.
transforms
is
None
:
if
self
.
transforms
is
None
:
...
@@ -53,9 +47,9 @@ class OpticDiscSeg(Dataset):
...
@@ -53,9 +47,9 @@ class OpticDiscSeg(Dataset):
self
.
data_dir
=
download_file_and_uncompress
(
self
.
data_dir
=
download_file_and_uncompress
(
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
)
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
)
if
image_set
==
'train'
:
if
mode
==
'train'
:
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'train_list.txt'
)
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'train_list.txt'
)
elif
image_set
==
'val'
:
elif
mode
==
'val'
:
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'val_list.txt'
)
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'val_list.txt'
)
else
:
else
:
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'test_list.txt'
)
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'test_list.txt'
)
...
@@ -64,7 +58,7 @@ class OpticDiscSeg(Dataset):
...
@@ -64,7 +58,7 @@ class OpticDiscSeg(Dataset):
for
line
in
f
:
for
line
in
f
:
items
=
line
.
strip
().
split
()
items
=
line
.
strip
().
split
()
if
len
(
items
)
!=
2
:
if
len
(
items
)
!=
2
:
if
mode
==
'train'
or
mode
==
'
e
val'
:
if
mode
==
'train'
or
mode
==
'val'
:
raise
Exception
(
raise
Exception
(
"File list format incorrect! It should be"
"File list format incorrect! It should be"
" image_name label_name
\\
n"
)
" image_name label_name
\\
n"
)
...
...
dygraph/datasets/voc.py
浏览文件 @
8d785cff
...
@@ -25,15 +25,13 @@ class PascalVOC(Dataset):
...
@@ -25,15 +25,13 @@ class PascalVOC(Dataset):
please run the voc_augment.py in tools.
please run the voc_augment.py in tools.
Args:
Args:
data_dir: The dataset directory.
data_dir: The dataset directory.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'trainval', 'trainaug). Default: 'train'.
mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
transforms: Transforms for image.
transforms: Transforms for image.
download: Whether to download dataset if data_dir is None.
download: Whether to download dataset if data_dir is None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
data_dir
=
None
,
data_dir
=
None
,
image_set
=
'train'
,
mode
=
'train'
,
mode
=
'train'
,
transforms
=
None
,
transforms
=
None
,
download
=
True
):
download
=
True
):
...
@@ -43,22 +41,17 @@ class PascalVOC(Dataset):
...
@@ -43,22 +41,17 @@ class PascalVOC(Dataset):
self
.
file_list
=
list
()
self
.
file_list
=
list
()
self
.
num_classes
=
21
self
.
num_classes
=
21
if
image_set
.
lower
()
not
in
[
'train'
,
'val'
,
'trainval'
,
'trainaug
'
]:
if
mode
.
lower
()
not
in
[
'train'
,
'trainval'
,
'trainaug'
,
'val
'
]:
raise
Exception
(
raise
Exception
(
"image_set should be one of ('train', 'val', 'trainval', 'trainaug'), but got {}."
"mode should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}."
.
format
(
image_set
))
.
format
(
mode
))
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
raise
Exception
(
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
if
self
.
transforms
is
None
:
if
self
.
transforms
is
None
:
raise
Exception
(
"transforms is necessary, but it is None."
)
raise
Exception
(
"transforms is necessary, but it is None."
)
if
self
.
data_dir
is
None
:
if
self
.
data_dir
is
None
:
if
not
download
:
if
not
download
:
raise
Exception
(
"data_
file
not set and auto download disabled."
)
raise
Exception
(
"data_
dir
not set and auto download disabled."
)
self
.
data_dir
=
download_file_and_uncompress
(
self
.
data_dir
=
download_file_and_uncompress
(
url
=
URL
,
url
=
URL
,
savepath
=
DATA_HOME
,
savepath
=
DATA_HOME
,
...
@@ -68,19 +61,19 @@ class PascalVOC(Dataset):
...
@@ -68,19 +61,19 @@ class PascalVOC(Dataset):
image_set_dir
=
os
.
path
.
join
(
self
.
data_dir
,
'VOC2012'
,
'ImageSets'
,
image_set_dir
=
os
.
path
.
join
(
self
.
data_dir
,
'VOC2012'
,
'ImageSets'
,
'Segmentation'
)
'Segmentation'
)
if
image_set
==
'train'
:
if
mode
==
'train'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'train.txt'
)
file_list
=
os
.
path
.
join
(
image_set_dir
,
'train.txt'
)
elif
image_set
==
'val'
:
elif
mode
==
'val'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'val.txt'
)
file_list
=
os
.
path
.
join
(
image_set_dir
,
'val.txt'
)
elif
image_set
==
'trainval'
:
elif
mode
==
'trainval'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'trainval.txt'
)
file_list
=
os
.
path
.
join
(
image_set_dir
,
'trainval.txt'
)
elif
image_set
==
'trainaug'
:
elif
mode
==
'trainaug'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'train.txt'
)
file_list
=
os
.
path
.
join
(
image_set_dir
,
'train.txt'
)
file_list_aug
=
os
.
path
.
join
(
image_set_dir
,
'aug.txt'
)
file_list_aug
=
os
.
path
.
join
(
image_set_dir
,
'aug.txt'
)
if
not
os
.
path
.
exists
(
file_list_aug
):
if
not
os
.
path
.
exists
(
file_list_aug
):
raise
Exception
(
raise
Exception
(
"When
image_set
is 'trainaug', Pascal Voc dataset should be augmented, "
"When
mode
is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode."
"Please make sure voc_augment.py has been properly run when using this mode."
)
)
...
@@ -95,10 +88,11 @@ class PascalVOC(Dataset):
...
@@ -95,10 +88,11 @@ class PascalVOC(Dataset):
image_path
=
os
.
path
.
join
(
img_dir
,
''
.
join
([
line
,
'.jpg'
]))
image_path
=
os
.
path
.
join
(
img_dir
,
''
.
join
([
line
,
'.jpg'
]))
grt_path
=
os
.
path
.
join
(
grt_dir
,
''
.
join
([
line
,
'.png'
]))
grt_path
=
os
.
path
.
join
(
grt_dir
,
''
.
join
([
line
,
'.png'
]))
self
.
file_list
.
append
([
image_path
,
grt_path
])
self
.
file_list
.
append
([
image_path
,
grt_path
])
if
image_set
==
'trainaug'
:
if
mode
==
'trainaug'
:
with
open
(
file_list_aug
,
'r'
)
as
f
:
with
open
(
file_list_aug
,
'r'
)
as
f
:
for
line
in
f
:
for
line
in
f
:
line
=
line
.
strip
()
line
=
line
.
strip
()
image_path
=
os
.
path
.
join
(
img_dir
,
''
.
join
([
line
,
'.jpg'
]))
image_path
=
os
.
path
.
join
(
img_dir
,
''
.
join
([
line
,
'.jpg'
]))
grt_path
=
os
.
path
.
join
(
grt_dir
,
''
.
join
([
line
,
'.png'
]))
grt_path
=
os
.
path
.
join
(
grt_dir_aug
,
''
.
join
([
line
,
'.png'
]))
self
.
file_list
.
append
([
image_path
,
grt_path
])
self
.
file_list
.
append
([
image_path
,
grt_path
])
dygraph/infer.py
浏览文件 @
8d785cff
...
@@ -13,20 +13,13 @@
...
@@ -13,20 +13,13 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
os
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
import
cv2
import
tqdm
from
datasets
import
DATASETS
from
datasets
import
DATASETS
import
transforms
as
T
import
transforms
as
T
from
models
import
MODELS
from
models
import
MODELS
import
utils
import
utils.logging
as
logging
from
utils
import
get_environ_info
from
utils
import
get_environ_info
from
core
import
infer
from
core
import
infer
...
@@ -43,7 +36,7 @@ def parse_args():
...
@@ -43,7 +36,7 @@ def parse_args():
type
=
str
,
type
=
str
,
default
=
'UNet'
)
default
=
'UNet'
)
# params of
dataset
# params of
infer
parser
.
add_argument
(
parser
.
add_argument
(
'--dataset'
,
'--dataset'
,
dest
=
'dataset'
,
dest
=
'dataset'
,
...
...
dygraph/train.py
浏览文件 @
8d785cff
...
@@ -153,7 +153,7 @@ def main(args):
...
@@ -153,7 +153,7 @@ def main(args):
eval_transforms
=
T
.
Compose
(
eval_transforms
=
T
.
Compose
(
[
T
.
Resize
(
args
.
input_size
),
[
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
T
.
Normalize
()])
eval_dataset
=
dataset
(
transforms
=
eval_transforms
,
mode
=
'
e
val'
)
eval_dataset
=
dataset
(
transforms
=
eval_transforms
,
mode
=
'val'
)
if
args
.
model_name
not
in
MODELS
:
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
raise
Exception
(
...
...
dygraph/val.py
浏览文件 @
8d785cff
...
@@ -87,7 +87,7 @@ def main(args):
...
@@ -87,7 +87,7 @@ def main(args):
with
fluid
.
dygraph
.
guard
(
places
):
with
fluid
.
dygraph
.
guard
(
places
):
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
eval_dataset
=
dataset
(
transforms
=
eval_transforms
,
mode
=
'
e
val'
)
eval_dataset
=
dataset
(
transforms
=
eval_transforms
,
mode
=
'val'
)
if
args
.
model_name
not
in
MODELS
:
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
raise
Exception
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录