未验证 提交 000f6ddc 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #340 from wuyefeilin/dygraph

Update dataset,  add Pascal VOC and ADE20K
...@@ -117,7 +117,8 @@ def train(model, ...@@ -117,7 +117,8 @@ def train(model,
avg_loss * nranks, lr, avg_train_batch_cost, avg_loss * nranks, lr, avg_train_batch_cost,
avg_train_reader_cost, eta)) avg_train_reader_cost, eta))
if use_vdl: if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, num_steps) log_writer.add_scalar('Train/loss', avg_loss * nranks,
num_steps)
log_writer.add_scalar('Train/lr', lr, num_steps) log_writer.add_scalar('Train/lr', lr, num_steps)
log_writer.add_scalar('Train/batch_cost', log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, num_steps) avg_train_batch_cost, num_steps)
......
...@@ -15,3 +15,12 @@ ...@@ -15,3 +15,12 @@
from .dataset import Dataset from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
from .voc import PascalVOC
from .ade import ADE20K
DATASETS = {
"OpticDiscSeg": OpticDiscSeg,
"Cityscapes": Cityscapes,
"PascalVOC": PascalVOC,
"ADE20K": ADE20K
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from PIL import Image
from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args:
dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if `dataset_root` is None.
"""
def __init__(self,
dataset_root=None,
mode='train',
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.mode = mode
self.file_list = list()
self.num_classes = 150
if mode.lower() not in ['train', 'val']:
raise Exception(
"`mode` should be one of ('train', 'val') in ADE20K dataset, but got {}."
.format(mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
if self.dataset_root is None:
if not download:
raise Exception(
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='ADEChallengeData2016')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
if mode == 'train':
img_dir = os.path.join(self.dataset_root, 'images/training')
grt_dir = os.path.join(self.dataset_root, 'annotations/training')
elif mode == 'val':
img_dir = os.path.join(self.dataset_root, 'images/validation')
grt_dir = os.path.join(self.dataset_root, 'annotations/validation')
img_files = os.listdir(img_dir)
grt_files = [i.replace('.jpg', '.png') for i in img_files]
for i in range(len(img_files)):
img_path = os.path.join(img_dir, img_files[i])
grt_path = os.path.join(grt_dir, grt_files[i])
self.file_list.append([img_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
elif self.mode == 'val':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label - 1
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
else:
im, im_info, label = self.transforms(im=image_path, label=grt_path)
label = label - 1
return im, label
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,59 +13,62 @@ ...@@ -13,59 +13,62 @@
# limitations under the License. # limitations under the License.
import os import os
import glob
from .dataset import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar"
class Cityscapes(Dataset): class Cityscapes(Dataset):
def __init__(self, """Cityscapes dataset `https://www.cityscapes-dataset.com/`.
data_dir=None, The folder structure is as follow:
transforms=None, cityscapes
mode='train', |
download=True): |--leftImg8bit
self.data_dir = data_dir | |--train
| |--val
| |--test
|
|--gtFine
| |--train
| |--val
| |--test
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
Args:
dataset_root: Cityscapes dataset directory.
mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
"""
def __init__(self, dataset_root, transforms=None, mode='train'):
self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = 19 self.num_classes = 19
if mode.lower() not in ['train', 'eval', 'test']: if mode.lower() not in ['train', 'val', 'test']:
raise Exception( raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format( "mode should be 'train', 'val' or 'test', but got {}.".format(
mode)) mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transform is necessary, but it is None.") raise Exception("`transforms` is necessary, but it is None.")
self.data_dir = data_dir img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
if self.data_dir is None: grt_dir = os.path.join(self.dataset_root, 'gtFine')
if not download: if self.dataset_root is None or not os.path.isdir(
raise Exception("data_file not set and auto download disabled.") self.dataset_root) or not os.path.isdir(
self.data_dir = download_file_and_uncompress( img_dir) or not os.path.isdir(grt_dir):
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) raise Exception(
"The dataset is not Found or the folder structure is nonconfoumance."
)
if mode == 'train': grt_files = sorted(
file_list = os.path.join(self.data_dir, 'train.list') glob.glob(
elif mode == 'eval': os.path.join(grt_dir, mode, '*', '*_gtFine_labelTrainIds.png')))
file_list = os.path.join(self.data_dir, 'val.list') img_files = sorted(
else: glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
file_list = os.path.join(self.data_dir, 'test.list')
with open(file_list, 'r') as f: self.file_list = [[img_path, grt_path]
for line in f: for img_path, grt_path in zip(img_files, grt_files)]
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,53 +20,83 @@ from PIL import Image ...@@ -20,53 +20,83 @@ from PIL import Image
class Dataset(fluid.io.Dataset): class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format.
Args:
dataset_root: The dataset directory.
num_classes: Number of classes.
mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow:
image1.jpg ground_truth1.png
image2.jpg ground_truth2.png
val_list: The evaluation dataset file. When image_set is 'val', val_list is necessary.
The contents is the same as train_list
test_list: The test dataset file. When image_set is 'test', test_list is necessary.
The annotation file is not necessary in test_list file.
separator: The separator of dataset list. Default: ' '.
transforms: Transforms for image.
Examples:
todo
"""
def __init__(self, def __init__(self,
data_dir, dataset_root,
num_classes, num_classes,
mode='train',
train_list=None, train_list=None,
val_list=None, val_list=None,
test_list=None, test_list=None,
separator=' ', separator=' ',
transforms=None, transforms=None):
mode='train'): self.dataset_root = dataset_root
self.data_dir = data_dir
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = num_classes self.num_classes = num_classes
if mode.lower() not in ['train', 'eval', 'test']: if mode.lower() not in ['train', 'val', 'test']:
raise Exception( raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format( "mode should be 'train', 'val' or 'test', but got {}.".format(
mode)) mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transform is necessary, but it is None.") raise Exception("`transforms` is necessary, but it is None.")
self.dataset_root = dataset_root
if not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
self.data_dir = data_dir
if mode == 'train': if mode == 'train':
if train_list is None: if train_list is None:
raise Exception( raise Exception(
'When mode is "train", train_list is need, but it is None.') 'When `mode` is "train", `train_list` is necessary, but it is None.'
)
elif not os.path.exists(train_list): elif not os.path.exists(train_list):
raise Exception( raise Exception(
'train_list is not found: {}'.format(train_list)) '`train_list` is not found: {}'.format(train_list))
else: else:
file_list = train_list file_list = train_list
elif mode == 'eval': elif mode == 'val':
if val_list is None: if val_list is None:
raise Exception( raise Exception(
'When mode is "eval", val_list is need, but it is None.') 'When `mode` is "val", `val_list` is necessary, but it is None.'
)
elif not os.path.exists(val_list): elif not os.path.exists(val_list):
raise Exception('val_list is not found: {}'.format(val_list)) raise Exception('`val_list` is not found: {}'.format(val_list))
else: else:
file_list = val_list file_list = val_list
else: else:
if test_list is None: if test_list is None:
raise Exception( raise Exception(
'When mode is "test", test_list is need, but it is None.') 'When `mode` is "test", `test_list` is necessary, but it is None.'
)
elif not os.path.exists(test_list): elif not os.path.exists(test_list):
raise Exception('test_list is not found: {}'.format(test_list)) raise Exception(
'`test_list` is not found: {}'.format(test_list))
else: else:
file_list = test_list file_list = test_list
...@@ -74,32 +104,32 @@ class Dataset(fluid.io.Dataset): ...@@ -74,32 +104,32 @@ class Dataset(fluid.io.Dataset):
for line in f: for line in f:
items = line.strip().split(separator) items = line.strip().split(separator)
if len(items) != 2: if len(items) != 2:
if mode == 'train' or mode == 'eval': if mode == 'train' or mode == 'val':
raise Exception( raise Exception(
"File list format incorrect! It should be" "File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator)) " image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = None grt_path = None
else: else:
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.data_dir, items[1]) grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
def __getitem__(self, idx): def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx] image_path, grt_path = self.file_list[idx]
if self.mode == 'train': if self.mode == 'test':
im, im_info, label = self.transforms(im=image_path, label=grt_path) im, im_info, _ = self.transforms(im=image_path)
return im, label im = im[np.newaxis, ...]
elif self.mode == 'eval': return im, im_info, image_path
elif self.mode == 'val':
im, im_info, _ = self.transforms(im=image_path) im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...] im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path)) label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :] label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label return im, im_info, label
if self.mode == 'test': else:
im, im_info, _ = self.transforms(im=image_path) im, im_info, label = self.transforms(im=image_path, label=grt_path)
im = im[np.newaxis, ...] return im, label
return im, im_info, image_path
def __len__(self): def __len__(self):
return len(self.file_list) return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,49 +23,52 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" ...@@ -23,49 +23,52 @@ URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
class OpticDiscSeg(Dataset): class OpticDiscSeg(Dataset):
def __init__(self, def __init__(self,
data_dir=None, dataset_root=None,
transforms=None, transforms=None,
mode='train', mode='train',
download=True): download=True):
self.data_dir = data_dir self.dataset_root = dataset_root
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = 2 self.num_classes = 2
if mode.lower() not in ['train', 'eval', 'test']: if mode.lower() not in ['train', 'val', 'test']:
raise Exception( raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format( "`mode` should be 'train', 'val' or 'test', but got {}.".format(
mode)) mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transform is necessary, but it is None.") raise Exception("`transforms` is necessary, but it is None.")
self.data_dir = data_dir if self.dataset_root is None:
if self.data_dir is None:
if not download: if not download:
raise Exception("data_file not set and auto download disabled.") raise Exception(
self.data_dir = download_file_and_uncompress( "`data_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
if mode == 'train': if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt') file_list = os.path.join(self.dataset_root, 'train_list.txt')
elif mode == 'eval': elif mode == 'val':
file_list = os.path.join(self.data_dir, 'val_list.txt') file_list = os.path.join(self.dataset_root, 'val_list.txt')
else: else:
file_list = os.path.join(self.data_dir, 'test_list.txt') file_list = os.path.join(self.dataset_root, 'test_list.txt')
with open(file_list, 'r') as f: with open(file_list, 'r') as f:
for line in f: for line in f:
items = line.strip().split() items = line.strip().split()
if len(items) != 2: if len(items) != 2:
if mode == 'train' or mode == 'eval': if mode == 'train' or mode == 'val':
raise Exception( raise Exception(
"File list format incorrect! It should be" "File list format incorrect! It should be"
" image_name label_name\\n") " image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = None grt_path = None
else: else:
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.data_dir, items[1]) grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
Args:
dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if dataset_root is None.
"""
def __init__(self,
dataset_root=None,
mode='train',
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.mode = mode
self.file_list = list()
self.num_classes = 21
if mode.lower() not in ['train', 'trainval', 'trainaug', 'val']:
raise Exception(
"`mode` should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}."
.format(mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
if self.dataset_root is None:
if not download:
raise Exception(
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='VOCdevkit')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
image_set_dir = os.path.join(self.dataset_root, 'VOC2012', 'ImageSets',
'Segmentation')
if mode == 'train':
file_list = os.path.join(image_set_dir, 'train.txt')
elif mode == 'val':
file_list = os.path.join(image_set_dir, 'val.txt')
elif mode == 'trainval':
file_list = os.path.join(image_set_dir, 'trainval.txt')
elif mode == 'trainaug':
file_list = os.path.join(image_set_dir, 'train.txt')
file_list_aug = os.path.join(image_set_dir, 'aug.txt')
if not os.path.exists(file_list_aug):
raise Exception(
"When `mode` is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode."
)
img_dir = os.path.join(self.dataset_root, 'VOC2012', 'JPEGImages')
grt_dir = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClass')
grt_dir_aug = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClassAug')
with open(file_list, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
self.file_list.append([image_path, grt_path])
if mode == 'trainaug':
with open(file_list_aug, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir_aug, ''.join([line,
'.png']))
self.file_list.append([image_path, grt_path])
...@@ -13,20 +13,13 @@ ...@@ -13,20 +13,13 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import OpticDiscSeg, Cityscapes from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils
import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from core import infer from core import infer
...@@ -43,14 +36,20 @@ def parse_args(): ...@@ -43,14 +36,20 @@ def parse_args():
type=str, type=str,
default='UNet') default='UNet')
# params of dataset # params of infer
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to test, which is one of {}".format(
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of prediction # params of prediction
parser.add_argument( parser.add_argument(
...@@ -88,22 +87,21 @@ def main(args): ...@@ -88,22 +87,21 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('`--dataset` is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test') test_dataset = dataset(
dataset_root=args.dataset_root,
transforms=test_transforms,
mode='test')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
'--model_name is invalid. it should be one of {}'.format( '`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys())))) str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=test_dataset.num_classes) model = MODELS[args.model_name](num_classes=test_dataset.num_classes)
......
...@@ -645,12 +645,14 @@ class FuseLayers(fluid.dygraph.Layer): ...@@ -645,12 +645,14 @@ class FuseLayers(fluid.dygraph.Layer):
residual_func_idx = 0 residual_func_idx = 0
for i in range(self._actual_ch): for i in range(self._actual_ch):
residual = input[i] residual = input[i]
residual_shape = residual.shape[-2:]
for j in range(len(self._in_channels)): for j in range(len(self._in_channels)):
if j > i: if j > i:
y = self.residual_func_list[residual_func_idx](input[j]) y = self.residual_func_list[residual_func_idx](input[j])
residual_func_idx += 1 residual_func_idx += 1
y = fluid.layers.resize_bilinear(input=y, scale=2**(j - i)) y = fluid.layers.resize_bilinear(
input=y, out_shape=residual_shape)
residual = fluid.layers.elementwise_add( residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None) x=residual, y=y, act=None)
elif j < i: elif j < i:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: convert_cityscapes.py
This file is based on https://github.com/mcordts/cityscapesScripts to generate **labelTrainIds.png for training.
Before running, you should download the cityscapes form https://www.cityscapes-dataset.com/ and make the folder
structure as follow:
cityscapes
|
|--leftImg8bit
| |--train
| |--val
| |--test
|
|--gtFine
| |--train
| |--val
| |--test
"""
import os
import argparse
from multiprocessing import Pool, cpu_count
import glob
from cityscapesscripts.preparation.json2labelImg import json2labelImg
def parse_args():
parser = argparse.ArgumentParser(
description='Generate **labelTrainIds.png for training')
parser.add_argument(
'--cityscapes_path',
dest='cityscapes_path',
help='cityscapes path',
type=str)
parser.add_argument(
'--num_workers',
dest='num_workers',
help='How many processes are used for data conversion',
type=int,
default=cpu_count())
return parser.parse_args()
def gen_labelTrainIds(json_file):
label_file = json_file.replace("_polygons.json", "_labelTrainIds.png")
json2labelImg(json_file, label_file, "trainIds")
def main():
args = parse_args()
fine_path = os.path.join(args.cityscapes_path, 'gtFine')
json_files = glob.glob(os.path.join(fine_path, '*', '*', '*_polygons.json'))
print('generating **_labelTrainIds.png')
p = Pool(args.num_workers)
for f in json_files:
p.apply_async(gen_labelTrainIds, args=(f, ))
p.close()
p.join()
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: voc_augment.py
This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html>
to augment the Pascal VOC.
"""
import os
import argparse
from multiprocessing import Pool, cpu_count
import cv2
import numpy as np
from scipy.io import loadmat
import tqdm
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz'
def parse_args():
parser = argparse.ArgumentParser(
description=
'Convert SBD to Pascal Voc annotations to augment the train dataset of Pascal Voc'
)
parser.add_argument(
'--voc_path',
dest='voc_path',
help='pascal voc path',
type=str,
default=os.path.join(DATA_HOME, 'VOCdevkit'))
parser.add_argument(
'--num_workers',
dest='num_workers',
help='How many processes are used for data conversion',
type=int,
default=cpu_count())
return parser.parse_args()
def mat_to_png(mat_file, sbd_cls_dir, save_dir):
mat_path = os.path.join(sbd_cls_dir, mat_file)
mat = loadmat(mat_path)
mask = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
save_file = os.path.join(save_dir, mat_file.replace('mat', 'png'))
cv2.imwrite(save_file, mask)
def main():
args = parse_args()
sbd_path = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='benchmark_RELEASE')
with open(os.path.join(sbd_path, 'dataset/train.txt'), 'r') as f:
sbd_file_list = [line.strip() for line in f]
with open(os.path.join(sbd_path, 'dataset/val.txt'), 'r') as f:
sbd_file_list += [line.strip() for line in f]
if not os.path.exists(args.voc_path):
raise Exception(
'There is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
)
with open(
os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/trainval.txt'),
'r') as f:
voc_file_list = [line.strip() for line in f]
aug_file_list = list(set(sbd_file_list) - set(voc_file_list))
with open(
os.path.join(args.voc_path,
'VOC2012/ImageSets/Segmentation/aug.txt'), 'w') as f:
f.writelines(''.join([line, '\n']) for line in aug_file_list)
sbd_cls_dir = os.path.join(sbd_path, 'dataset/cls')
save_dir = os.path.join(args.voc_path, 'VOC2012/SegmentationClassAug')
if not os.path.exists(save_dir):
os.mkdir(save_dir)
mat_file_list = os.listdir(sbd_cls_dir)
p = Pool(args.num_workers)
for f in tqdm.tqdm(mat_file_list):
p.apply_async(mat_to_png, args=(f, sbd_cls_dir, save_dir))
p.close()
p.join()
if __name__ == '__main__':
main()
...@@ -13,21 +13,14 @@ ...@@ -13,21 +13,14 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from core import train from core import train
...@@ -47,10 +40,16 @@ def parse_args(): ...@@ -47,10 +40,16 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to train, which is one of {}".format(
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training # params of training
parser.add_argument( parser.add_argument(
...@@ -134,14 +133,10 @@ def main(args): ...@@ -134,14 +133,10 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('`--dataset` is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader # Creat dataset reader
...@@ -150,18 +145,24 @@ def main(args): ...@@ -150,18 +145,24 @@ def main(args):
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
T.Normalize() T.Normalize()
]) ])
train_dataset = dataset(transforms=train_transforms, mode='train') train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None eval_dataset = None
if args.do_eval: if args.do_eval:
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Resize(args.input_size), [T.Resize(args.input_size),
T.Normalize()]) T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
'--model_name is invalid. it should be one of {}'.format( '`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys())))) str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes) model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
......
...@@ -85,8 +85,8 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress): ...@@ -85,8 +85,8 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress):
for total_num, index, rootpath in handler(filepath, extrapath): for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress: if print_progress:
done = int(50 * float(index) / total_num) done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" % progress(
('=' * done, float(100 * index) / total_num)) "[%-50s] %.2f%%" % ('=' * done, float(100 * index) / total_num))
if print_progress: if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
...@@ -132,4 +132,4 @@ def download_file_and_uncompress(url, ...@@ -132,4 +132,4 @@ def download_file_and_uncompress(url,
print_progress) print_progress)
savename = os.path.join(extrapath, savename) savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname) shutil.move(savename, extraname)
return savename return extraname
...@@ -88,6 +88,7 @@ def resume(model, optimizer, resume_model): ...@@ -88,6 +88,7 @@ def resume(model, optimizer, resume_model):
if resume_model is not None: if resume_model is not None:
logging.info('Resume model from {}'.format(resume_model)) logging.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model): if os.path.exists(resume_model):
resume_model = os.path.normpath(resume_model)
ckpt_path = os.path.join(resume_model, 'model') ckpt_path = os.path.join(resume_model, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path) para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict) model.set_dict(para_state_dict)
......
...@@ -13,25 +13,14 @@ ...@@ -13,25 +13,14 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import math
import numpy as np
import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import ConfusionMatrix
from utils import Timer, calculate_eta
from core import evaluate from core import evaluate
...@@ -51,10 +40,16 @@ def parse_args(): ...@@ -51,10 +40,16 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
help= help="The dataset you want to evaluation, which is one of {}".format(
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')", str(list(DATASETS.keys()))),
type=str, type=str,
default='OpticDiscSeg') default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of evaluate # params of evaluate
parser.add_argument( parser.add_argument(
...@@ -80,22 +75,21 @@ def main(args): ...@@ -80,22 +75,21 @@ def main(args):
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg': if args.dataset not in DATASETS:
dataset = OpticDiscSeg raise Exception('`--dataset` is invalid. it should be one of {}'.format(
elif args.dataset.lower() == 'cityscapes': str(list(DATASETS.keys()))))
dataset = Cityscapes dataset = DATASETS[args.dataset]
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
'--model_name is invalid. it should be one of {}'.format( '`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys())))) str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=eval_dataset.num_classes) model = MODELS[args.model_name](num_classes=eval_dataset.num_classes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册