提交 6b4a7f02 编写于 作者: C chenguowei01

add cityscapes dataset

上级 d2ed4fbf
......@@ -22,7 +22,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import OpticDiscSeg
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils
......@@ -45,7 +45,8 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help='The dataset you want to train',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
......@@ -134,9 +135,12 @@ def main(args):
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
......@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils.logging as logging
......@@ -38,13 +38,14 @@ def parse_args():
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='OpticDiscSeg')
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help='The dataset you want to train',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
......@@ -194,9 +195,12 @@ def main(args):
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
# Creat dataset reader
......
......@@ -23,7 +23,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils.logging as logging
......@@ -46,7 +46,8 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help='The dataset you want to evaluation',
help=
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
......@@ -132,9 +133,12 @@ def main(args):
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册