diff --git a/dygraph/benchmark/deeplabv3p.py b/dygraph/benchmark/deeplabv3p.py index 92c7e8ba00ef67a51be09b877e816eab118ff7f9..641f9cc6003e4bdf9abb25d2f76868ce42679b85 100644 --- a/dygraph/benchmark/deeplabv3p.py +++ b/dygraph/benchmark/deeplabv3p.py @@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +#from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger from dygraph.core import train @@ -33,7 +34,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for training, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -161,18 +162,15 @@ def main(args): eval_dataset = None if args.do_eval: eval_transforms = T.Compose( - [T.Resize(args.input_size), + [T.Padding((2049, 1025)), T.Normalize()]) eval_dataset = dataset( dataset_root=args.dataset_root, transforms=eval_transforms, mode='val') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=train_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=train_dataset.num_classes) # Creat optimizer # todo, may less one than len(loader) @@ -195,7 +193,6 @@ def main(args): save_dir=args.save_dir, iters=args.iters, batch_size=args.batch_size, - pretrained_model=args.pretrained_model, resume_model=args.resume_model, save_interval_iters=args.save_interval_iters, log_iters=args.log_iters, diff --git a/dygraph/benchmark/hrnet.py b/dygraph/benchmark/hrnet.py index 4de9b06f0135b971a7795f6b3713599e26e798a5..4544875b5fca8d38048a195aec8b79dd4c1c3abe 100644 --- a/dygraph/benchmark/hrnet.py +++ b/dygraph/benchmark/hrnet.py @@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +#from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger from dygraph.core import train @@ -33,7 +34,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for training, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -166,11 +167,9 @@ def main(args): transforms=eval_transforms, mode='val') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=train_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=train_dataset.num_classes, + pretrained_model=args.pretrained_model) # Creat optimizer # todo, may less one than len(loader) @@ -193,7 +192,6 @@ def main(args): save_dir=args.save_dir, iters=args.iters, batch_size=args.batch_size, - pretrained_model=args.pretrained_model, resume_model=args.resume_model, save_interval_iters=args.save_interval_iters, log_iters=args.log_iters, diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index 6af6df34ab8a0037cfd97536634b5b96489bacf8..caa734fefb7bad40e82a51160495b5889d82aa34 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -14,38 +14,5 @@ from .architectures import * from .unet import UNet -from .hrnet import * from .deeplab import * from .fcn import * - -# MODELS = { -# "UNet": UNet, -# "HRNet_W18_Small_V1": HRNet_W18_Small_V1, -# "HRNet_W18_Small_V2": HRNet_W18_Small_V2, -# "HRNet_W18": HRNet_W18, -# "HRNet_W30": HRNet_W30, -# "HRNet_W32": HRNet_W32, -# "HRNet_W40": HRNet_W40, -# "HRNet_W44": HRNet_W44, -# "HRNet_W48": HRNet_W48, -# "HRNet_W60": HRNet_W48, -# "HRNet_W64": HRNet_W64, -# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1, -# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2, -# "SE_HRNet_W18": SE_HRNet_W18, -# "SE_HRNet_W30": SE_HRNet_W30, -# "SE_HRNet_W32": SE_HRNet_W30, -# "SE_HRNet_W40": SE_HRNet_W40, -# "SE_HRNet_W44": SE_HRNet_W44, -# "SE_HRNet_W48": SE_HRNet_W48, -# "SE_HRNet_W60": SE_HRNet_W60, -# "SE_HRNet_W64": SE_HRNet_W64, -# "DeepLabV3P": DeepLabV3P, -# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd, -# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8, -# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd, -# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8, -# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab, -# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large, -# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small -# } diff --git a/dygraph/train.py b/dygraph/train.py index eb16e9968a5a5691c12eddead2120df29f4eede4..382a41a06afe03efc28ac8493ad6910cf74333ca 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -19,7 +19,6 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -#from dygraph.models import MODELS from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger