提交 6a569885 编写于 作者: C chenguowei01

update dataset using

上级 78b49980
...@@ -43,22 +43,11 @@ def parse_args(): ...@@ -43,22 +43,11 @@ def parse_args():
# params of dataset # params of dataset
parser.add_argument( parser.add_argument(
'--data_dir', '--dataset',
dest='data_dir', dest='dataset',
help='The root directory of dataset', help='The dataset you want to train',
type=str)
parser.add_argument(
'--test_list',
dest='test_list',
help='Val list file of dataset',
type=str, type=str,
default=None) default='OpticDiscSeg')
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
# params of prediction # params of prediction
parser.add_argument( parser.add_argument(
...@@ -142,12 +131,19 @@ def main(args): ...@@ -142,12 +131,19 @@ def main(args):
places = fluid.CUDAPlace(ParallelEnv().dev_id) \ places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = OpticDiscSeg(transforms=test_transforms, mode='test') test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet': if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes) model = models.UNet(num_classes=test_dataset.num_classes)
infer( infer(
model, model,
......
...@@ -40,31 +40,15 @@ def parse_args(): ...@@ -40,31 +40,15 @@ def parse_args():
dest='model_name', dest='model_name',
help="Model type for traing, which is one of ('UNet')", help="Model type for traing, which is one of ('UNet')",
type=str, type=str,
default='UNet') default='OpticDiscSeg')
# params of dataset # params of dataset
parser.add_argument( parser.add_argument(
'--data_dir', '--dataset',
dest='data_dir', dest='dataset',
help='The root directory of dataset', help='The dataset you want to train',
type=str)
parser.add_argument(
'--train_list',
dest='train_list',
help='Train list file of dataset',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
type=str, type=str,
default=None) default='OpticDiscSeg')
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
# params of training # params of training
parser.add_argument( parser.add_argument(
...@@ -83,7 +67,7 @@ def parse_args(): ...@@ -83,7 +67,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--batch_size', '--batch_size',
dest='batch_size', dest='batch_size',
help='Mini batch size', help='Mini batch size of one gpu or cpu',
type=int, type=int,
default=2) default=2)
parser.add_argument( parser.add_argument(
...@@ -210,6 +194,12 @@ def main(args): ...@@ -210,6 +194,12 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader # Creat dataset reader
train_transforms = T.Compose([ train_transforms = T.Compose([
...@@ -217,17 +207,18 @@ def main(args): ...@@ -217,17 +207,18 @@ def main(args):
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
T.Normalize() T.Normalize()
]) ])
train_dataset = OpticDiscSeg(transforms=train_transforms, mode='train') train_dataset = dataset(transforms=train_transforms, mode='train')
eval_dataset = None eval_dataset = None
if args.do_eval: if args.do_eval:
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Resize(args.input_size), [T.Resize(args.input_size),
T.Normalize()]) T.Normalize()])
eval_dataset = OpticDiscSeg(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet': if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes, ignore_index=255) model = models.UNet(
num_classes=train_dataset.num_classes, ignore_index=255)
# Creat optimizer # Creat optimizer
num_steps_each_epoch = len(train_dataset) // args.batch_size num_steps_each_epoch = len(train_dataset) // args.batch_size
...@@ -251,7 +242,7 @@ def main(args): ...@@ -251,7 +242,7 @@ def main(args):
batch_size=args.batch_size, batch_size=args.batch_size,
pretrained_model=args.pretrained_model, pretrained_model=args.pretrained_model,
save_interval_epochs=args.save_interval_epochs, save_interval_epochs=args.save_interval_epochs,
num_classes=args.num_classes, num_classes=train_dataset.num_classes,
num_workers=args.num_workers) num_workers=args.num_workers)
......
...@@ -44,22 +44,11 @@ def parse_args(): ...@@ -44,22 +44,11 @@ def parse_args():
# params of dataset # params of dataset
parser.add_argument( parser.add_argument(
'--data_dir', '--dataset',
dest='data_dir', dest='dataset',
help='The root directory of dataset', help='The dataset you want to evaluation',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
type=str, type=str,
default=None) default='OpticDiscSeg')
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
# params of evaluate # params of evaluate
parser.add_argument( parser.add_argument(
...@@ -140,19 +129,26 @@ def main(args): ...@@ -140,19 +129,26 @@ def main(args):
places = fluid.CUDAPlace(ParallelEnv().dev_id) \ places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = OpticDiscSeg(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet': if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes) model = models.UNet(num_classes=eval_dataset.num_classes)
evaluate( evaluate(
model, model,
eval_dataset, eval_dataset,
places=places, places=places,
model_dir=args.model_dir, model_dir=args.model_dir,
num_classes=args.num_classes, num_classes=eval_dataset.num_classes,
batch_size=args.batch_size) batch_size=args.batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册