data_reader.py 9.9 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
xiaohang 已提交
4 5
import os
import cv2
W
wanghaoshuang 已提交
6
import tarfile
X
xiaohang 已提交
7
import numpy as np
X
xiaohang 已提交
8
from PIL import Image
W
wanghaoshuang 已提交
9
from os import path
W
wanghaoshuang 已提交
10 11
from paddle.dataset.image import load_image
import paddle
L
LiufangSang 已提交
12
import random
W
wanghaoshuang 已提交
13

W
whs 已提交
14 15 16 17 18
try:
    input = raw_input
except NameError:
    pass

19 20
SOS = 0
EOS = 1
21
NUM_CLASSES = 95
W
wanghaoshuang 已提交
22
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
23

W
whs 已提交
24 25
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
W
wanghaoshuang 已提交
26 27 28 29 30 31 32 33
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
34 35

class DataGenerator(object):
36 37
    def __init__(self, model="crnn_ctc"):
        self.model = model
X
xiaohang 已提交
38

W
whs 已提交
39 40 41 42 43 44
    def train_reader(self,
                     img_root_dir,
                     img_label_list,
                     batchsize,
                     cycle,
                     shuffle=True):
X
xiaohang 已提交
45 46 47
        '''
        Reader interface for training.

X
xiaohang 已提交
48
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
49
        :type img_root_dir: str
X
xiaohang 已提交
50 51

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
52
        :type img_label_list: str
X
xiaohang 已提交
53

54 55 56
        :param cycle: If number of iterations is greater than dataset_size / batch_size
        it reiterates dataset over as many times as necessary.
        :type cycle: bool
L
LiufangSang 已提交
57

X
xiaohang 已提交
58
        '''
X
xiaohang 已提交
59

X
xiaohang 已提交
60
        img_label_lines = []
W
whs 已提交
61
        to_file = "tmp.txt"
L
LiufangSang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

        def _shuffle_data(input_file_path, output_file_path, shuffle,
                          batchsize):
            def _write_file(file_path, lines_to_write):
                open(file_path, 'w').writelines(
                    ["{}\n".format(item) for item in lines_to_write])

            input_file = open(input_file_path, 'r')
            lines_to_shuf = [line.strip() for line in input_file.readlines()]

            if not shuffle:
                _write_file(output_file_path, lines_to_shuf)
            elif batchsize == 1:
                random.shuffle(lines_to_shuf)
                _write_file(output_file_path, lines_to_shuf)
            else:
                #partial shuffle
                for i in range(len(lines_to_shuf)):
                    str_i = lines_to_shuf[i]
                    list_i = str_i.strip().split(' ')
                    str_i_ = "%04d%.4f " % (int(list_i[0]), random.random()
                                            ) + str_i
                    lines_to_shuf[i] = str_i_
                lines_to_shuf.sort()
                delete_num = random.randint(1, 100)
                del lines_to_shuf[0:delete_num]

                #batch merge and shuffle
                lines_concat = []
                for i in range(0, len(lines_to_shuf), batchsize):
                    lines_concat.append(' '.join(lines_to_shuf[i:i +
                                                               batchsize]))
                random.shuffle(lines_concat)

                #batch split
                out_file = open(output_file_path, 'w')
                for i in range(len(lines_concat)):
                    tmp_list = lines_concat[i].split(' ')
                    for j in range(int(len(tmp_list) / 5)):
                        out_file.write("{} {} {} {}\n".format(tmp_list[
                            5 * j + 1], tmp_list[5 * j + 2], tmp_list[
                                5 * j + 3], tmp_list[5 * j + 4]))
                out_file.close()
            input_file.close()

        _shuffle_data(img_label_list, to_file, shuffle, batchsize)
108
        print("finish batch shuffle")
W
whs 已提交
109
        img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
110 111

        def reader():
112
            sizes = len(img_label_lines) // batchsize
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            if sizes == 0:
                raise ValueError('Batch size is bigger than the dataset size.')
            while True:
                for i in range(sizes):
                    result = []
                    sz = [0, 0]
                    for j in range(batchsize):
                        line = img_label_lines[i * batchsize + j]
                        # h, w, img_name, labels
                        items = line.split(' ')

                        label = [int(c) for c in items[-1].split(',')]
                        img = Image.open(os.path.join(img_root_dir, items[
                            2])).convert('L')  #zhuanhuidu
                        if j == 0:
                            sz = img.size
                        img = img.resize((sz[0], sz[1]))
                        img = np.array(img) - 127.5
                        img = img[np.newaxis, ...]
132 133 134 135
                        if self.model == "crnn_ctc":
                            result.append([img, label])
                        else:
                            result.append([img, [SOS] + label, label + [EOS]])
136 137 138
                    yield result
                if not cycle:
                    break
X
xiaohang 已提交
139 140 141 142 143 144 145

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
146
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
147
        :type img_root_dir: str
X
xiaohang 已提交
148 149

        :param img_label_list: The path of the <image_name, label> file for testing.
W
wanghaoshuang 已提交
150
        :type img_label_list: str
X
xiaohang 已提交
151 152 153 154 155 156 157 158
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
159 160 161 162
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
163 164 165 166
                if self.model == "crnn_ctc":
                    yield img, label
                else:
                    yield img, [SOS] + label, label + [EOS]
X
xiaohang 已提交
167 168

        return reader
W
wanghaoshuang 已提交
169

170
    def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
W
wanghaoshuang 已提交
171 172 173 174 175 176 177 178 179 180
        '''A reader interface for inference.

        :param img_root_dir: The root path of the images for training.
        :type img_root_dir: str

        :param img_label_list: The path of the <image_name, label> file for
        inference. It should be the path of <image_path> file if img_root_dir
        was None. If img_label_list was set to None, it will read image path
        from stdin.
        :type img_root_dir: str
L
LiufangSang 已提交
181

182 183 184
        :param cycle: If number of iterations is greater than dataset_size /
        batch_size it reiterates dataset over as many times as necessary.
        :type cycle: bool
W
wanghaoshuang 已提交
185 186 187
        '''

        def reader():
188 189
            def yield_img_and_label(lines):
                for line in lines:
W
wanghaoshuang 已提交
190 191 192 193 194 195 196 197 198
                    if img_root_dir is not None:
                        # h, w, img_name, labels
                        img_name = line.split(' ')[2]
                        img_path = os.path.join(img_root_dir, img_name)
                    else:
                        img_path = line.strip("\t\n\r")
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
199
                    label = [int(c) for c in line.split(' ')[3].split(',')]
W
wanghaoshuang 已提交
200
                    yield img, label
201 202 203 204 205 206 207 208 209 210

            if img_label_list is not None:
                lines = []
                with open(img_label_list) as f:
                    lines = f.readlines()
                for img, label in yield_img_and_label(lines):
                    yield img, label
                while cycle:
                    for img, label in yield_img_and_label(lines):
                        yield img, label
W
wanghaoshuang 已提交
211 212
            else:
                while True:
W
whs 已提交
213
                    img_path = input("Please input the path of image: ")
W
wanghaoshuang 已提交
214 215 216 217 218 219 220
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    yield img, [[0]]

        return reader

W
wanghaoshuang 已提交
221 222

def num_classes():
W
wanghaoshuang 已提交
223 224
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
225 226 227 228
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
229 230
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
231 232 233
    return DATA_SHAPE


234 235 236 237 238 239
def train(batch_size,
          train_images_dir=None,
          train_list_file=None,
          cycle=False,
          model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
240 241 242 243 244
    if train_images_dir is None:
        data_dir = download_data()
        train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
    if train_list_file is None:
        train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
W
whs 已提交
245 246 247 248 249
    shuffle = True
    if 'ce_mode' in os.environ:
        shuffle = False
    return generator.train_reader(
        train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
W
wanghaoshuang 已提交
250 251


252 253 254 255 256
def test(batch_size=1,
         test_images_dir=None,
         test_list_file=None,
         model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
257 258 259 260 261
    if test_images_dir is None:
        data_dir = download_data()
        test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
    if test_list_file is None:
        test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
W
wanghaoshuang 已提交
262
    return paddle.batch(
W
wanghaoshuang 已提交
263
        generator.test_reader(test_images_dir, test_list_file), batch_size)
W
wanghaoshuang 已提交
264 265


266 267 268
def inference(batch_size=1,
              infer_images_dir=None,
              infer_list_file=None,
269 270 271
              cycle=False,
              model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
272
    return paddle.batch(
273 274
        generator.infer_reader(infer_images_dir, infer_list_file, cycle),
        batch_size)
W
wanghaoshuang 已提交
275 276


W
wanghaoshuang 已提交
277 278 279 280 281 282 283 284 285 286 287
def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir