提交 4cdc97c3 编写于 作者: C chenguowei01

update citescapes dataset

上级 58dd74fb
...@@ -13,58 +13,61 @@ ...@@ -13,58 +13,61 @@
# limitations under the License. # limitations under the License.
import os import os
import glob
from .dataset import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar"
class Cityscapes(Dataset): class Cityscapes(Dataset):
def __init__(self, """Cityscapes dataset `https://www.cityscapes-dataset.com/`.
data_dir=None, The folder structure is as follow:
transforms=None, cityscapes
mode='train', |
download=True): |--leftImg8bit
| |--train
| |--val
| |--test
|
|--gtFine
| |--train
| |--val
| |--test
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
Args:
data_dir: Cityscapes dataset directory.
mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
"""
def __init__(self, data_dir, transforms=None, mode='train'):
self.data_dir = data_dir self.data_dir = data_dir
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = 19 self.num_classes = 19
if mode.lower() not in ['train', 'eval', 'test']: if mode.lower() not in ['train', 'val', 'test']:
raise Exception( raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format( "mode should be 'train', 'val' or 'test', but got {}.".format(
mode)) mode))
if self.transforms is None: if self.transforms is None:
raise Exception("transforms is necessary, but it is None.") raise Exception("transforms is necessary, but it is None.")
if self.data_dir is None: img_dir = os.path.join(self.data_dir, 'leftImg8bit')
if not download: grt_dir = os.path.join(self.data_dir, 'gtFine')
raise Exception("data_file not set and auto download disabled.") if not os.path.isdir(self.data_dir) or not os.path.isdir(
self.data_dir = download_file_and_uncompress( img_dir) or not os.path.isdir(grt_dir):
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) raise Exception(
"The dataset is not Found or the folder structure is nonconfoumance."
)
if mode == 'train': grt_files = sorted(
file_list = os.path.join(self.data_dir, 'train.list') glob.glob(
elif mode == 'eval': os.path.join(grt_dir, mode, '*', '*_gtFine_labelTrainIds.png')))
file_list = os.path.join(self.data_dir, 'val.list') img_files = sorted(
else: glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
file_list = os.path.join(self.data_dir, 'test.list')
with open(file_list, 'r') as f: self.file_list = [[img_path, grt_path]
for line in f: for img_path, grt_path in zip(img_files, grt_files)]
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
...@@ -21,10 +21,12 @@ cityscapes ...@@ -21,10 +21,12 @@ cityscapes
|--leftImg8bit |--leftImg8bit
| |--train | |--train
| |--val | |--val
| |--test
| |
|--gtFine |--gtFine
| |--train | |--train
| |--val | |--val
| |--test
""" """
import os import os
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册