Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
3639c2de
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
288
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3639c2de
编写于
8月 05, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update dataset, add voc and voc aug
上级
2dd6872e
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
263 addition
and
20 deletion
+263
-20
dygraph/datasets/__init__.py
dygraph/datasets/__init__.py
+1
-0
dygraph/datasets/cityscapes.py
dygraph/datasets/cityscapes.py
+2
-3
dygraph/datasets/dataset.py
dygraph/datasets/dataset.py
+41
-9
dygraph/datasets/optic_disc_seg.py
dygraph/datasets/optic_disc_seg.py
+10
-5
dygraph/datasets/voc.py
dygraph/datasets/voc.py
+104
-0
dygraph/tools/voc_augment.py
dygraph/tools/voc_augment.py
+102
-0
dygraph/utils/download.py
dygraph/utils/download.py
+3
-3
未找到文件。
dygraph/datasets/__init__.py
浏览文件 @
3639c2de
...
...
@@ -15,3 +15,4 @@
from
.dataset
import
Dataset
from
.optic_disc_seg
import
OpticDiscSeg
from
.cityscapes
import
Cityscapes
from
.voc
import
PascalVoc
dygraph/datasets/cityscapes.py
浏览文件 @
3639c2de
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -39,9 +39,8 @@ class Cityscapes(Dataset):
mode
))
if
self
.
transforms
is
None
:
raise
Exception
(
"transform is necessary, but it is None."
)
raise
Exception
(
"transform
s
is necessary, but it is None."
)
self
.
data_dir
=
data_dir
if
self
.
data_dir
is
None
:
if
not
download
:
raise
Exception
(
"data_file not set and auto download disabled."
)
...
...
dygraph/datasets/dataset.py
浏览文件 @
3639c2de
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -20,43 +20,74 @@ from PIL import Image
class
Dataset
(
fluid
.
io
.
Dataset
):
"""Pass in a custom dataset that conforms to the format.
Args:
data_dir: The dataset directory.
num_classes: Number of classes.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow:
image1.jpg ground_truth1.png
image2.jpg ground_truth2.png
val_list: The evaluation dataset file. When image_set is 'val', val_list is necessary.
The contents is the same as train_list
test_list: The test dataset file. When image_set is 'test', test_list is necessary.
The annotation file is not necessary in test_list file.
separator: The separator of dataset list. Default: ' '.
transforms: Transforms for image.
Examples:
todo
"""
def
__init__
(
self
,
data_dir
,
num_classes
,
image_set
=
'train'
,
mode
=
'train'
,
train_list
=
None
,
val_list
=
None
,
test_list
=
None
,
separator
=
' '
,
transforms
=
None
,
mode
=
'train'
):
transforms
=
None
):
self
.
data_dir
=
data_dir
self
.
transforms
=
transforms
self
.
file_list
=
list
()
self
.
mode
=
mode
self
.
num_classes
=
num_classes
if
image_set
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
raise
Exception
(
"image_set should be one of ('train', 'val', 'test'), but got {}."
.
format
(
image_set
))
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
raise
Exception
(
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
if
self
.
transforms
is
None
:
raise
Exception
(
"transform is necessary, but it is None."
)
raise
Exception
(
"transform
s
is necessary, but it is None."
)
self
.
data_dir
=
data_dir
if
mode
==
'train'
:
if
image_set
==
'train'
:
if
train_list
is
None
:
raise
Exception
(
'When mode is "train", train_list is need, but it is None.'
)
'When mode is "train", train_list is necessary, but it is None.'
)
elif
not
os
.
path
.
exists
(
train_list
):
raise
Exception
(
'train_list is not found: {}'
.
format
(
train_list
))
else
:
file_list
=
train_list
elif
mode
==
'eval'
:
elif
image_set
==
'eval'
:
if
val_list
is
None
:
raise
Exception
(
'When mode is "eval", val_list is need, but it is None.'
)
'When mode is "eval", val_list is necessary, but it is None.'
)
elif
not
os
.
path
.
exists
(
val_list
):
raise
Exception
(
'val_list is not found: {}'
.
format
(
val_list
))
else
:
...
...
@@ -64,7 +95,8 @@ class Dataset(fluid.io.Dataset):
else
:
if
test_list
is
None
:
raise
Exception
(
'When mode is "test", test_list is need, but it is None.'
)
'When mode is "test", test_list is necessary, but it is None.'
)
elif
not
os
.
path
.
exists
(
test_list
):
raise
Exception
(
'test_list is not found: {}'
.
format
(
test_list
))
else
:
...
...
dygraph/datasets/optic_disc_seg.py
浏览文件 @
3639c2de
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -25,6 +25,7 @@ class OpticDiscSeg(Dataset):
def
__init__
(
self
,
data_dir
=
None
,
transforms
=
None
,
image_set
=
'train'
,
mode
=
'train'
,
download
=
True
):
self
.
data_dir
=
data_dir
...
...
@@ -33,24 +34,28 @@ class OpticDiscSeg(Dataset):
self
.
mode
=
mode
self
.
num_classes
=
2
if
image_set
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
raise
Exception
(
"image_set should be one of ('train', 'val', 'test'), but got {}."
.
format
(
image_set
))
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
raise
Exception
(
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
if
self
.
transforms
is
None
:
raise
Exception
(
"transform is necessary, but it is None."
)
raise
Exception
(
"transform
s
is necessary, but it is None."
)
self
.
data_dir
=
data_dir
if
self
.
data_dir
is
None
:
if
not
download
:
raise
Exception
(
"data_file not set and auto download disabled."
)
self
.
data_dir
=
download_file_and_uncompress
(
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
)
if
mode
==
'train'
:
if
image_set
==
'train'
:
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'train_list.txt'
)
elif
mode
==
'e
val'
:
elif
image_set
==
'
val'
:
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'val_list.txt'
)
else
:
file_list
=
os
.
path
.
join
(
self
.
data_dir
,
'test_list.txt'
)
...
...
dygraph/datasets/voc.py
0 → 100644
浏览文件 @
3639c2de
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
.dataset
import
Dataset
from
utils.download
import
download_file_and_uncompress
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
class
PascalVoc
(
Dataset
):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
Args:
data_dir: The dataset directory.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'trainval', 'trainaug). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if data_dir is None.
"""
def
__init__
(
self
,
data_dir
=
None
,
image_set
=
'train'
,
mode
=
'train'
,
transforms
=
None
,
download
=
False
):
self
.
data_dir
=
data_dir
self
.
transforms
=
transforms
self
.
mode
=
mode
self
.
file_list
=
list
()
self
.
num_classes
=
21
if
image_set
.
lower
()
not
in
[
'train'
,
'val'
,
'trainval'
,
'trainaug'
]:
raise
Exception
(
"image_set should be one of ('train', 'val', 'trainval', 'trainaug'), but got {}."
.
format
(
image_set
))
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
raise
Exception
(
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
if
self
.
transforms
is
None
:
raise
Exception
(
"transforms is necessary, but it is None."
)
if
self
.
data_dir
is
None
:
if
not
download
:
raise
Exception
(
"data_file not set and auto download disabled."
)
self
.
data_dir
=
download_file_and_uncompress
(
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
,
extraname
=
'VOCdevkit'
)
print
(
self
.
data_dir
)
image_set_dir
=
os
.
path
.
join
(
self
.
data_dir
,
'VOC2012'
,
'ImageSets'
,
'Segmentation'
)
if
image_set
==
'train'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'train.txt'
)
elif
image_set
==
'val'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'val.txt'
)
elif
image_set
==
'trainval'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'trainval.txt'
)
elif
image_set
==
'trainaug'
:
file_list
=
os
.
path
.
join
(
image_set_dir
,
'train.txt'
)
file_list_aug
=
os
.
path
.
join
(
image_set_dir
,
'aug.txt'
)
if
not
os
.
path
.
exists
(
file_list_aug
):
raise
Exception
(
"When image_set is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode."
)
img_dir
=
os
.
path
.
join
(
self
.
data_dir
,
'VOC2012'
,
'JPEGImages'
)
grt_dir
=
os
.
path
.
join
(
self
.
data_dir
,
'VOC2012'
,
'SegmentationClass'
)
grt_dir_aug
=
os
.
path
.
join
(
self
.
data_dir
,
'VOC2012'
,
'SegmentationClassAug'
)
with
open
(
file_list
,
'r'
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
image_path
=
os
.
path
.
join
(
img_dir
,
''
.
join
([
line
,
'.jpg'
]))
grt_path
=
os
.
path
.
join
(
grt_dir
,
''
.
join
([
line
,
'.png'
]))
self
.
file_list
.
append
([
image_path
,
grt_path
])
if
image_set
==
'trainaug'
:
with
open
(
file_list_aug
,
'r'
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
image_path
=
os
.
path
.
join
(
img_dir
,
''
.
join
([
line
,
'.jpg'
]))
grt_path
=
os
.
path
.
join
(
grt_dir
,
''
.
join
([
line
,
'.png'
]))
self
.
file_list
.
append
([
image_path
,
grt_path
])
dygraph/tools/voc_augment.py
0 → 100644
浏览文件 @
3639c2de
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: voc_augment.py
This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html>
to augment the Pascal VOC
"""
import
os
import
argparse
from
multiprocessing
import
Pool
,
cpu_count
import
cv2
import
numpy
as
np
from
scipy.io
import
loadmat
import
tqdm
from
utils.download
import
download_file_and_uncompress
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz'
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert SBD to Pascal Voc annotations to augment the train dataset of Pascal Voc'
)
parser
.
add_argument
(
'--voc_path'
,
dest
=
'voc_path'
,
help
=
'pascal voc path'
,
type
=
str
,
default
=
os
.
path
.
join
(
DATA_HOME
+
'VOCdevkit'
))
parser
.
add_argument
(
'--num_workers'
,
dest
=
'num_workers'
,
help
=
'How many processes are used for data conversion'
,
type
=
str
,
default
=
cpu_count
())
return
parser
.
parse_args
()
def
conver_to_png
(
mat_file
,
sbd_cls_dir
,
save_dir
):
mat_path
=
os
.
path
.
join
(
sbd_cls_dir
,
mat_file
)
mat
=
loadmat
(
mat_path
)
mask
=
mat
[
'GTcls'
][
0
][
'Segmentation'
][
0
].
astype
(
np
.
uint8
)
save_file
=
os
.
path
.
join
(
save_dir
,
mat_file
.
replace
(
'mat'
,
'png'
))
cv2
.
imwrite
(
save_file
,
mask
)
def
main
():
args
=
parse_args
()
sbd_path
=
download_file_and_uncompress
(
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
,
extraname
=
'benchmark_RELEASE'
)
with
open
(
os
.
path
.
join
(
sbd_path
,
'dataset/train.txt'
),
'r'
)
as
f
:
sbd_file_list
=
[
line
.
strip
()
for
line
in
f
]
with
open
(
os
.
path
.
join
(
sbd_path
,
'dataset/val.txt'
),
'r'
)
as
f
:
sbd_file_list
+=
[
line
.
strip
()
for
line
in
f
]
if
not
os
.
path
.
exists
(
args
.
voc_path
):
raise
Exception
(
'Ther is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
)
with
open
(
os
.
path
.
join
(
args
.
voc_path
,
'VOC2012/ImageSets/Segmentation/trainval.txt'
,
'r'
))
as
f
:
voc_file_list
=
[
line
.
strip
()
for
line
in
f
]
aug_file_list
=
list
(
set
(
sbd_file_list
)
-
set
(
voc_file_list
))
with
open
(
os
.
path
.
join
(
args
.
voc_path
,
'VOC2012/ImageSets/Segmentation/aug.txt'
,
'w'
))
as
f
:
f
.
writelines
(
''
.
join
(
line
,
'
\n
'
)
for
line
in
aug_file_list
)
sbd_cls_dir
=
os
.
path
.
join
(
sbd_path
,
'dataset/cls'
)
save_dir
=
os
.
path
.
join
(
args
.
voc_path
,
'VOC2012/ImageSets/SegmentationClassAug'
)
mat_file_list
=
os
.
listdir
(
sbd_cls_dir
)
p
=
Pool
(
args
.
num_workers
)
for
f
in
tqdm
.
tqdm
(
mat_file_list
):
p
.
apply_async
(
conver_to_png
,
args
=
(
f
,
sbd_cls_dir
,
save_dir
))
if
__name__
==
'__main__'
:
main
()
dygraph/utils/download.py
浏览文件 @
3639c2de
...
...
@@ -85,8 +85,8 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress):
for
total_num
,
index
,
rootpath
in
handler
(
filepath
,
extrapath
):
if
print_progress
:
done
=
int
(
50
*
float
(
index
)
/
total_num
)
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
done
,
float
(
100
*
index
)
/
total_num
))
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
done
,
float
(
100
*
index
)
/
total_num
))
if
print_progress
:
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
50
,
100
),
end
=
True
)
...
...
@@ -132,4 +132,4 @@ def download_file_and_uncompress(url,
print_progress
)
savename
=
os
.
path
.
join
(
extrapath
,
savename
)
shutil
.
move
(
savename
,
extraname
)
return
save
name
return
extra
name
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录