cityscapes.py 2.7 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
C
chenguowei01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
C
chenguowei01 已提交
16
import glob
C
chenguowei01 已提交
17

C
chenguowei01 已提交
18
from .dataset import Dataset
W
wuzewu 已提交
19 20
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
C
chenguowei01 已提交
21 22


W
wuzewu 已提交
23
@manager.DATASETS.add_component
C
chenguowei01 已提交
24
class Cityscapes(Dataset):
C
chenguowei01 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    """Cityscapes dataset `https://www.cityscapes-dataset.com/`.
    The folder structure is as follow:
    cityscapes
    |
    |--leftImg8bit
    |  |--train
    |  |--val
    |  |--test
    |
    |--gtFine
    |  |--train
    |  |--val
    |  |--test
    Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.

    Args:
C
chenguowei01 已提交
41
        dataset_root: Cityscapes dataset directory.
C
chenguowei01 已提交
42 43 44 45
        mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
        transforms: Transforms for image.
    """

C
chenguowei01 已提交
46 47
    def __init__(self, dataset_root, transforms=None, mode='train'):
        self.dataset_root = dataset_root
W
wuzewu 已提交
48
        self.transforms = Compose(transforms)
C
chenguowei01 已提交
49 50 51 52
        self.file_list = list()
        self.mode = mode
        self.num_classes = 19

C
chenguowei01 已提交
53
        if mode.lower() not in ['train', 'val', 'test']:
C
chenguowei01 已提交
54
            raise Exception(
C
chenguowei01 已提交
55
                "mode should be 'train', 'val' or 'test', but got {}.".format(
C
chenguowei01 已提交
56 57 58
                    mode))

        if self.transforms is None:
C
chenguowei01 已提交
59
            raise Exception("`transforms` is necessary, but it is None.")
C
chenguowei01 已提交
60

C
chenguowei01 已提交
61 62 63 64 65
        img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
        grt_dir = os.path.join(self.dataset_root, 'gtFine')
        if self.dataset_root is None or not os.path.isdir(
                self.dataset_root) or not os.path.isdir(
                    img_dir) or not os.path.isdir(grt_dir):
C
chenguowei01 已提交
66 67 68
            raise Exception(
                "The dataset is not Found or the folder structure is nonconfoumance."
            )
C
chenguowei01 已提交
69

C
chenguowei01 已提交
70 71 72 73 74
        grt_files = sorted(
            glob.glob(
                os.path.join(grt_dir, mode, '*', '*_gtFine_labelTrainIds.png')))
        img_files = sorted(
            glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
C
chenguowei01 已提交
75

C
chenguowei01 已提交
76 77
        self.file_list = [[img_path, grt_path]
                          for img_path, grt_path in zip(img_files, grt_files)]