提交 6b4a7f02 编写于 作者: C chenguowei01

add cityscapes dataset

上级 d2ed4fbf
...@@ -22,7 +22,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -22,7 +22,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2 import cv2
import tqdm import tqdm
from datasets import OpticDiscSeg from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models import models
import utils import utils
...@@ -45,7 +45,8 @@ def parse_args(): ...@@ -45,7 +45,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help='The dataset you want to train', help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
...@@ -134,9 +135,12 @@ def main(args): ...@@ -134,9 +135,12 @@ def main(args):
if args.dataset.lower() == 'opticdiscseg': if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else: else:
raise Exception( raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)") "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
...@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -20,7 +20,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models import models
import utils.logging as logging import utils.logging as logging
...@@ -38,13 +38,14 @@ def parse_args(): ...@@ -38,13 +38,14 @@ def parse_args():
dest='model_name', dest='model_name',
help="Model type for traing, which is one of ('UNet')", help="Model type for traing, which is one of ('UNet')",
type=str, type=str,
default='OpticDiscSeg') default='UNet')
# params of dataset # params of dataset
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help='The dataset you want to train', help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
...@@ -194,9 +195,12 @@ def main(args): ...@@ -194,9 +195,12 @@ def main(args):
if args.dataset.lower() == 'opticdiscseg': if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else: else:
raise Exception( raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)") "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader # Creat dataset reader
......
...@@ -23,7 +23,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -23,7 +23,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models import models
import utils.logging as logging import utils.logging as logging
...@@ -46,7 +46,8 @@ def parse_args(): ...@@ -46,7 +46,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help='The dataset you want to evaluation', help=
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
...@@ -132,9 +133,12 @@ def main(args): ...@@ -132,9 +133,12 @@ def main(args):
if args.dataset.lower() == 'opticdiscseg': if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else: else:
raise Exception( raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)") "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册