提交 967ebd6f 编写于 作者: C chenguowei01

update datasets

上级 c4bf7ab3
......@@ -15,4 +15,10 @@
from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
from .voc import PascalVoc
from .voc import PascalVOC
DATASETS = {
"OpticDiscSeg": OpticDiscSeg,
"Cityscapes": Cityscapes,
"PascalVOC": PascalVOC
}
......@@ -20,7 +20,7 @@ DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
class PascalVoc(Dataset):
class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
Args:
......@@ -36,7 +36,7 @@ class PascalVoc(Dataset):
image_set='train',
mode='train',
transforms=None,
download=False):
download=True):
self.data_dir = data_dir
self.transforms = transforms
self.mode = mode
......
......@@ -22,7 +22,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import OpticDiscSeg, Cityscapes
from datasets import DATASETS
import transforms as T
from models import MODELS
import utils
......@@ -47,8 +47,8 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
help="The dataset you want to test, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
......@@ -88,14 +88,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
if args.dataset not in DATASETS:
raise Exception('--dataset is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
......@@ -43,18 +43,18 @@ def parse_args():
dest='voc_path',
help='pascal voc path',
type=str,
default=os.path.join(DATA_HOME + 'VOCdevkit'))
default=os.path.join(DATA_HOME, 'VOCdevkit'))
parser.add_argument(
'--num_workers',
dest='num_workers',
help='How many processes are used for data conversion',
type=str,
type=int,
default=cpu_count())
return parser.parse_args()
def conver_to_png(mat_file, sbd_cls_dir, save_dir):
def mat_to_png(mat_file, sbd_cls_dir, save_dir):
mat_path = os.path.join(sbd_cls_dir, mat_file)
mat = loadmat(mat_path)
mask = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
......@@ -75,27 +75,30 @@ def main():
sbd_file_list += [line.strip() for line in f]
if not os.path.exists(args.voc_path):
raise Exception(
'Ther is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
'There is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
)
with open(
os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/trainval.txt',
'r')) as f:
'VOC2012/ImageSets/Segmentation/trainval.txt'),
'r') as f:
voc_file_list = [line.strip() for line in f]
aug_file_list = list(set(sbd_file_list) - set(voc_file_list))
with open(
os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/aug.txt', 'w')) as f:
f.writelines(''.join(line, '\n') for line in aug_file_list)
'VOC2012/ImageSets/Segmentation/aug.txt'), 'w') as f:
f.writelines(''.join([line, '\n']) for line in aug_file_list)
sbd_cls_dir = os.path.join(sbd_path, 'dataset/cls')
save_dir = os.path.join(args.voc_path,
'VOC2012/ImageSets/SegmentationClassAug')
save_dir = os.path.join(args.voc_path, 'VOC2012/SegmentationClassAug')
if not os.path.exists(save_dir):
os.mkdir(save_dir)
mat_file_list = os.listdir(sbd_cls_dir)
p = Pool(args.num_workers)
for f in tqdm.tqdm(mat_file_list):
p.apply_async(conver_to_png, args=(f, sbd_cls_dir, save_dir))
p.apply_async(mat_to_png, args=(f, sbd_cls_dir, save_dir))
p.close()
p.join()
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
from datasets import DATASETS
import transforms as T
from models import MODELS
import utils.logging as logging
......@@ -47,8 +47,8 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
help="The dataset you want to train, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
......@@ -134,14 +134,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
if args.dataset not in DATASETS:
raise Exception('--dataset is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
# Creat dataset reader
......
......@@ -25,7 +25,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes
from datasets import DATASETS
import transforms as T
from models import MODELS
import utils.logging as logging
......@@ -51,8 +51,8 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')",
help="The dataset you want to evaluation, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
......@@ -80,14 +80,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
if args.dataset not in DATASETS:
raise Exception('--dataset is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册