提交 5d0a6b72 编写于 作者: C chenguowei01

update benchmark

上级 c5345b50
...@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS from dygraph.datasets import DATASETS
import dygraph.transforms as T import dygraph.transforms as T
from dygraph.models import MODELS #from dygraph.models import MODELS
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info from dygraph.utils import get_environ_info
from dygraph.utils import logger from dygraph.utils import logger
from dygraph.core import train from dygraph.core import train
...@@ -33,7 +34,7 @@ def parse_args(): ...@@ -33,7 +34,7 @@ def parse_args():
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help='Model type for training, which is one of {}'.format( help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))), str(list(manager.MODELS.components_dict.keys()))),
type=str, type=str,
default='UNet') default='UNet')
...@@ -161,18 +162,15 @@ def main(args): ...@@ -161,18 +162,15 @@ def main(args):
eval_dataset = None eval_dataset = None
if args.do_eval: if args.do_eval:
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Resize(args.input_size), [T.Padding((2049, 1025)),
T.Normalize()]) T.Normalize()])
eval_dataset = dataset( eval_dataset = dataset(
dataset_root=args.dataset_root, dataset_root=args.dataset_root,
transforms=eval_transforms, transforms=eval_transforms,
mode='val') mode='val')
if args.model_name not in MODELS: model = manager.MODELS[args.model_name](
raise Exception( num_classes=train_dataset.num_classes)
'`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer # Creat optimizer
# todo, may less one than len(loader) # todo, may less one than len(loader)
...@@ -195,7 +193,6 @@ def main(args): ...@@ -195,7 +193,6 @@ def main(args):
save_dir=args.save_dir, save_dir=args.save_dir,
iters=args.iters, iters=args.iters,
batch_size=args.batch_size, batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model, resume_model=args.resume_model,
save_interval_iters=args.save_interval_iters, save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters, log_iters=args.log_iters,
......
...@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS from dygraph.datasets import DATASETS
import dygraph.transforms as T import dygraph.transforms as T
from dygraph.models import MODELS #from dygraph.models import MODELS
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info from dygraph.utils import get_environ_info
from dygraph.utils import logger from dygraph.utils import logger
from dygraph.core import train from dygraph.core import train
...@@ -33,7 +34,7 @@ def parse_args(): ...@@ -33,7 +34,7 @@ def parse_args():
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help='Model type for training, which is one of {}'.format( help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))), str(list(manager.MODELS.components_dict.keys()))),
type=str, type=str,
default='UNet') default='UNet')
...@@ -166,11 +167,9 @@ def main(args): ...@@ -166,11 +167,9 @@ def main(args):
transforms=eval_transforms, transforms=eval_transforms,
mode='val') mode='val')
if args.model_name not in MODELS: model = manager.MODELS[args.model_name](
raise Exception( num_classes=train_dataset.num_classes,
'`--model_name` is invalid. it should be one of {}'.format( pretrained_model=args.pretrained_model)
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer # Creat optimizer
# todo, may less one than len(loader) # todo, may less one than len(loader)
...@@ -193,7 +192,6 @@ def main(args): ...@@ -193,7 +192,6 @@ def main(args):
save_dir=args.save_dir, save_dir=args.save_dir,
iters=args.iters, iters=args.iters,
batch_size=args.batch_size, batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model, resume_model=args.resume_model,
save_interval_iters=args.save_interval_iters, save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters, log_iters=args.log_iters,
......
...@@ -14,38 +14,5 @@ ...@@ -14,38 +14,5 @@
from .architectures import * from .architectures import *
from .unet import UNet from .unet import UNet
from .hrnet import *
from .deeplab import * from .deeplab import *
from .fcn import * from .fcn import *
# MODELS = {
# "UNet": UNet,
# "HRNet_W18_Small_V1": HRNet_W18_Small_V1,
# "HRNet_W18_Small_V2": HRNet_W18_Small_V2,
# "HRNet_W18": HRNet_W18,
# "HRNet_W30": HRNet_W30,
# "HRNet_W32": HRNet_W32,
# "HRNet_W40": HRNet_W40,
# "HRNet_W44": HRNet_W44,
# "HRNet_W48": HRNet_W48,
# "HRNet_W60": HRNet_W48,
# "HRNet_W64": HRNet_W64,
# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
# "SE_HRNet_W18": SE_HRNet_W18,
# "SE_HRNet_W30": SE_HRNet_W30,
# "SE_HRNet_W32": SE_HRNet_W30,
# "SE_HRNet_W40": SE_HRNet_W40,
# "SE_HRNet_W44": SE_HRNet_W44,
# "SE_HRNet_W48": SE_HRNet_W48,
# "SE_HRNet_W60": SE_HRNet_W60,
# "SE_HRNet_W64": SE_HRNet_W64,
# "DeepLabV3P": DeepLabV3P,
# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd,
# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8,
# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd,
# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8,
# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab,
# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large,
# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small
# }
...@@ -19,7 +19,6 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -19,7 +19,6 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS from dygraph.datasets import DATASETS
import dygraph.transforms as T import dygraph.transforms as T
#from dygraph.models import MODELS
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info from dygraph.utils import get_environ_info
from dygraph.utils import logger from dygraph.utils import logger
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册