提交 33cd5c87 编写于 作者: C chenguowei01

update benchmark

上级 5f72c538
......@@ -13,22 +13,14 @@
# limitations under the License.
import argparse
import os
import sys
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
from datasets import DATASETS
import transforms as T
from models import MODELS
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from core import train
......@@ -48,10 +40,16 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
help="The dataset you want to train, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training
parser.add_argument(
......@@ -135,36 +133,38 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
# Creat dataset reader
train_transforms = T.Compose([
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(args.input_size),
T.RandomHorizontalFlip(),
T.Normalize()
T.RandomDistort(),
T.Normalize(),
])
train_dataset = dataset(transforms=train_transforms, mode='train')
train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Padding((2049, 1025)),
T.Normalize()]
)
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
T.Normalize()])
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
'`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
......@@ -174,17 +174,13 @@ def main(args):
args.batch_size * ParallelEnv().nranks)
decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0.00001, power=0.9)
args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
#parameter_list=filter(lambda p: p.trainable, model.parameters()),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train(
model,
train_dataset,
......
......@@ -13,22 +13,15 @@
# limitations under the License.
import argparse
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
from datasets import DATASETS
import transforms as T
from models import MODELS
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from core import train, evaluate
from core import train
def parse_args():
......@@ -47,10 +40,16 @@ def parse_args():
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
help="The dataset you want to train, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='Cityscapes')
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training
parser.add_argument(
......@@ -58,14 +57,14 @@ def parse_args():
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[1024, 512],
default=[512, 512],
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='Number epochs for training',
type=int,
default=500)
default=100)
parser.add_argument(
'--batch_size',
dest='batch_size',
......@@ -107,7 +106,7 @@ def parse_args():
dest='num_workers',
help='Num workers for data loader',
type=int,
default=2)
default=0)
parser.add_argument(
'--do_eval',
dest='do_eval',
......@@ -134,14 +133,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
# Creat dataset reader
......@@ -152,16 +147,22 @@ def main(args):
T.RandomDistort(),
T.Normalize(),
])
train_dataset = dataset(transforms=train_transforms, mode='train')
train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose([T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
'`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
......@@ -176,7 +177,8 @@ def main(args):
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=5e-4))
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train(
model,
train_dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册