diff --git a/dygraph/infer.py b/dygraph/infer.py index 2deac245ab4601536abf055e48055534022af7e3..40315023871a3d0b3c784dcc45b19d9b6fdea978 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -43,22 +43,11 @@ def parse_args(): # params of dataset parser.add_argument( - '--data_dir', - dest='data_dir', - help='The root directory of dataset', - type=str) - parser.add_argument( - '--test_list', - dest='test_list', - help='Val list file of dataset', + '--dataset', + dest='dataset', + help='The dataset you want to train', type=str, - default=None) - parser.add_argument( - '--num_classes', - dest='num_classes', - help='Number of classes', - type=int, - default=2) + default='OpticDiscSeg') # params of prediction parser.add_argument( @@ -142,12 +131,19 @@ def main(args): places = fluid.CUDAPlace(ParallelEnv().dev_id) \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ else fluid.CPUPlace() + + if args.dataset.lower() == 'opticdiscseg': + dataset = OpticDiscSeg + else: + raise Exception( + "The --dataset set wrong. It should be one of ('OpticDiscSeg',)") + with fluid.dygraph.guard(places): test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - test_dataset = OpticDiscSeg(transforms=test_transforms, mode='test') + test_dataset = dataset(transforms=test_transforms, mode='test') if args.model_name == 'UNet': - model = models.UNet(num_classes=args.num_classes) + model = models.UNet(num_classes=test_dataset.num_classes) infer( model, diff --git a/dygraph/train.py b/dygraph/train.py index 17ba2d68730e39c480c084c0de1afc790582ba41..b9d248b9daaec621f8fe00a2579d6b2346461d9a 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -40,31 +40,15 @@ def parse_args(): dest='model_name', help="Model type for traing, which is one of ('UNet')", type=str, - default='UNet') + default='OpticDiscSeg') # params of dataset parser.add_argument( - '--data_dir', - dest='data_dir', - help='The root directory of dataset', - type=str) - parser.add_argument( - '--train_list', - dest='train_list', - help='Train list file of dataset', - type=str) - parser.add_argument( - '--val_list', - dest='val_list', - help='Val list file of dataset', + '--dataset', + dest='dataset', + help='The dataset you want to train', type=str, - default=None) - parser.add_argument( - '--num_classes', - dest='num_classes', - help='Number of classes', - type=int, - default=2) + default='OpticDiscSeg') # params of training parser.add_argument( @@ -83,7 +67,7 @@ def parse_args(): parser.add_argument( '--batch_size', dest='batch_size', - help='Mini batch size', + help='Mini batch size of one gpu or cpu', type=int, default=2) parser.add_argument( @@ -210,6 +194,12 @@ def main(args): if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ else fluid.CPUPlace() + if args.dataset.lower() == 'opticdiscseg': + dataset = OpticDiscSeg + else: + raise Exception( + "The --dataset set wrong. It should be one of ('OpticDiscSeg',)") + with fluid.dygraph.guard(places): # Creat dataset reader train_transforms = T.Compose([ @@ -217,17 +207,18 @@ def main(args): T.RandomHorizontalFlip(), T.Normalize() ]) - train_dataset = OpticDiscSeg(transforms=train_transforms, mode='train') + train_dataset = dataset(transforms=train_transforms, mode='train') eval_dataset = None if args.do_eval: eval_transforms = T.Compose( [T.Resize(args.input_size), T.Normalize()]) - eval_dataset = OpticDiscSeg(transforms=eval_transforms, mode='eval') + eval_dataset = dataset(transforms=eval_transforms, mode='eval') if args.model_name == 'UNet': - model = models.UNet(num_classes=args.num_classes, ignore_index=255) + model = models.UNet( + num_classes=train_dataset.num_classes, ignore_index=255) # Creat optimizer num_steps_each_epoch = len(train_dataset) // args.batch_size @@ -251,7 +242,7 @@ def main(args): batch_size=args.batch_size, pretrained_model=args.pretrained_model, save_interval_epochs=args.save_interval_epochs, - num_classes=args.num_classes, + num_classes=train_dataset.num_classes, num_workers=args.num_workers) diff --git a/dygraph/val.py b/dygraph/val.py index 37b60921c5fc5a4b0a17b85ff413479dae479195..ccf89f44d99a8eae37213da67f6020acf62b3b94 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -44,22 +44,11 @@ def parse_args(): # params of dataset parser.add_argument( - '--data_dir', - dest='data_dir', - help='The root directory of dataset', - type=str) - parser.add_argument( - '--val_list', - dest='val_list', - help='Val list file of dataset', + '--dataset', + dest='dataset', + help='The dataset you want to evaluation', type=str, - default=None) - parser.add_argument( - '--num_classes', - dest='num_classes', - help='Number of classes', - type=int, - default=2) + default='OpticDiscSeg') # params of evaluate parser.add_argument( @@ -140,19 +129,26 @@ def main(args): places = fluid.CUDAPlace(ParallelEnv().dev_id) \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ else fluid.CPUPlace() + + if args.dataset.lower() == 'opticdiscseg': + dataset = OpticDiscSeg + else: + raise Exception( + "The --dataset set wrong. It should be one of ('OpticDiscSeg',)") + with fluid.dygraph.guard(places): eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - eval_dataset = OpticDiscSeg(transforms=eval_transforms, mode='eval') + eval_dataset = dataset(transforms=eval_transforms, mode='eval') if args.model_name == 'UNet': - model = models.UNet(num_classes=args.num_classes) + model = models.UNet(num_classes=eval_dataset.num_classes) evaluate( model, eval_dataset, places=places, model_dir=args.model_dir, - num_classes=args.num_classes, + num_classes=eval_dataset.num_classes, batch_size=args.batch_size)