reader.py 5.0 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
D
Dun 已提交
4 5
import cv2
import numpy as np
6 7
import os
import six
D
Dun 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37

default_config = {
    "shuffle": True,
    "min_resize": 0.5,
    "max_resize": 2,
    "crop_size": 769,
}


def slice_with_pad(a, s, value=0):
    pads = []
    slices = []
    for i in range(len(a.shape)):
        if i >= len(s):
            pads.append([0, 0])
            slices.append([0, a.shape[i]])
        else:
            l, r = s[i]
            if l < 0:
                pl = -l
                l = 0
            else:
                pl = 0
            if r > a.shape[i]:
                pr = r - a.shape[i]
                r = a.shape[i]
            else:
                pr = 0
            pads.append([pl, pr])
            slices.append([l, r])
38
    slices = list(map(lambda x: slice(x[0], x[1], 1), slices))
D
Dun 已提交
39 40 41 42 43 44 45
    a = a[slices]
    a = np.pad(a, pad_width=pads, mode='constant', constant_values=value)
    return a


class CityscapeDataset:
    def __init__(self, dataset_dir, subset='train', config=default_config):
46 47 48 49 50 51 52 53 54 55 56
        label_dirname = os.path.join(dataset_dir, 'gtFine/' + subset)
        if six.PY2:
            import commands
            label_files = commands.getoutput(
                "find %s -type f | grep labelTrainIds | sort" %
                label_dirname).splitlines()
        else:
            import subprocess
            label_files = subprocess.getstatusoutput(
                "find %s -type f | grep labelTrainIds | sort" %
                label_dirname)[-1].splitlines()
D
Dun 已提交
57 58 59 60 61 62 63
        self.label_files = label_files
        self.label_dirname = label_dirname
        self.index = 0
        self.subset = subset
        self.dataset_dir = dataset_dir
        self.config = config
        self.reset()
64
        print("total number", len(label_files))
D
Dun 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    def reset(self, shuffle=False):
        self.index = 0
        if self.config["shuffle"]:
            np.random.shuffle(self.label_files)

    def next_img(self):
        self.index += 1
        if self.index >= len(self.label_files):
            self.reset()

    def get_img(self):
        shape = self.config["crop_size"]
        while True:
            ln = self.label_files[self.index]
80 81 82
            img_name = os.path.join(
                self.dataset_dir,
                'leftImg8bit/' + self.subset + ln[len(self.label_dirname):])
D
Dun 已提交
83 84 85 86
            img_name = img_name.replace('gtFine_labelTrainIds', 'leftImg8bit')
            label = cv2.imread(ln)
            img = cv2.imread(img_name)
            if img is None:
87
                print("load img failed:", img_name)
D
Dun 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
                self.next_img()
            else:
                break
        if shape == -1:
            return img, label, ln
        random_scale = np.random.rand(1) * (self.config['max_resize'] -
                                            self.config['min_resize']
                                            ) + self.config['min_resize']
        crop_size = int(shape / random_scale)
        bb = crop_size // 2

        def _randint(low, high):
            return int(np.random.rand(1) * (high - low) + low)

        offset_x = np.random.randint(bb, max(bb + 1, img.shape[0] -
                                             bb)) - crop_size // 2
        offset_y = np.random.randint(bb, max(bb + 1, img.shape[1] -
                                             bb)) - crop_size // 2
        img_crop = slice_with_pad(img, [[offset_x, offset_x + crop_size],
                                        [offset_y, offset_y + crop_size]], 128)
        img = cv2.resize(img_crop, (shape, shape))
        label_crop = slice_with_pad(label, [[offset_x, offset_x + crop_size],
                                            [offset_y, offset_y + crop_size]],
                                    255)
        label = cv2.resize(
            label_crop, (shape, shape), interpolation=cv2.INTER_NEAREST)
        return img, label, ln + str(
            (offset_x, offset_y, crop_size, random_scale))

    def get_batch(self, batch_size=1):
        imgs = []
        labels = []
        names = []
        while len(imgs) < batch_size:
            img, label, ln = self.get_img()
            imgs.append(img)
            labels.append(label)
            names.append(ln)
            self.next_img()
        return np.array(imgs), np.array(labels), names

    def get_batch_generator(self, batch_size, total_step):
        def do_get_batch():
            for i in range(total_step):
                imgs, labels, names = self.get_batch(batch_size)
                labels = labels.astype(np.int32)[:, :, :, 0]
                imgs = imgs[:, :, :, ::-1].transpose(
                    0, 3, 1, 2).astype(np.float32) / (255.0 / 2) - 1
                yield i, imgs, labels, names

        batches = do_get_batch()
        try:
            from prefetch_generator import BackgroundGenerator
            batches = BackgroundGenerator(batches, 100)
        except:
143 144 145
            print(
                "You can install 'prefetch_generator' for acceleration of data reading."
            )
D
Dun 已提交
146
        return batches