提交 4cdc97c3 编写于 作者: C chenguowei01

update citescapes dataset

上级 58dd74fb
......@@ -13,58 +13,61 @@
# limitations under the License.
import os
import glob
from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar"
class Cityscapes(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
mode='train',
download=True):
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`.
The folder structure is as follow:
cityscapes
|
|--leftImg8bit
| |--train
| |--val
| |--test
|
|--gtFine
| |--train
| |--val
| |--test
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
Args:
data_dir: Cityscapes dataset directory.
mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
"""
def __init__(self, data_dir, transforms=None, mode='train'):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = 19
if mode.lower() not in ['train', 'eval', 'test']:
if mode.lower() not in ['train', 'val', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
"mode should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
img_dir = os.path.join(self.data_dir, 'leftImg8bit')
grt_dir = os.path.join(self.data_dir, 'gtFine')
if not os.path.isdir(self.data_dir) or not os.path.isdir(
img_dir) or not os.path.isdir(grt_dir):
raise Exception(
"The dataset is not Found or the folder structure is nonconfoumance."
)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train.list')
elif mode == 'eval':
file_list = os.path.join(self.data_dir, 'val.list')
else:
file_list = os.path.join(self.data_dir, 'test.list')
grt_files = sorted(
glob.glob(
os.path.join(grt_dir, mode, '*', '*_gtFine_labelTrainIds.png')))
img_files = sorted(
glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
self.file_list = [[img_path, grt_path]
for img_path, grt_path in zip(img_files, grt_files)]
......@@ -21,10 +21,12 @@ cityscapes
|--leftImg8bit
| |--train
| |--val
| |--test
|
|--gtFine
| |--train
| |--val
| |--test
"""
import os
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册