提交 6a569885 编写于 作者: C chenguowei01

update dataset using

上级 78b49980
......@@ -43,22 +43,11 @@ def parse_args():
# params of dataset
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--test_list',
dest='test_list',
help='Val list file of dataset',
'--dataset',
dest='dataset',
help='The dataset you want to train',
type=str,
default=None)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
default='OpticDiscSeg')
# params of prediction
parser.add_argument(
......@@ -142,12 +131,19 @@ def main(args):
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = OpticDiscSeg(transforms=test_transforms, mode='test')
test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes)
model = models.UNet(num_classes=test_dataset.num_classes)
infer(
model,
......
......@@ -40,31 +40,15 @@ def parse_args():
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
default='OpticDiscSeg')
# params of dataset
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--train_list',
dest='train_list',
help='Train list file of dataset',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
'--dataset',
dest='dataset',
help='The dataset you want to train',
type=str,
default=None)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
default='OpticDiscSeg')
# params of training
parser.add_argument(
......@@ -83,7 +67,7 @@ def parse_args():
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
help='Mini batch size of one gpu or cpu',
type=int,
default=2)
parser.add_argument(
......@@ -210,6 +194,12 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
with fluid.dygraph.guard(places):
# Creat dataset reader
train_transforms = T.Compose([
......@@ -217,17 +207,18 @@ def main(args):
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = OpticDiscSeg(transforms=train_transforms, mode='train')
train_dataset = dataset(transforms=train_transforms, mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = OpticDiscSeg(transforms=eval_transforms, mode='eval')
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes, ignore_index=255)
model = models.UNet(
num_classes=train_dataset.num_classes, ignore_index=255)
# Creat optimizer
num_steps_each_epoch = len(train_dataset) // args.batch_size
......@@ -251,7 +242,7 @@ def main(args):
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
save_interval_epochs=args.save_interval_epochs,
num_classes=args.num_classes,
num_classes=train_dataset.num_classes,
num_workers=args.num_workers)
......
......@@ -44,22 +44,11 @@ def parse_args():
# params of dataset
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
'--dataset',
dest='dataset',
help='The dataset you want to evaluation',
type=str,
default=None)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
default='OpticDiscSeg')
# params of evaluate
parser.add_argument(
......@@ -140,19 +129,26 @@ def main(args):
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg',)")
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = OpticDiscSeg(transforms=eval_transforms, mode='eval')
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes)
model = models.UNet(num_classes=eval_dataset.num_classes)
evaluate(
model,
eval_dataset,
places=places,
model_dir=args.model_dir,
num_classes=args.num_classes,
num_classes=eval_dataset.num_classes,
batch_size=args.batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册