提交 8d785cff 编写于 作者: C chenguowei01

update datasets

上级 967ebd6f
...@@ -25,8 +25,7 @@ class Dataset(fluid.io.Dataset): ...@@ -25,8 +25,7 @@ class Dataset(fluid.io.Dataset):
Args: Args:
data_dir: The dataset directory. data_dir: The dataset directory.
num_classes: Number of classes. num_classes: Number of classes.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'test'). Default: 'train'. mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary. train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow: The contents of train_list file are as follow:
image1.jpg ground_truth1.png image1.jpg ground_truth1.png
...@@ -46,7 +45,6 @@ class Dataset(fluid.io.Dataset): ...@@ -46,7 +45,6 @@ class Dataset(fluid.io.Dataset):
def __init__(self, def __init__(self,
data_dir, data_dir,
num_classes, num_classes,
image_set='train',
mode='train', mode='train',
train_list=None, train_list=None,
val_list=None, val_list=None,
...@@ -59,21 +57,16 @@ class Dataset(fluid.io.Dataset): ...@@ -59,21 +57,16 @@ class Dataset(fluid.io.Dataset):
self.mode = mode self.mode = mode
self.num_classes = num_classes self.num_classes = num_classes
if image_set.lower() not in ['train', 'val', 'test']: if mode.lower() not in ['train', 'val', 'test']:
raise Exception( raise Exception(
"image_set should be one of ('train', 'val', 'test'), but got {}." "mode should be 'train', 'val' or 'test', but got {}.".format(
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode)) mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir self.data_dir = data_dir
if image_set == 'train': if mode == 'train':
if train_list is None: if train_list is None:
raise Exception( raise Exception(
'When mode is "train", train_list is necessary, but it is None.' 'When mode is "train", train_list is necessary, but it is None.'
...@@ -83,10 +76,10 @@ class Dataset(fluid.io.Dataset): ...@@ -83,10 +76,10 @@ class Dataset(fluid.io.Dataset):
'train_list is not found: {}'.format(train_list)) 'train_list is not found: {}'.format(train_list))
else: else:
file_list = train_list file_list = train_list
elif image_set == 'eval': elif mode == 'val':
if val_list is None: if val_list is None:
raise Exception( raise Exception(
'When mode is "eval", val_list is necessary, but it is None.' 'When mode is "val", val_list is necessary, but it is None.'
) )
elif not os.path.exists(val_list): elif not os.path.exists(val_list):
raise Exception('val_list is not found: {}'.format(val_list)) raise Exception('val_list is not found: {}'.format(val_list))
...@@ -106,9 +99,9 @@ class Dataset(fluid.io.Dataset): ...@@ -106,9 +99,9 @@ class Dataset(fluid.io.Dataset):
for line in f: for line in f:
items = line.strip().split(separator) items = line.strip().split(separator)
if len(items) != 2: if len(items) != 2:
if mode == 'train' or mode == 'eval': if mode == 'train' or mode == 'val':
raise Exception( raise Exception(
"File list format incorrect! It should be" "File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator)) " image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.data_dir, items[0])
grt_path = None grt_path = None
...@@ -119,19 +112,19 @@ class Dataset(fluid.io.Dataset): ...@@ -119,19 +112,19 @@ class Dataset(fluid.io.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx] image_path, grt_path = self.file_list[idx]
if self.mode == 'train': if self.mode == 'test':
im, im_info, label = self.transforms(im=image_path, label=grt_path) im, im_info, _ = self.transforms(im=image_path)
return im, label im = im[np.newaxis, ...]
elif self.mode == 'eval': return im, im_info, image_path
elif self.mode == 'val':
im, im_info, _ = self.transforms(im=image_path) im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...] im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path)) label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :] label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label return im, im_info, label
if self.mode == 'test': else:
im, im_info, _ = self.transforms(im=image_path) im, im_info, label = self.transforms(im=image_path, label=grt_path)
im = im[np.newaxis, ...] return im, label
return im, im_info, image_path
def __len__(self): def __len__(self):
return len(self.file_list) return len(self.file_list)
...@@ -25,7 +25,6 @@ class OpticDiscSeg(Dataset): ...@@ -25,7 +25,6 @@ class OpticDiscSeg(Dataset):
def __init__(self, def __init__(self,
data_dir=None, data_dir=None,
transforms=None, transforms=None,
image_set='train',
mode='train', mode='train',
download=True): download=True):
self.data_dir = data_dir self.data_dir = data_dir
...@@ -34,14 +33,9 @@ class OpticDiscSeg(Dataset): ...@@ -34,14 +33,9 @@ class OpticDiscSeg(Dataset):
self.mode = mode self.mode = mode
self.num_classes = 2 self.num_classes = 2
if image_set.lower() not in ['train', 'val', 'test']: if mode.lower() not in ['train', 'val', 'test']:
raise Exception( raise Exception(
"image_set should be one of ('train', 'val', 'test'), but got {}." "mode should be 'train', 'val' or 'test', but got {}.".format(
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode)) mode))
if self.transforms is None: if self.transforms is None:
...@@ -53,9 +47,9 @@ class OpticDiscSeg(Dataset): ...@@ -53,9 +47,9 @@ class OpticDiscSeg(Dataset):
self.data_dir = download_file_and_uncompress( self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if image_set == 'train': if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt') file_list = os.path.join(self.data_dir, 'train_list.txt')
elif image_set == 'val': elif mode == 'val':
file_list = os.path.join(self.data_dir, 'val_list.txt') file_list = os.path.join(self.data_dir, 'val_list.txt')
else: else:
file_list = os.path.join(self.data_dir, 'test_list.txt') file_list = os.path.join(self.data_dir, 'test_list.txt')
...@@ -64,7 +58,7 @@ class OpticDiscSeg(Dataset): ...@@ -64,7 +58,7 @@ class OpticDiscSeg(Dataset):
for line in f: for line in f:
items = line.strip().split() items = line.strip().split()
if len(items) != 2: if len(items) != 2:
if mode == 'train' or mode == 'eval': if mode == 'train' or mode == 'val':
raise Exception( raise Exception(
"File list format incorrect! It should be" "File list format incorrect! It should be"
" image_name label_name\\n") " image_name label_name\\n")
......
...@@ -25,15 +25,13 @@ class PascalVOC(Dataset): ...@@ -25,15 +25,13 @@ class PascalVOC(Dataset):
please run the voc_augment.py in tools. please run the voc_augment.py in tools.
Args: Args:
data_dir: The dataset directory. data_dir: The dataset directory.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'trainval', 'trainaug). Default: 'train'. mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
transforms: Transforms for image. transforms: Transforms for image.
download: Whether to download dataset if data_dir is None. download: Whether to download dataset if data_dir is None.
""" """
def __init__(self, def __init__(self,
data_dir=None, data_dir=None,
image_set='train',
mode='train', mode='train',
transforms=None, transforms=None,
download=True): download=True):
...@@ -43,22 +41,17 @@ class PascalVOC(Dataset): ...@@ -43,22 +41,17 @@ class PascalVOC(Dataset):
self.file_list = list() self.file_list = list()
self.num_classes = 21 self.num_classes = 21
if image_set.lower() not in ['train', 'val', 'trainval', 'trainaug']: if mode.lower() not in ['train', 'trainval', 'trainaug', 'val']:
raise Exception( raise Exception(
"image_set should be one of ('train', 'val', 'trainval', 'trainaug'), but got {}." "mode should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}."
.format(image_set)) .format(mode))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None: if self.data_dir is None:
if not download: if not download:
raise Exception("data_file not set and auto download disabled.") raise Exception("data_dir not set and auto download disabled.")
self.data_dir = download_file_and_uncompress( self.data_dir = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=DATA_HOME,
...@@ -68,19 +61,19 @@ class PascalVOC(Dataset): ...@@ -68,19 +61,19 @@ class PascalVOC(Dataset):
image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets', image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets',
'Segmentation') 'Segmentation')
if image_set == 'train': if mode == 'train':
file_list = os.path.join(image_set_dir, 'train.txt') file_list = os.path.join(image_set_dir, 'train.txt')
elif image_set == 'val': elif mode == 'val':
file_list = os.path.join(image_set_dir, 'val.txt') file_list = os.path.join(image_set_dir, 'val.txt')
elif image_set == 'trainval': elif mode == 'trainval':
file_list = os.path.join(image_set_dir, 'trainval.txt') file_list = os.path.join(image_set_dir, 'trainval.txt')
elif image_set == 'trainaug': elif mode == 'trainaug':
file_list = os.path.join(image_set_dir, 'train.txt') file_list = os.path.join(image_set_dir, 'train.txt')
file_list_aug = os.path.join(image_set_dir, 'aug.txt') file_list_aug = os.path.join(image_set_dir, 'aug.txt')
if not os.path.exists(file_list_aug): if not os.path.exists(file_list_aug):
raise Exception( raise Exception(
"When image_set is 'trainaug', Pascal Voc dataset should be augmented, " "When mode is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode." "Please make sure voc_augment.py has been properly run when using this mode."
) )
...@@ -95,10 +88,11 @@ class PascalVOC(Dataset): ...@@ -95,10 +88,11 @@ class PascalVOC(Dataset):
image_path = os.path.join(img_dir, ''.join([line, '.jpg'])) image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png'])) grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
if image_set == 'trainaug': if mode == 'trainaug':
with open(file_list_aug, 'r') as f: with open(file_list_aug, 'r') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg'])) image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png'])) grt_path = os.path.join(grt_dir_aug, ''.join([line,
'.png']))
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
...@@ -13,20 +13,13 @@ ...@@ -13,20 +13,13 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import DATASETS from datasets import DATASETS
import transforms as T import transforms as T
from models import MODELS from models import MODELS
import utils
import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from core import infer from core import infer
...@@ -43,7 +36,7 @@ def parse_args(): ...@@ -43,7 +36,7 @@ def parse_args():
type=str, type=str,
default='UNet') default='UNet')
# params of dataset # params of infer
parser.add_argument( parser.add_argument(
'--dataset', '--dataset',
dest='dataset', dest='dataset',
......
...@@ -153,7 +153,7 @@ def main(args): ...@@ -153,7 +153,7 @@ def main(args):
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Resize(args.input_size), [T.Resize(args.input_size),
T.Normalize()]) T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
......
...@@ -87,7 +87,7 @@ def main(args): ...@@ -87,7 +87,7 @@ def main(args):
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='val')
if args.model_name not in MODELS: if args.model_name not in MODELS:
raise Exception( raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册