culane.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
from ppdet.core.workspace import register, serializable
import cv2
import os
import tarfile
import numpy as np
import os.path as osp
from ppdet.data.source.dataset import DetDataset
from imgaug.augmentables.lines import LineStringsOnImage
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from ppdet.data.culane_utils import lane_to_linestrings
import pickle as pkl
from ppdet.utils.logger import setup_logger
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
from .dataset import DetDataset, _make_dataset, _is_valid_file
from ppdet.utils.download import download_dataset

logger = setup_logger(__name__)


@register
@serializable
class CULaneDataSet(DetDataset):
    def __init__(
            self,
            dataset_dir,
            cut_height,
            list_path,
            split='train',
            data_fields=['image'],
            video_file=None,
            frame_rate=-1, ):
        super(CULaneDataSet, self).__init__(
            dataset_dir=dataset_dir,
            cut_height=cut_height,
            split=split,
            data_fields=data_fields)
        self.dataset_dir = dataset_dir
        self.list_path = osp.join(dataset_dir, list_path)
        self.cut_height = cut_height
        self.data_fields = data_fields
        self.split = split
        self.training = 'train' in split
        self.data_infos = []
        self.video_file = video_file
        self.frame_rate = frame_rate
        self._imid2path = {}
        self.predict_dir = None

    def __len__(self):
        return len(self.data_infos)

    def check_or_download_dataset(self):
        if not osp.exists(self.dataset_dir):
            download_dataset("dataset", dataset="culane")
            # extract .tar files in self.dataset_dir
            for fname in os.listdir(self.dataset_dir):
                logger.info("Decompressing {}...".format(fname))
                # ignore .* files
                if fname.startswith('.'):
                    continue
                if fname.find('.tar.gz') >= 0:
                    with tarfile.open(osp.join(self.dataset_dir, fname)) as tf:
                        tf.extractall(path=self.dataset_dir)
        logger.info("Dataset files are ready.")

    def parse_dataset(self):
        logger.info('Loading CULane annotations...')
        if self.predict_dir is not None:
            logger.info('switch to predict mode')
            return
        # Waiting for the dataset to load is tedious, let's cache it
        os.makedirs('cache', exist_ok=True)
        cache_path = 'cache/culane_paddle_{}.pkl'.format(self.split)
        if os.path.exists(cache_path):
            with open(cache_path, 'rb') as cache_file:
                self.data_infos = pkl.load(cache_file)
                self.max_lanes = max(
                    len(anno['lanes']) for anno in self.data_infos)
                return

        with open(self.list_path) as list_file:
            for line in list_file:
                infos = self.load_annotation(line.split())
                self.data_infos.append(infos)

        # cache data infos to file
        with open(cache_path, 'wb') as cache_file:
            pkl.dump(self.data_infos, cache_file)

    def load_annotation(self, line):
        infos = {}
        img_line = line[0]
        img_line = img_line[1 if img_line[0] == '/' else 0::]
        img_path = os.path.join(self.dataset_dir, img_line)
        infos['img_name'] = img_line
        infos['img_path'] = img_path
        if len(line) > 1:
            mask_line = line[1]
            mask_line = mask_line[1 if mask_line[0] == '/' else 0::]
            mask_path = os.path.join(self.dataset_dir, mask_line)
            infos['mask_path'] = mask_path

        if len(line) > 2:
            exist_list = [int(l) for l in line[2:]]
            infos['lane_exist'] = np.array(exist_list)

        anno_path = img_path[:
                             -3] + 'lines.txt'  # remove sufix jpg and add lines.txt
        with open(anno_path, 'r') as anno_file:
            data = [
                list(map(float, line.split())) for line in anno_file.readlines()
            ]
        lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)
                  if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data]
        lanes = [list(set(lane)) for lane in lanes]  # remove duplicated points
        lanes = [lane for lane in lanes
                 if len(lane) > 2]  # remove lanes with less than 2 points

        lanes = [sorted(
            lane, key=lambda x: x[1]) for lane in lanes]  # sort by y
        infos['lanes'] = lanes

        return infos

    def set_images(self, images):
        self.predict_dir = images
        self.data_infos = self._load_images()

    def _find_images(self):
        predict_dir = self.predict_dir
        if not isinstance(predict_dir, Sequence):
            predict_dir = [predict_dir]
        images = []
        for im_dir in predict_dir:
            if os.path.isdir(im_dir):
                im_dir = os.path.join(self.predict_dir, im_dir)
                images.extend(_make_dataset(im_dir))
            elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
                images.append(im_dir)
        return images

    def _load_images(self):
        images = self._find_images()
        ct = 0
        records = []
        for image in images:
            assert image != '' and os.path.isfile(image), \
                    "Image {} not found".format(image)
            if self.sample_num > 0 and ct >= self.sample_num:
                break
            rec = {
                'im_id': np.array([ct]),
                "img_path": os.path.abspath(image),
                "img_name": os.path.basename(image),
                "lanes": []
            }
            self._imid2path[ct] = image
            ct += 1
            records.append(rec)
        assert len(records) > 0, "No image file found"
        return records

    def get_imid2path(self):
        return self._imid2path

    def __getitem__(self, idx):
        data_info = self.data_infos[idx]
        img = cv2.imread(data_info['img_path'])
        img = img[self.cut_height:, :, :]
        sample = data_info.copy()
        sample.update({'image': img})
        img_org = sample['image']

        if self.training:
            label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED)
            if len(label.shape) > 2:
                label = label[:, :, 0]
            label = label.squeeze()
            label = label[self.cut_height:, :]
            sample.update({'mask': label})
            if self.cut_height != 0:
                new_lanes = []
                for i in sample['lanes']:
                    lanes = []
                    for p in i:
                        lanes.append((p[0], p[1] - self.cut_height))
                    new_lanes.append(lanes)
                sample.update({'lanes': new_lanes})

            sample['mask'] = SegmentationMapsOnImage(
                sample['mask'], shape=img_org.shape)

        sample['full_img_path'] = data_info['img_path']
        sample['img_name'] = data_info['img_name']
        sample['im_id'] = np.array([idx])

        sample['image'] = sample['image'].copy().astype(np.uint8)
        sample['lanes'] = lane_to_linestrings(sample['lanes'])
        sample['lanes'] = LineStringsOnImage(
            sample['lanes'], shape=img_org.shape)
        sample['seg'] = np.zeros(img_org.shape)

        return sample