未验证 提交 ec54aeff 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #346 from wuyefeilin/develop

...@@ -13,23 +13,15 @@ ...@@ -13,23 +13,15 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import sys
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes from dygraph.datasets import DATASETS
import transforms as T import dygraph.transforms as T
from models import MODELS from dygraph.models import MODELS
import utils.logging as logging from dygraph.utils import get_environ_info
from utils import get_environ_info from dygraph.core import train
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from core import train
def parse_args(): def parse_args():
...@@ -48,10 +40,16 @@ def parse_args(): ...@@ -48,10 +40,16 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to train, which is one of {}".format(
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training # params of training
parser.add_argument( parser.add_argument(
...@@ -135,36 +133,38 @@ def main(args): ...@@ -135,36 +133,38 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('`--dataset` is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader # Creat dataset reader
train_transforms = T.Compose([ train_transforms = T.Compose([
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25), T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(args.input_size), T.RandomPaddingCrop(args.input_size),
T.RandomHorizontalFlip(), T.RandomDistort(),
T.Normalize() T.Normalize(),
]) ])
train_dataset = dataset(transforms=train_transforms, mode='train') train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None eval_dataset = None
if args.do_eval: if args.do_eval:
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Padding((2049, 1025)), [T.Padding((2049, 1025)),
T.Normalize()] T.Normalize()])
) eval_dataset = dataset(
eval_dataset = dataset(transforms=eval_transforms, mode='eval') dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
'--model_name is invalid. it should be one of {}'.format( '`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys())))) str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes) model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
...@@ -174,16 +174,12 @@ def main(args): ...@@ -174,16 +174,12 @@ def main(args):
args.batch_size * ParallelEnv().nranks) args.batch_size * ParallelEnv().nranks)
decay_step = args.num_epochs * num_steps_each_epoch decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay( lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0.00001, power=0.9) args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
lr_decay, lr_decay,
momentum=0.9, momentum=0.9,
parameter_list=model.parameters(), parameter_list=model.parameters(),
#parameter_list=filter(lambda p: p.trainable, model.parameters()),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5)) regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train( train(
model, model,
......
...@@ -13,22 +13,15 @@ ...@@ -13,22 +13,15 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes from dygraph.datasets import DATASETS
import transforms as T import dygraph.transforms as T
from models import MODELS from dygraph.models import MODELS
import utils.logging as logging from dygraph.utils import get_environ_info
from utils import get_environ_info from dygraph.core import train
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from core import train, evaluate
def parse_args(): def parse_args():
...@@ -47,10 +40,16 @@ def parse_args(): ...@@ -47,10 +40,16 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to train, which is one of {}".format(
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='Cityscapes') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training # params of training
parser.add_argument( parser.add_argument(
...@@ -58,14 +57,14 @@ def parse_args(): ...@@ -58,14 +57,14 @@ def parse_args():
dest="input_size", dest="input_size",
help="The image size for net inputs.", help="The image size for net inputs.",
nargs=2, nargs=2,
default=[1024, 512], default=[512, 512],
type=int) type=int)
parser.add_argument( parser.add_argument(
'--num_epochs', '--num_epochs',
dest='num_epochs', dest='num_epochs',
help='Number epochs for training', help='Number epochs for training',
type=int, type=int,
default=500) default=100)
parser.add_argument( parser.add_argument(
'--batch_size', '--batch_size',
dest='batch_size', dest='batch_size',
...@@ -107,7 +106,7 @@ def parse_args(): ...@@ -107,7 +106,7 @@ def parse_args():
dest='num_workers', dest='num_workers',
help='Num workers for data loader', help='Num workers for data loader',
type=int, type=int,
default=2) default=0)
parser.add_argument( parser.add_argument(
'--do_eval', '--do_eval',
dest='do_eval', dest='do_eval',
...@@ -134,14 +133,10 @@ def main(args): ...@@ -134,14 +133,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('`--dataset` is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader # Creat dataset reader
...@@ -152,16 +147,22 @@ def main(args): ...@@ -152,16 +147,22 @@ def main(args):
T.RandomDistort(), T.RandomDistort(),
T.Normalize(), T.Normalize(),
]) ])
train_dataset = dataset(transforms=train_transforms, mode='train') train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None eval_dataset = None
if args.do_eval: if args.do_eval:
eval_transforms = T.Compose([T.Normalize()]) eval_transforms = T.Compose([T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
'--model_name is invalid. it should be one of {}'.format( '`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys())))) str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes) model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
...@@ -176,7 +177,8 @@ def main(args): ...@@ -176,7 +177,8 @@ def main(args):
lr_decay, lr_decay,
momentum=0.9, momentum=0.9,
parameter_list=model.parameters(), parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=5e-4)) regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train( train(
model, model,
train_dataset, train_dataset,
......
...@@ -20,8 +20,8 @@ import paddle.fluid as fluid ...@@ -20,8 +20,8 @@ import paddle.fluid as fluid
import cv2 import cv2
import tqdm import tqdm
import utils from dygraph import utils
import utils.logging as logging import dygraph.utils.logging as logging
def mkdir(path): def mkdir(path):
......
...@@ -19,10 +19,10 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -19,10 +19,10 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler from paddle.incubate.hapi.distributed import DistributedBatchSampler
import utils.logging as logging import dygraph.utils.logging as logging
from utils import load_pretrained_model from dygraph.utils import load_pretrained_model
from utils import resume from dygraph.utils import resume
from utils import Timer, calculate_eta from dygraph.utils import Timer, calculate_eta
from .val import evaluate from .val import evaluate
......
...@@ -20,9 +20,9 @@ import cv2 ...@@ -20,9 +20,9 @@ import cv2
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import utils.logging as logging import dygraph.utils.logging as logging
from utils import ConfusionMatrix from dygraph.utils import ConfusionMatrix
from utils import Timer, calculate_eta from dygraph.utils import Timer, calculate_eta
def evaluate(model, def evaluate(model,
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
from PIL import Image from PIL import Image
from .dataset import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
from .dataset import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
from .dataset import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
......
...@@ -17,11 +17,11 @@ import argparse ...@@ -17,11 +17,11 @@ import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from datasets import DATASETS from dygraph.datasets import DATASETS
import transforms as T import dygraph.transforms as T
from models import MODELS from dygraph.models import MODELS
from utils import get_environ_info from dygraph.utils import get_environ_info
from core import infer from dygraph.core import infer
def parse_args(): def parse_args():
......
...@@ -27,7 +27,7 @@ import numpy as np ...@@ -27,7 +27,7 @@ import numpy as np
from scipy.io import loadmat from scipy.io import loadmat
import tqdm import tqdm
from utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz' URL = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz'
......
...@@ -17,11 +17,11 @@ import argparse ...@@ -17,11 +17,11 @@ import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from datasets import DATASETS from dygraph.datasets import DATASETS
import transforms as T import dygraph.transforms as T
from models import MODELS from dygraph.models import MODELS
from utils import get_environ_info from dygraph.utils import get_environ_info
from core import train from dygraph.core import train
def parse_args(): def parse_args():
......
...@@ -17,11 +17,11 @@ import argparse ...@@ -17,11 +17,11 @@ import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from datasets import DATASETS from dygraph.datasets import DATASETS
import transforms as T import dygraph.transforms as T
from models import MODELS from dygraph.models import MODELS
from utils import get_environ_info from dygraph.utils import get_environ_info
from core import evaluate from dygraph.core import evaluate
def parse_args(): def parse_args():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册