提交 6ae9d4ac 编写于 作者: C chenguowei01

change data_dir to data_root

上级 4cdc97c3
...@@ -24,46 +24,47 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/ADEChallengeData2016.zip" ...@@ -24,46 +24,47 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/ADEChallengeData2016.zip"
class ADE20K(Dataset): class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`. """ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args: Args:
data_dir: The dataset directory. dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'. mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'.
transforms: Transforms for image. transforms: Transforms for image.
download: Whether to download dataset if data_dir is None. download: Whether to download dataset if dataset_root is None.
""" """
def __init__(self, def __init__(self,
data_dir=None, dataset_root=None,
mode='train', mode='train',
transforms=None, transforms=None,
download=True): download=True):
self.data_dir = data_dir self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.mode = mode self.mode = mode
self.file_list = list() self.file_list = list()
self.num_classes = 21 self.num_classes = 150
if mode.lower() not in ['train', 'val']: if mode.lower() not in ['train', 'val']:
raise Exception( raise Exception(
"mode should be one of ('train', 'val') in PascalVOC dataset, but got {}." "mode should be one of ('train', 'val') in ADE20K dataset, but got {}."
.format(mode)) .format(mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None: if self.dataset_root is None:
if not download: if not download:
raise Exception("data_dir not set and auto download disabled.") raise Exception(
self.data_dir = download_file_and_uncompress( "dataset_root not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=DATA_HOME,
extrapath=DATA_HOME, extrapath=DATA_HOME,
extraname='ADEChallengeData2016') extraname='ADEChallengeData2016')
if mode == 'train': if mode == 'train':
img_dir = os.path.join(self.data_dir, 'images/training') img_dir = os.path.join(self.dataset_root, 'images/training')
grt_dir = os.path.join(self.data_dir, 'annotations/training') grt_dir = os.path.join(self.dataset_root, 'annotations/training')
elif mode == 'val': elif mode == 'val':
img_dir = os.path.join(self.data_dir, 'images/validation') img_dir = os.path.join(self.dataset_root, 'images/validation')
grt_dir = os.path.join(self.data_dir, 'annotations/validation') grt_dir = os.path.join(self.dataset_root, 'annotations/validation')
img_files = os.listdir(img_dir) img_files = os.listdir(img_dir)
grt_files = [i.replace('.jpg', '.png') for i in img_files] grt_files = [i.replace('.jpg', '.png') for i in img_files]
for i in range(len(img_files)): for i in range(len(img_files)):
......
...@@ -35,13 +35,13 @@ class Cityscapes(Dataset): ...@@ -35,13 +35,13 @@ class Cityscapes(Dataset):
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools. Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
Args: Args:
data_dir: Cityscapes dataset directory. dataset_root: Cityscapes dataset directory.
mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image. transforms: Transforms for image.
""" """
def __init__(self, data_dir, transforms=None, mode='train'): def __init__(self, dataset_root, transforms=None, mode='train'):
self.data_dir = data_dir self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
...@@ -55,10 +55,11 @@ class Cityscapes(Dataset): ...@@ -55,10 +55,11 @@ class Cityscapes(Dataset):
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
img_dir = os.path.join(self.data_dir, 'leftImg8bit') img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
grt_dir = os.path.join(self.data_dir, 'gtFine') grt_dir = os.path.join(self.dataset_root, 'gtFine')
if not os.path.isdir(self.data_dir) or not os.path.isdir( if self.dataset_root is None or not os.path.isdir(
img_dir) or not os.path.isdir(grt_dir): self.dataset_root) or not os.path.isdir(
img_dir) or not os.path.isdir(grt_dir):
raise Exception( raise Exception(
"The dataset is not Found or the folder structure is nonconfoumance." "The dataset is not Found or the folder structure is nonconfoumance."
) )
......
...@@ -23,7 +23,7 @@ class Dataset(fluid.io.Dataset): ...@@ -23,7 +23,7 @@ class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format. """Pass in a custom dataset that conforms to the format.
Args: Args:
data_dir: The dataset directory. dataset_root: The dataset directory.
num_classes: Number of classes. num_classes: Number of classes.
mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary. train_list: The train dataset file. When image_set is 'train', train_list is necessary.
...@@ -43,7 +43,7 @@ class Dataset(fluid.io.Dataset): ...@@ -43,7 +43,7 @@ class Dataset(fluid.io.Dataset):
""" """
def __init__(self, def __init__(self,
data_dir, dataset_root,
num_classes, num_classes,
mode='train', mode='train',
train_list=None, train_list=None,
...@@ -51,7 +51,7 @@ class Dataset(fluid.io.Dataset): ...@@ -51,7 +51,7 @@ class Dataset(fluid.io.Dataset):
test_list=None, test_list=None,
separator=' ', separator=' ',
transforms=None): transforms=None):
self.data_dir = data_dir self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
...@@ -65,7 +65,7 @@ class Dataset(fluid.io.Dataset): ...@@ -65,7 +65,7 @@ class Dataset(fluid.io.Dataset):
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir self.dataset_root = dataset_root
if mode == 'train': if mode == 'train':
if train_list is None: if train_list is None:
raise Exception( raise Exception(
...@@ -103,11 +103,11 @@ class Dataset(fluid.io.Dataset): ...@@ -103,11 +103,11 @@ class Dataset(fluid.io.Dataset):
raise Exception( raise Exception(
"File list format incorrect! In training or evaluation task it should be" "File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator)) " image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = None grt_path = None
else: else:
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.data_dir, items[1]) grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
def __getitem__(self, idx): def __getitem__(self, idx):
......
...@@ -23,11 +23,11 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" ...@@ -23,11 +23,11 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
class OpticDiscSeg(Dataset): class OpticDiscSeg(Dataset):
def __init__(self, def __init__(self,
data_dir=None, dataset_root=None,
transforms=None, transforms=None,
mode='train', mode='train',
download=True): download=True):
self.data_dir = data_dir self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
...@@ -41,18 +41,18 @@ class OpticDiscSeg(Dataset): ...@@ -41,18 +41,18 @@ class OpticDiscSeg(Dataset):
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None: if self.dataset_root is None:
if not download: if not download:
raise Exception("data_file not set and auto download disabled.") raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train': if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt') file_list = os.path.join(self.dataset_root, 'train_list.txt')
elif mode == 'val': elif mode == 'val':
file_list = os.path.join(self.data_dir, 'val_list.txt') file_list = os.path.join(self.dataset_root, 'val_list.txt')
else: else:
file_list = os.path.join(self.data_dir, 'test_list.txt') file_list = os.path.join(self.dataset_root, 'test_list.txt')
with open(file_list, 'r') as f: with open(file_list, 'r') as f:
for line in f: for line in f:
...@@ -62,9 +62,9 @@ class OpticDiscSeg(Dataset): ...@@ -62,9 +62,9 @@ class OpticDiscSeg(Dataset):
raise Exception( raise Exception(
"File list format incorrect! It should be" "File list format incorrect! It should be"
" image_name label_name\\n") " image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = None grt_path = None
else: else:
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.data_dir, items[1]) grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
...@@ -24,18 +24,18 @@ class PascalVOC(Dataset): ...@@ -24,18 +24,18 @@ class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset, """Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools. please run the voc_augment.py in tools.
Args: Args:
data_dir: The dataset directory. dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'. mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image. transforms: Transforms for image.
download: Whether to download dataset if data_dir is None. download: Whether to download dataset if dataset_root is None.
""" """
def __init__(self, def __init__(self,
data_dir=None, dataset_root=None,
mode='train', mode='train',
transforms=None, transforms=None,
download=True): download=True):
self.data_dir = data_dir self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.mode = mode self.mode = mode
self.file_list = list() self.file_list = list()
...@@ -49,16 +49,17 @@ class PascalVOC(Dataset): ...@@ -49,16 +49,17 @@ class PascalVOC(Dataset):
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None: if self.dataset_root is None:
if not download: if not download:
raise Exception("data_dir not set and auto download disabled.") raise Exception(
self.data_dir = download_file_and_uncompress( "dataset_root not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=DATA_HOME,
extrapath=DATA_HOME, extrapath=DATA_HOME,
extraname='VOCdevkit') extraname='VOCdevkit')
image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets', image_set_dir = os.path.join(self.dataset_root, 'VOC2012', 'ImageSets',
'Segmentation') 'Segmentation')
if mode == 'train': if mode == 'train':
file_list = os.path.join(image_set_dir, 'train.txt') file_list = os.path.join(image_set_dir, 'train.txt')
...@@ -76,9 +77,10 @@ class PascalVOC(Dataset): ...@@ -76,9 +77,10 @@ class PascalVOC(Dataset):
"Please make sure voc_augment.py has been properly run when using this mode." "Please make sure voc_augment.py has been properly run when using this mode."
) )
img_dir = os.path.join(self.data_dir, 'VOC2012', 'JPEGImages') img_dir = os.path.join(self.dataset_root, 'VOC2012', 'JPEGImages')
grt_dir = os.path.join(self.data_dir, 'VOC2012', 'SegmentationClass') grt_dir = os.path.join(self.dataset_root, 'VOC2012',
grt_dir_aug = os.path.join(self.data_dir, 'VOC2012', 'SegmentationClass')
grt_dir_aug = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClassAug') 'SegmentationClassAug')
with open(file_list, 'r') as f: with open(file_list, 'r') as f:
......
...@@ -44,6 +44,12 @@ def parse_args(): ...@@ -44,6 +44,12 @@ def parse_args():
str(list(DATASETS.keys()))), str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of prediction # params of prediction
parser.add_argument( parser.add_argument(
...@@ -88,7 +94,10 @@ def main(args): ...@@ -88,7 +94,10 @@ def main(args):
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test') test_dataset = dataset(
dataset_root=args.dataset_root,
transforms=test_transforms,
mode='test')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
File: voc_augment.py File: voc_augment.py
This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html> This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html>
to augment the Pascal VOC to augment the Pascal VOC.
""" """
import os import os
......
...@@ -51,6 +51,12 @@ def parse_args(): ...@@ -51,6 +51,12 @@ def parse_args():
str(list(DATASETS.keys()))), str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training # params of training
parser.add_argument( parser.add_argument(
...@@ -146,14 +152,20 @@ def main(args): ...@@ -146,14 +152,20 @@ def main(args):
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
T.Normalize() T.Normalize()
]) ])
train_dataset = dataset(transforms=train_transforms, mode='train') train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None eval_dataset = None
if args.do_eval: if args.do_eval:
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Resize(args.input_size), [T.Resize(args.input_size),
T.Normalize()]) T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='val') eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
......
...@@ -55,6 +55,12 @@ def parse_args(): ...@@ -55,6 +55,12 @@ def parse_args():
str(list(DATASETS.keys()))), str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of evaluate # params of evaluate
parser.add_argument( parser.add_argument(
...@@ -87,7 +93,10 @@ def main(args): ...@@ -87,7 +93,10 @@ def main(args):
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='val') eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册