From 5d0a6b72ec96f80fb97ae80df485689b9ed72016 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 26 Aug 2020 18:07:02 +0800 Subject: [PATCH] update benchmark --- dygraph/benchmark/deeplabv3p.py | 15 ++++++--------- dygraph/benchmark/hrnet.py | 14 ++++++-------- dygraph/models/__init__.py | 33 --------------------------------- dygraph/train.py | 1 - 4 files changed, 12 insertions(+), 51 deletions(-) diff --git a/dygraph/benchmark/deeplabv3p.py b/dygraph/benchmark/deeplabv3p.py index 92c7e8ba..641f9cc6 100644 --- a/dygraph/benchmark/deeplabv3p.py +++ b/dygraph/benchmark/deeplabv3p.py @@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +#from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger from dygraph.core import train @@ -33,7 +34,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for training, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -161,18 +162,15 @@ def main(args): eval_dataset = None if args.do_eval: eval_transforms = T.Compose( - [T.Resize(args.input_size), + [T.Padding((2049, 1025)), T.Normalize()]) eval_dataset = dataset( dataset_root=args.dataset_root, transforms=eval_transforms, mode='val') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=train_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=train_dataset.num_classes) # Creat optimizer # todo, may less one than len(loader) @@ -195,7 +193,6 @@ def main(args): save_dir=args.save_dir, iters=args.iters, batch_size=args.batch_size, - pretrained_model=args.pretrained_model, resume_model=args.resume_model, save_interval_iters=args.save_interval_iters, log_iters=args.log_iters, diff --git a/dygraph/benchmark/hrnet.py b/dygraph/benchmark/hrnet.py index 4de9b06f..4544875b 100644 --- a/dygraph/benchmark/hrnet.py +++ b/dygraph/benchmark/hrnet.py @@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +#from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger from dygraph.core import train @@ -33,7 +34,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for training, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -166,11 +167,9 @@ def main(args): transforms=eval_transforms, mode='val') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=train_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=train_dataset.num_classes, + pretrained_model=args.pretrained_model) # Creat optimizer # todo, may less one than len(loader) @@ -193,7 +192,6 @@ def main(args): save_dir=args.save_dir, iters=args.iters, batch_size=args.batch_size, - pretrained_model=args.pretrained_model, resume_model=args.resume_model, save_interval_iters=args.save_interval_iters, log_iters=args.log_iters, diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index 6af6df34..caa734fe 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -14,38 +14,5 @@ from .architectures import * from .unet import UNet -from .hrnet import * from .deeplab import * from .fcn import * - -# MODELS = { -# "UNet": UNet, -# "HRNet_W18_Small_V1": HRNet_W18_Small_V1, -# "HRNet_W18_Small_V2": HRNet_W18_Small_V2, -# "HRNet_W18": HRNet_W18, -# "HRNet_W30": HRNet_W30, -# "HRNet_W32": HRNet_W32, -# "HRNet_W40": HRNet_W40, -# "HRNet_W44": HRNet_W44, -# "HRNet_W48": HRNet_W48, -# "HRNet_W60": HRNet_W48, -# "HRNet_W64": HRNet_W64, -# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1, -# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2, -# "SE_HRNet_W18": SE_HRNet_W18, -# "SE_HRNet_W30": SE_HRNet_W30, -# "SE_HRNet_W32": SE_HRNet_W30, -# "SE_HRNet_W40": SE_HRNet_W40, -# "SE_HRNet_W44": SE_HRNet_W44, -# "SE_HRNet_W48": SE_HRNet_W48, -# "SE_HRNet_W60": SE_HRNet_W60, -# "SE_HRNet_W64": SE_HRNet_W64, -# "DeepLabV3P": DeepLabV3P, -# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd, -# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8, -# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd, -# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8, -# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab, -# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large, -# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small -# } diff --git a/dygraph/train.py b/dygraph/train.py index eb16e996..382a41a0 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -19,7 +19,6 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -#from dygraph.models import MODELS from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger -- GitLab