提交 8d785cff 编写于 作者: C chenguowei01

update datasets

上级 967ebd6f
......@@ -25,8 +25,7 @@ class Dataset(fluid.io.Dataset):
Args:
data_dir: The dataset directory.
num_classes: Number of classes.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'test'). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow:
image1.jpg ground_truth1.png
......@@ -46,7 +45,6 @@ class Dataset(fluid.io.Dataset):
def __init__(self,
data_dir,
num_classes,
image_set='train',
mode='train',
train_list=None,
val_list=None,
......@@ -59,21 +57,16 @@ class Dataset(fluid.io.Dataset):
self.mode = mode
self.num_classes = num_classes
if image_set.lower() not in ['train', 'val', 'test']:
if mode.lower() not in ['train', 'val', 'test']:
raise Exception(
"image_set should be one of ('train', 'val', 'test'), but got {}."
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
"mode should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
self.data_dir = data_dir
if image_set == 'train':
if mode == 'train':
if train_list is None:
raise Exception(
'When mode is "train", train_list is necessary, but it is None.'
......@@ -83,10 +76,10 @@ class Dataset(fluid.io.Dataset):
'train_list is not found: {}'.format(train_list))
else:
file_list = train_list
elif image_set == 'eval':
elif mode == 'val':
if val_list is None:
raise Exception(
'When mode is "eval", val_list is necessary, but it is None.'
'When mode is "val", val_list is necessary, but it is None.'
)
elif not os.path.exists(val_list):
raise Exception('val_list is not found: {}'.format(val_list))
......@@ -106,9 +99,9 @@ class Dataset(fluid.io.Dataset):
for line in f:
items = line.strip().split(separator)
if len(items) != 2:
if mode == 'train' or mode == 'eval':
if mode == 'train' or mode == 'val':
raise Exception(
"File list format incorrect! It should be"
"File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
......@@ -119,19 +112,19 @@ class Dataset(fluid.io.Dataset):
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
if self.mode == 'train':
im, im_info, label = self.transforms(im=image_path, label=grt_path)
return im, label
elif self.mode == 'eval':
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
elif self.mode == 'val':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
else:
im, im_info, label = self.transforms(im=image_path, label=grt_path)
return im, label
def __len__(self):
return len(self.file_list)
......@@ -25,7 +25,6 @@ class OpticDiscSeg(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
image_set='train',
mode='train',
download=True):
self.data_dir = data_dir
......@@ -34,14 +33,9 @@ class OpticDiscSeg(Dataset):
self.mode = mode
self.num_classes = 2
if image_set.lower() not in ['train', 'val', 'test']:
if mode.lower() not in ['train', 'val', 'test']:
raise Exception(
"image_set should be one of ('train', 'val', 'test'), but got {}."
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
"mode should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
......@@ -53,9 +47,9 @@ class OpticDiscSeg(Dataset):
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if image_set == 'train':
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif image_set == 'val':
elif mode == 'val':
file_list = os.path.join(self.data_dir, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
......@@ -64,7 +58,7 @@ class OpticDiscSeg(Dataset):
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
if mode == 'train' or mode == 'val':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
......
......@@ -25,15 +25,13 @@ class PascalVOC(Dataset):
please run the voc_augment.py in tools.
Args:
data_dir: The dataset directory.
image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'trainval', 'trainaug). Default: 'train'.
mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'.
mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if data_dir is None.
"""
def __init__(self,
data_dir=None,
image_set='train',
mode='train',
transforms=None,
download=True):
......@@ -43,22 +41,17 @@ class PascalVOC(Dataset):
self.file_list = list()
self.num_classes = 21
if image_set.lower() not in ['train', 'val', 'trainval', 'trainaug']:
if mode.lower() not in ['train', 'trainval', 'trainaug', 'val']:
raise Exception(
"image_set should be one of ('train', 'val', 'trainval', 'trainaug'), but got {}."
.format(image_set))
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
"mode should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}."
.format(mode))
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
raise Exception("data_dir not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
......@@ -68,19 +61,19 @@ class PascalVOC(Dataset):
image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets',
'Segmentation')
if image_set == 'train':
if mode == 'train':
file_list = os.path.join(image_set_dir, 'train.txt')
elif image_set == 'val':
elif mode == 'val':
file_list = os.path.join(image_set_dir, 'val.txt')
elif image_set == 'trainval':
elif mode == 'trainval':
file_list = os.path.join(image_set_dir, 'trainval.txt')
elif image_set == 'trainaug':
elif mode == 'trainaug':
file_list = os.path.join(image_set_dir, 'train.txt')
file_list_aug = os.path.join(image_set_dir, 'aug.txt')
if not os.path.exists(file_list_aug):
raise Exception(
"When image_set is 'trainaug', Pascal Voc dataset should be augmented, "
"When mode is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode."
)
......@@ -95,10 +88,11 @@ class PascalVOC(Dataset):
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
self.file_list.append([image_path, grt_path])
if image_set == 'trainaug':
if mode == 'trainaug':
with open(file_list_aug, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
grt_path = os.path.join(grt_dir_aug, ''.join([line,
'.png']))
self.file_list.append([image_path, grt_path])
......@@ -13,20 +13,13 @@
# limitations under the License.
import argparse
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import DATASETS
import transforms as T
from models import MODELS
import utils
import utils.logging as logging
from utils import get_environ_info
from core import infer
......@@ -43,7 +36,7 @@ def parse_args():
type=str,
default='UNet')
# params of dataset
# params of infer
parser.add_argument(
'--dataset',
dest='dataset',
......
......@@ -153,7 +153,7 @@ def main(args):
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
eval_dataset = dataset(transforms=eval_transforms, mode='val')
if args.model_name not in MODELS:
raise Exception(
......
......@@ -87,7 +87,7 @@ def main(args):
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
eval_dataset = dataset(transforms=eval_transforms, mode='val')
if args.model_name not in MODELS:
raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册