提交 3639c2de 编写于 作者: C chenguowei01

update dataset, add voc and voc aug

上级 2dd6872e
......@@ -15,3 +15,4 @@
from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
from .voc import PascalVoc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -39,9 +39,8 @@ class Cityscapes(Dataset):
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,43 +20,74 @@ from PIL import Image
class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format.
Args:
data_dir: The dataset directory.
num_classes: Number of classes.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow:
image1.jpg ground_truth1.png
image2.jpg ground_truth2.png
val_list: The evaluation dataset file. When image_set is 'val', val_list is necessary.
The contents is the same as train_list
test_list: The test dataset file. When image_set is 'test', test_list is necessary.
The annotation file is not necessary in test_list file.
separator: The separator of dataset list. Default: ' '.
transforms: Transforms for image.
Examples:
todo
"""
def __init__(self,
data_dir,
num_classes,
image_set='train',
mode='train',
train_list=None,
val_list=None,
test_list=None,
separator=' ',
transforms=None,
mode='train'):
transforms=None):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = num_classes
if image_set.lower() not in ['train', 'val', 'test']:
raise Exception(
"image_set should be one of ('train', 'val', 'test'), but got {}."
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir
if mode == 'train':
if image_set == 'train':
if train_list is None:
raise Exception(
'When mode is "train", train_list is need, but it is None.')
'When mode is "train", train_list is necessary, but it is None.'
)
elif not os.path.exists(train_list):
raise Exception(
'train_list is not found: {}'.format(train_list))
else:
file_list = train_list
elif mode == 'eval':
elif image_set == 'eval':
if val_list is None:
raise Exception(
'When mode is "eval", val_list is need, but it is None.')
'When mode is "eval", val_list is necessary, but it is None.'
)
elif not os.path.exists(val_list):
raise Exception('val_list is not found: {}'.format(val_list))
else:
......@@ -64,7 +95,8 @@ class Dataset(fluid.io.Dataset):
else:
if test_list is None:
raise Exception(
'When mode is "test", test_list is need, but it is None.')
'When mode is "test", test_list is necessary, but it is None.'
)
elif not os.path.exists(test_list):
raise Exception('test_list is not found: {}'.format(test_list))
else:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,6 +25,7 @@ class OpticDiscSeg(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
image_set='train',
mode='train',
download=True):
self.data_dir = data_dir
......@@ -33,24 +34,28 @@ class OpticDiscSeg(Dataset):
self.mode = mode
self.num_classes = 2
if image_set.lower() not in ['train', 'val', 'test']:
raise Exception(
"image_set should be one of ('train', 'val', 'test'), but got {}."
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train':
if image_set == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval':
elif image_set == 'val':
file_list = os.path.join(self.data_dir, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
class PascalVoc(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
Args:
data_dir: The dataset directory.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'trainval', 'trainaug). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if data_dir is None.
"""
def __init__(self,
data_dir=None,
image_set='train',
mode='train',
transforms=None,
download=False):
self.data_dir = data_dir
self.transforms = transforms
self.mode = mode
self.file_list = list()
self.num_classes = 21
if image_set.lower() not in ['train', 'val', 'trainval', 'trainaug']:
raise Exception(
"image_set should be one of ('train', 'val', 'trainval', 'trainaug'), but got {}."
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='VOCdevkit')
print(self.data_dir)
image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets',
'Segmentation')
if image_set == 'train':
file_list = os.path.join(image_set_dir, 'train.txt')
elif image_set == 'val':
file_list = os.path.join(image_set_dir, 'val.txt')
elif image_set == 'trainval':
file_list = os.path.join(image_set_dir, 'trainval.txt')
elif image_set == 'trainaug':
file_list = os.path.join(image_set_dir, 'train.txt')
file_list_aug = os.path.join(image_set_dir, 'aug.txt')
if not os.path.exists(file_list_aug):
raise Exception(
"When image_set is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode."
)
img_dir = os.path.join(self.data_dir, 'VOC2012', 'JPEGImages')
grt_dir = os.path.join(self.data_dir, 'VOC2012', 'SegmentationClass')
grt_dir_aug = os.path.join(self.data_dir, 'VOC2012',
'SegmentationClassAug')
with open(file_list, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
self.file_list.append([image_path, grt_path])
if image_set == 'trainaug':
with open(file_list_aug, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
self.file_list.append([image_path, grt_path])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: voc_augment.py
This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html>
to augment the Pascal VOC
"""
import os
import argparse
from multiprocessing import Pool, cpu_count
import cv2
import numpy as np
from scipy.io import loadmat
import tqdm
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz'
def parse_args():
parser = argparse.ArgumentParser(
description=
'Convert SBD to Pascal Voc annotations to augment the train dataset of Pascal Voc'
)
parser.add_argument(
'--voc_path',
dest='voc_path',
help='pascal voc path',
type=str,
default=os.path.join(DATA_HOME + 'VOCdevkit'))
parser.add_argument(
'--num_workers',
dest='num_workers',
help='How many processes are used for data conversion',
type=str,
default=cpu_count())
return parser.parse_args()
def conver_to_png(mat_file, sbd_cls_dir, save_dir):
mat_path = os.path.join(sbd_cls_dir, mat_file)
mat = loadmat(mat_path)
mask = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
save_file = os.path.join(save_dir, mat_file.replace('mat', 'png'))
cv2.imwrite(save_file, mask)
def main():
args = parse_args()
sbd_path = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='benchmark_RELEASE')
with open(os.path.join(sbd_path, 'dataset/train.txt'), 'r') as f:
sbd_file_list = [line.strip() for line in f]
with open(os.path.join(sbd_path, 'dataset/val.txt'), 'r') as f:
sbd_file_list += [line.strip() for line in f]
if not os.path.exists(args.voc_path):
raise Exception(
'Ther is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
)
with open(
os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/trainval.txt',
'r')) as f:
voc_file_list = [line.strip() for line in f]
aug_file_list = list(set(sbd_file_list) - set(voc_file_list))
with open(
os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/aug.txt', 'w')) as f:
f.writelines(''.join(line, '\n') for line in aug_file_list)
sbd_cls_dir = os.path.join(sbd_path, 'dataset/cls')
save_dir = os.path.join(args.voc_path,
'VOC2012/ImageSets/SegmentationClassAug')
mat_file_list = os.listdir(sbd_cls_dir)
p = Pool(args.num_workers)
for f in tqdm.tqdm(mat_file_list):
p.apply_async(conver_to_png, args=(f, sbd_cls_dir, save_dir))
if __name__ == '__main__':
main()
......@@ -85,8 +85,8 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress):
for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * index) / total_num))
progress(
"[%-50s] %.2f%%" % ('=' * done, float(100 * index) / total_num))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
......@@ -132,4 +132,4 @@ def download_file_and_uncompress(url,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname)
return savename
return extraname
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册