提交 967ebd6f 编写于 作者: C chenguowei01

update datasets

上级 c4bf7ab3
...@@ -15,4 +15,10 @@ ...@@ -15,4 +15,10 @@
from .dataset import Dataset from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
from .voc import PascalVoc from .voc import PascalVOC
DATASETS = {
"OpticDiscSeg": OpticDiscSeg,
"Cityscapes": Cityscapes,
"PascalVOC": PascalVOC
}
...@@ -20,7 +20,7 @@ DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') ...@@ -20,7 +20,7 @@ DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar" URL = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
class PascalVoc(Dataset): class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset, """Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools. please run the voc_augment.py in tools.
Args: Args:
...@@ -36,7 +36,7 @@ class PascalVoc(Dataset): ...@@ -36,7 +36,7 @@ class PascalVoc(Dataset):
image_set='train', image_set='train',
mode='train', mode='train',
transforms=None, transforms=None,
download=False): download=True):
self.data_dir = data_dir self.data_dir = data_dir
self.transforms = transforms self.transforms = transforms
self.mode = mode self.mode = mode
......
...@@ -22,7 +22,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -22,7 +22,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2 import cv2
import tqdm import tqdm
from datasets import OpticDiscSeg, Cityscapes from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils import utils
...@@ -47,8 +47,8 @@ def parse_args(): ...@@ -47,8 +47,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to test, which is one of {}".format(
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
...@@ -88,14 +88,10 @@ def main(args): ...@@ -88,14 +88,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('--dataset is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
...@@ -43,18 +43,18 @@ def parse_args(): ...@@ -43,18 +43,18 @@ def parse_args():
dest='voc_path', dest='voc_path',
help='pascal voc path', help='pascal voc path',
type=str, type=str,
default=os.path.join(DATA_HOME + 'VOCdevkit')) default=os.path.join(DATA_HOME, 'VOCdevkit'))
parser.add_argument( parser.add_argument(
'--num_workers', '--num_workers',
dest='num_workers', dest='num_workers',
help='How many processes are used for data conversion', help='How many processes are used for data conversion',
type=str, type=int,
default=cpu_count()) default=cpu_count())
return parser.parse_args() return parser.parse_args()
def conver_to_png(mat_file, sbd_cls_dir, save_dir): def mat_to_png(mat_file, sbd_cls_dir, save_dir):
mat_path = os.path.join(sbd_cls_dir, mat_file) mat_path = os.path.join(sbd_cls_dir, mat_file)
mat = loadmat(mat_path) mat = loadmat(mat_path)
mask = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8) mask = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
...@@ -75,27 +75,30 @@ def main(): ...@@ -75,27 +75,30 @@ def main():
sbd_file_list += [line.strip() for line in f] sbd_file_list += [line.strip() for line in f]
if not os.path.exists(args.voc_path): if not os.path.exists(args.voc_path):
raise Exception( raise Exception(
'Ther is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly' 'There is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
) )
with open( with open(
os.path.join(args.voc_path, os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/trainval.txt', 'VOC2012/ImageSets/Segmentation/trainval.txt'),
'r')) as f: 'r') as f:
voc_file_list = [line.strip() for line in f] voc_file_list = [line.strip() for line in f]
aug_file_list = list(set(sbd_file_list) - set(voc_file_list)) aug_file_list = list(set(sbd_file_list) - set(voc_file_list))
with open( with open(
os.path.join(args.voc_path, os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/aug.txt', 'w')) as f: 'VOC2012/ImageSets/Segmentation/aug.txt'), 'w') as f:
f.writelines(''.join(line, '\n') for line in aug_file_list) f.writelines(''.join([line, '\n']) for line in aug_file_list)
sbd_cls_dir = os.path.join(sbd_path, 'dataset/cls') sbd_cls_dir = os.path.join(sbd_path, 'dataset/cls')
save_dir = os.path.join(args.voc_path, save_dir = os.path.join(args.voc_path, 'VOC2012/SegmentationClassAug')
'VOC2012/ImageSets/SegmentationClassAug') if not os.path.exists(save_dir):
os.mkdir(save_dir)
mat_file_list = os.listdir(sbd_cls_dir) mat_file_list = os.listdir(sbd_cls_dir)
p = Pool(args.num_workers) p = Pool(args.num_workers)
for f in tqdm.tqdm(mat_file_list): for f in tqdm.tqdm(mat_file_list):
p.apply_async(conver_to_png, args=(f, sbd_cls_dir, save_dir)) p.apply_async(mat_to_png, args=(f, sbd_cls_dir, save_dir))
p.close()
p.join()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils.logging as logging import utils.logging as logging
...@@ -47,8 +47,8 @@ def parse_args(): ...@@ -47,8 +47,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to train, which is one of {}".format(
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
...@@ -134,14 +134,10 @@ def main(args): ...@@ -134,14 +134,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('--dataset is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader # Creat dataset reader
......
...@@ -25,7 +25,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -25,7 +25,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils.logging as logging import utils.logging as logging
...@@ -51,8 +51,8 @@ def parse_args(): ...@@ -51,8 +51,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to evaluation, which is one of {}".format(
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
...@@ -80,14 +80,10 @@ def main(args): ...@@ -80,14 +80,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('--dataset is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册