提交 5d0a6b72 编写于 作者: C chenguowei01

update benchmark

上级 c5345b50
......@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
from dygraph.models import MODELS
#from dygraph.models import MODELS
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import logger
from dygraph.core import train
......@@ -33,7 +34,7 @@ def parse_args():
'--model_name',
dest='model_name',
help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))),
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
......@@ -161,18 +162,15 @@ def main(args):
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
[T.Padding((2049, 1025)),
T.Normalize()])
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
'`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
model = manager.MODELS[args.model_name](
num_classes=train_dataset.num_classes)
# Creat optimizer
# todo, may less one than len(loader)
......@@ -195,7 +193,6 @@ def main(args):
save_dir=args.save_dir,
iters=args.iters,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters,
......
......@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
from dygraph.models import MODELS
#from dygraph.models import MODELS
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import logger
from dygraph.core import train
......@@ -33,7 +34,7 @@ def parse_args():
'--model_name',
dest='model_name',
help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))),
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
......@@ -166,11 +167,9 @@ def main(args):
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
'`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
model = manager.MODELS[args.model_name](
num_classes=train_dataset.num_classes,
pretrained_model=args.pretrained_model)
# Creat optimizer
# todo, may less one than len(loader)
......@@ -193,7 +192,6 @@ def main(args):
save_dir=args.save_dir,
iters=args.iters,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters,
......
......@@ -14,38 +14,5 @@
from .architectures import *
from .unet import UNet
from .hrnet import *
from .deeplab import *
from .fcn import *
# MODELS = {
# "UNet": UNet,
# "HRNet_W18_Small_V1": HRNet_W18_Small_V1,
# "HRNet_W18_Small_V2": HRNet_W18_Small_V2,
# "HRNet_W18": HRNet_W18,
# "HRNet_W30": HRNet_W30,
# "HRNet_W32": HRNet_W32,
# "HRNet_W40": HRNet_W40,
# "HRNet_W44": HRNet_W44,
# "HRNet_W48": HRNet_W48,
# "HRNet_W60": HRNet_W48,
# "HRNet_W64": HRNet_W64,
# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
# "SE_HRNet_W18": SE_HRNet_W18,
# "SE_HRNet_W30": SE_HRNet_W30,
# "SE_HRNet_W32": SE_HRNet_W30,
# "SE_HRNet_W40": SE_HRNet_W40,
# "SE_HRNet_W44": SE_HRNet_W44,
# "SE_HRNet_W48": SE_HRNet_W48,
# "SE_HRNet_W60": SE_HRNet_W60,
# "SE_HRNet_W64": SE_HRNet_W64,
# "DeepLabV3P": DeepLabV3P,
# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd,
# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8,
# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd,
# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8,
# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab,
# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large,
# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small
# }
......@@ -19,7 +19,6 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
#from dygraph.models import MODELS
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import logger
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册