提交 6ae9d4ac 编写于 作者: C chenguowei01

change data_dir to data_root

上级 4cdc97c3
......@@ -24,46 +24,47 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/ADEChallengeData2016.zip"
class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args:
data_dir: The dataset directory.
dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if data_dir is None.
download: Whether to download dataset if dataset_root is None.
"""
def __init__(self,
data_dir=None,
dataset_root=None,
mode='train',
transforms=None,
download=True):
self.data_dir = data_dir
self.dataset_root = dataset_root
self.transforms = transforms
self.mode = mode
self.file_list = list()
self.num_classes = 21
self.num_classes = 150
if mode.lower() not in ['train', 'val']:
raise Exception(
"mode should be one of ('train', 'val') in PascalVOC dataset, but got {}."
"mode should be one of ('train', 'val') in ADE20K dataset, but got {}."
.format(mode))
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None:
if self.dataset_root is None:
if not download:
raise Exception("data_dir not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
raise Exception(
"dataset_root not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='ADEChallengeData2016')
if mode == 'train':
img_dir = os.path.join(self.data_dir, 'images/training')
grt_dir = os.path.join(self.data_dir, 'annotations/training')
img_dir = os.path.join(self.dataset_root, 'images/training')
grt_dir = os.path.join(self.dataset_root, 'annotations/training')
elif mode == 'val':
img_dir = os.path.join(self.data_dir, 'images/validation')
grt_dir = os.path.join(self.data_dir, 'annotations/validation')
img_dir = os.path.join(self.dataset_root, 'images/validation')
grt_dir = os.path.join(self.dataset_root, 'annotations/validation')
img_files = os.listdir(img_dir)
grt_files = [i.replace('.jpg', '.png') for i in img_files]
for i in range(len(img_files)):
......
......@@ -35,13 +35,13 @@ class Cityscapes(Dataset):
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
Args:
data_dir: Cityscapes dataset directory.
dataset_root: Cityscapes dataset directory.
mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
"""
def __init__(self, data_dir, transforms=None, mode='train'):
self.data_dir = data_dir
def __init__(self, dataset_root, transforms=None, mode='train'):
self.dataset_root = dataset_root
self.transforms = transforms
self.file_list = list()
self.mode = mode
......@@ -55,10 +55,11 @@ class Cityscapes(Dataset):
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
img_dir = os.path.join(self.data_dir, 'leftImg8bit')
grt_dir = os.path.join(self.data_dir, 'gtFine')
if not os.path.isdir(self.data_dir) or not os.path.isdir(
img_dir) or not os.path.isdir(grt_dir):
img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
grt_dir = os.path.join(self.dataset_root, 'gtFine')
if self.dataset_root is None or not os.path.isdir(
self.dataset_root) or not os.path.isdir(
img_dir) or not os.path.isdir(grt_dir):
raise Exception(
"The dataset is not Found or the folder structure is nonconfoumance."
)
......
......@@ -23,7 +23,7 @@ class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format.
Args:
data_dir: The dataset directory.
dataset_root: The dataset directory.
num_classes: Number of classes.
mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
......@@ -43,7 +43,7 @@ class Dataset(fluid.io.Dataset):
"""
def __init__(self,
data_dir,
dataset_root,
num_classes,
mode='train',
train_list=None,
......@@ -51,7 +51,7 @@ class Dataset(fluid.io.Dataset):
test_list=None,
separator=' ',
transforms=None):
self.data_dir = data_dir
self.dataset_root = dataset_root
self.transforms = transforms
self.file_list = list()
self.mode = mode
......@@ -65,7 +65,7 @@ class Dataset(fluid.io.Dataset):
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir
self.dataset_root = dataset_root
if mode == 'train':
if train_list is None:
raise Exception(
......@@ -103,11 +103,11 @@ class Dataset(fluid.io.Dataset):
raise Exception(
"File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0])
image_path = os.path.join(self.dataset_root, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
......
......@@ -23,11 +23,11 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
class OpticDiscSeg(Dataset):
def __init__(self,
data_dir=None,
dataset_root=None,
transforms=None,
mode='train',
download=True):
self.data_dir = data_dir
self.dataset_root = dataset_root
self.transforms = transforms
self.file_list = list()
self.mode = mode
......@@ -41,18 +41,18 @@ class OpticDiscSeg(Dataset):
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None:
if self.dataset_root is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
file_list = os.path.join(self.dataset_root, 'train_list.txt')
elif mode == 'val':
file_list = os.path.join(self.data_dir, 'val_list.txt')
file_list = os.path.join(self.dataset_root, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
file_list = os.path.join(self.dataset_root, 'test_list.txt')
with open(file_list, 'r') as f:
for line in f:
......@@ -62,9 +62,9 @@ class OpticDiscSeg(Dataset):
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
image_path = os.path.join(self.dataset_root, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path])
......@@ -24,18 +24,18 @@ class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
Args:
data_dir: The dataset directory.
dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if data_dir is None.
download: Whether to download dataset if dataset_root is None.
"""
def __init__(self,
data_dir=None,
dataset_root=None,
mode='train',
transforms=None,
download=True):
self.data_dir = data_dir
self.dataset_root = dataset_root
self.transforms = transforms
self.mode = mode
self.file_list = list()
......@@ -49,16 +49,17 @@ class PascalVOC(Dataset):
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None:
if self.dataset_root is None:
if not download:
raise Exception("data_dir not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
raise Exception(
"dataset_root not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='VOCdevkit')
image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets',
image_set_dir = os.path.join(self.dataset_root, 'VOC2012', 'ImageSets',
'Segmentation')
if mode == 'train':
file_list = os.path.join(image_set_dir, 'train.txt')
......@@ -76,9 +77,10 @@ class PascalVOC(Dataset):
"Please make sure voc_augment.py has been properly run when using this mode."
)
img_dir = os.path.join(self.data_dir, 'VOC2012', 'JPEGImages')
grt_dir = os.path.join(self.data_dir, 'VOC2012', 'SegmentationClass')
grt_dir_aug = os.path.join(self.data_dir, 'VOC2012',
img_dir = os.path.join(self.dataset_root, 'VOC2012', 'JPEGImages')
grt_dir = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClass')
grt_dir_aug = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClassAug')
with open(file_list, 'r') as f:
......
......@@ -44,6 +44,12 @@ def parse_args():
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of prediction
parser.add_argument(
......@@ -88,7 +94,10 @@ def main(args):
with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test')
test_dataset = dataset(
dataset_root=args.dataset_root,
transforms=test_transforms,
mode='test')
if args.model_name not in MODELS:
raise Exception(
......
......@@ -15,7 +15,7 @@
File: voc_augment.py
This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html>
to augment the Pascal VOC
to augment the Pascal VOC.
"""
import os
......
......@@ -51,6 +51,12 @@ def parse_args():
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training
parser.add_argument(
......@@ -146,14 +152,20 @@ def main(args):
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = dataset(transforms=train_transforms, mode='train')
train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='val')
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
......
......@@ -55,6 +55,12 @@ def parse_args():
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of evaluate
parser.add_argument(
......@@ -87,7 +93,10 @@ def main(args):
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='val')
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册