From 6ae9d4acbb68dee3b1250adc813a2f06b9852374 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Sun, 9 Aug 2020 21:55:40 +0800 Subject: [PATCH] change data_dir to data_root --- dygraph/datasets/ade.py | 27 ++++++++++++++------------- dygraph/datasets/cityscapes.py | 15 ++++++++------- dygraph/datasets/dataset.py | 14 +++++++------- dygraph/datasets/optic_disc_seg.py | 20 ++++++++++---------- dygraph/datasets/voc.py | 24 +++++++++++++----------- dygraph/infer.py | 11 ++++++++++- dygraph/tools/voc_augment.py | 2 +- dygraph/train.py | 16 ++++++++++++++-- dygraph/val.py | 11 ++++++++++- 9 files changed, 87 insertions(+), 53 deletions(-) diff --git a/dygraph/datasets/ade.py b/dygraph/datasets/ade.py index e8a19256..f8f8902b 100644 --- a/dygraph/datasets/ade.py +++ b/dygraph/datasets/ade.py @@ -24,46 +24,47 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/ADEChallengeData2016.zip" class ADE20K(Dataset): """ADE20K dataset `http://sceneparsing.csail.mit.edu/`. Args: - data_dir: The dataset directory. + dataset_root: The dataset directory. mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'. transforms: Transforms for image. - download: Whether to download dataset if data_dir is None. + download: Whether to download dataset if dataset_root is None. """ def __init__(self, - data_dir=None, + dataset_root=None, mode='train', transforms=None, download=True): - self.data_dir = data_dir + self.dataset_root = dataset_root self.transforms = transforms self.mode = mode self.file_list = list() - self.num_classes = 21 + self.num_classes = 150 if mode.lower() not in ['train', 'val']: raise Exception( - "mode should be one of ('train', 'val') in PascalVOC dataset, but got {}." + "mode should be one of ('train', 'val') in ADE20K dataset, but got {}." .format(mode)) if self.transforms is None: raise Exception("transforms is necessary, but it is None.") - if self.data_dir is None: + if self.dataset_root is None: if not download: - raise Exception("data_dir not set and auto download disabled.") - self.data_dir = download_file_and_uncompress( + raise Exception( + "dataset_root not set and auto download disabled.") + self.dataset_root = download_file_and_uncompress( url=URL, savepath=DATA_HOME, extrapath=DATA_HOME, extraname='ADEChallengeData2016') if mode == 'train': - img_dir = os.path.join(self.data_dir, 'images/training') - grt_dir = os.path.join(self.data_dir, 'annotations/training') + img_dir = os.path.join(self.dataset_root, 'images/training') + grt_dir = os.path.join(self.dataset_root, 'annotations/training') elif mode == 'val': - img_dir = os.path.join(self.data_dir, 'images/validation') - grt_dir = os.path.join(self.data_dir, 'annotations/validation') + img_dir = os.path.join(self.dataset_root, 'images/validation') + grt_dir = os.path.join(self.dataset_root, 'annotations/validation') img_files = os.listdir(img_dir) grt_files = [i.replace('.jpg', '.png') for i in img_files] for i in range(len(img_files)): diff --git a/dygraph/datasets/cityscapes.py b/dygraph/datasets/cityscapes.py index d12e804c..4e735fc2 100644 --- a/dygraph/datasets/cityscapes.py +++ b/dygraph/datasets/cityscapes.py @@ -35,13 +35,13 @@ class Cityscapes(Dataset): Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools. Args: - data_dir: Cityscapes dataset directory. + dataset_root: Cityscapes dataset directory. mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. transforms: Transforms for image. """ - def __init__(self, data_dir, transforms=None, mode='train'): - self.data_dir = data_dir + def __init__(self, dataset_root, transforms=None, mode='train'): + self.dataset_root = dataset_root self.transforms = transforms self.file_list = list() self.mode = mode @@ -55,10 +55,11 @@ class Cityscapes(Dataset): if self.transforms is None: raise Exception("transforms is necessary, but it is None.") - img_dir = os.path.join(self.data_dir, 'leftImg8bit') - grt_dir = os.path.join(self.data_dir, 'gtFine') - if not os.path.isdir(self.data_dir) or not os.path.isdir( - img_dir) or not os.path.isdir(grt_dir): + img_dir = os.path.join(self.dataset_root, 'leftImg8bit') + grt_dir = os.path.join(self.dataset_root, 'gtFine') + if self.dataset_root is None or not os.path.isdir( + self.dataset_root) or not os.path.isdir( + img_dir) or not os.path.isdir(grt_dir): raise Exception( "The dataset is not Found or the folder structure is nonconfoumance." ) diff --git a/dygraph/datasets/dataset.py b/dygraph/datasets/dataset.py index 316caa9f..fd21ad84 100644 --- a/dygraph/datasets/dataset.py +++ b/dygraph/datasets/dataset.py @@ -23,7 +23,7 @@ class Dataset(fluid.io.Dataset): """Pass in a custom dataset that conforms to the format. Args: - data_dir: The dataset directory. + dataset_root: The dataset directory. num_classes: Number of classes. mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. train_list: The train dataset file. When image_set is 'train', train_list is necessary. @@ -43,7 +43,7 @@ class Dataset(fluid.io.Dataset): """ def __init__(self, - data_dir, + dataset_root, num_classes, mode='train', train_list=None, @@ -51,7 +51,7 @@ class Dataset(fluid.io.Dataset): test_list=None, separator=' ', transforms=None): - self.data_dir = data_dir + self.dataset_root = dataset_root self.transforms = transforms self.file_list = list() self.mode = mode @@ -65,7 +65,7 @@ class Dataset(fluid.io.Dataset): if self.transforms is None: raise Exception("transforms is necessary, but it is None.") - self.data_dir = data_dir + self.dataset_root = dataset_root if mode == 'train': if train_list is None: raise Exception( @@ -103,11 +103,11 @@ class Dataset(fluid.io.Dataset): raise Exception( "File list format incorrect! In training or evaluation task it should be" " image_name{}label_name\\n".format(separator)) - image_path = os.path.join(self.data_dir, items[0]) + image_path = os.path.join(self.dataset_root, items[0]) grt_path = None else: - image_path = os.path.join(self.data_dir, items[0]) - grt_path = os.path.join(self.data_dir, items[1]) + image_path = os.path.join(self.dataset_root, items[0]) + grt_path = os.path.join(self.dataset_root, items[1]) self.file_list.append([image_path, grt_path]) def __getitem__(self, idx): diff --git a/dygraph/datasets/optic_disc_seg.py b/dygraph/datasets/optic_disc_seg.py index e1b88176..5c2bc1ea 100644 --- a/dygraph/datasets/optic_disc_seg.py +++ b/dygraph/datasets/optic_disc_seg.py @@ -23,11 +23,11 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" class OpticDiscSeg(Dataset): def __init__(self, - data_dir=None, + dataset_root=None, transforms=None, mode='train', download=True): - self.data_dir = data_dir + self.dataset_root = dataset_root self.transforms = transforms self.file_list = list() self.mode = mode @@ -41,18 +41,18 @@ class OpticDiscSeg(Dataset): if self.transforms is None: raise Exception("transforms is necessary, but it is None.") - if self.data_dir is None: + if self.dataset_root is None: if not download: raise Exception("data_file not set and auto download disabled.") - self.data_dir = download_file_and_uncompress( + self.dataset_root = download_file_and_uncompress( url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) if mode == 'train': - file_list = os.path.join(self.data_dir, 'train_list.txt') + file_list = os.path.join(self.dataset_root, 'train_list.txt') elif mode == 'val': - file_list = os.path.join(self.data_dir, 'val_list.txt') + file_list = os.path.join(self.dataset_root, 'val_list.txt') else: - file_list = os.path.join(self.data_dir, 'test_list.txt') + file_list = os.path.join(self.dataset_root, 'test_list.txt') with open(file_list, 'r') as f: for line in f: @@ -62,9 +62,9 @@ class OpticDiscSeg(Dataset): raise Exception( "File list format incorrect! It should be" " image_name label_name\\n") - image_path = os.path.join(self.data_dir, items[0]) + image_path = os.path.join(self.dataset_root, items[0]) grt_path = None else: - image_path = os.path.join(self.data_dir, items[0]) - grt_path = os.path.join(self.data_dir, items[1]) + image_path = os.path.join(self.dataset_root, items[0]) + grt_path = os.path.join(self.dataset_root, items[1]) self.file_list.append([image_path, grt_path]) diff --git a/dygraph/datasets/voc.py b/dygraph/datasets/voc.py index 56ece84a..05eb9f26 100644 --- a/dygraph/datasets/voc.py +++ b/dygraph/datasets/voc.py @@ -24,18 +24,18 @@ class PascalVOC(Dataset): """Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset, please run the voc_augment.py in tools. Args: - data_dir: The dataset directory. + dataset_root: The dataset directory. mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'. transforms: Transforms for image. - download: Whether to download dataset if data_dir is None. + download: Whether to download dataset if dataset_root is None. """ def __init__(self, - data_dir=None, + dataset_root=None, mode='train', transforms=None, download=True): - self.data_dir = data_dir + self.dataset_root = dataset_root self.transforms = transforms self.mode = mode self.file_list = list() @@ -49,16 +49,17 @@ class PascalVOC(Dataset): if self.transforms is None: raise Exception("transforms is necessary, but it is None.") - if self.data_dir is None: + if self.dataset_root is None: if not download: - raise Exception("data_dir not set and auto download disabled.") - self.data_dir = download_file_and_uncompress( + raise Exception( + "dataset_root not set and auto download disabled.") + self.dataset_root = download_file_and_uncompress( url=URL, savepath=DATA_HOME, extrapath=DATA_HOME, extraname='VOCdevkit') - image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets', + image_set_dir = os.path.join(self.dataset_root, 'VOC2012', 'ImageSets', 'Segmentation') if mode == 'train': file_list = os.path.join(image_set_dir, 'train.txt') @@ -76,9 +77,10 @@ class PascalVOC(Dataset): "Please make sure voc_augment.py has been properly run when using this mode." ) - img_dir = os.path.join(self.data_dir, 'VOC2012', 'JPEGImages') - grt_dir = os.path.join(self.data_dir, 'VOC2012', 'SegmentationClass') - grt_dir_aug = os.path.join(self.data_dir, 'VOC2012', + img_dir = os.path.join(self.dataset_root, 'VOC2012', 'JPEGImages') + grt_dir = os.path.join(self.dataset_root, 'VOC2012', + 'SegmentationClass') + grt_dir_aug = os.path.join(self.dataset_root, 'VOC2012', 'SegmentationClassAug') with open(file_list, 'r') as f: diff --git a/dygraph/infer.py b/dygraph/infer.py index 6cb8f6d6..364287bd 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -44,6 +44,12 @@ def parse_args(): str(list(DATASETS.keys()))), type=str, default='OpticDiscSeg') + parser.add_argument( + '--dataset_root', + dest='dataset_root', + help="dataset root directory", + type=str, + default=None) # params of prediction parser.add_argument( @@ -88,7 +94,10 @@ def main(args): with fluid.dygraph.guard(places): test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - test_dataset = dataset(transforms=test_transforms, mode='test') + test_dataset = dataset( + dataset_root=args.dataset_root, + transforms=test_transforms, + mode='test') if args.model_name not in MODELS: raise Exception( diff --git a/dygraph/tools/voc_augment.py b/dygraph/tools/voc_augment.py index c92776f7..c4be6ad4 100644 --- a/dygraph/tools/voc_augment.py +++ b/dygraph/tools/voc_augment.py @@ -15,7 +15,7 @@ File: voc_augment.py This file use SBD(Semantic Boundaries Dataset) -to augment the Pascal VOC +to augment the Pascal VOC. """ import os diff --git a/dygraph/train.py b/dygraph/train.py index cf3ad5b6..f6ccc39a 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -51,6 +51,12 @@ def parse_args(): str(list(DATASETS.keys()))), type=str, default='OpticDiscSeg') + parser.add_argument( + '--dataset_root', + dest='dataset_root', + help="dataset root directory", + type=str, + default=None) # params of training parser.add_argument( @@ -146,14 +152,20 @@ def main(args): T.RandomHorizontalFlip(), T.Normalize() ]) - train_dataset = dataset(transforms=train_transforms, mode='train') + train_dataset = dataset( + dataset_root=args.dataset_root, + transforms=train_transforms, + mode='train') eval_dataset = None if args.do_eval: eval_transforms = T.Compose( [T.Resize(args.input_size), T.Normalize()]) - eval_dataset = dataset(transforms=eval_transforms, mode='val') + eval_dataset = dataset( + dataset_root=args.dataset_root, + transforms=eval_transforms, + mode='val') if args.model_name not in MODELS: raise Exception( diff --git a/dygraph/val.py b/dygraph/val.py index a453bd84..e3e1dca6 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -55,6 +55,12 @@ def parse_args(): str(list(DATASETS.keys()))), type=str, default='OpticDiscSeg') + parser.add_argument( + '--dataset_root', + dest='dataset_root', + help="dataset root directory", + type=str, + default=None) # params of evaluate parser.add_argument( @@ -87,7 +93,10 @@ def main(args): with fluid.dygraph.guard(places): eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - eval_dataset = dataset(transforms=eval_transforms, mode='val') + eval_dataset = dataset( + dataset_root=args.dataset_root, + transforms=eval_transforms, + mode='val') if args.model_name not in MODELS: raise Exception( -- GitLab