data_reader.py 10.7 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
xiaohang 已提交
17 18
import os
import cv2
W
wanghaoshuang 已提交
19
import tarfile
X
xiaohang 已提交
20
import numpy as np
X
xiaohang 已提交
21
from PIL import Image
W
wanghaoshuang 已提交
22
from os import path
W
wanghaoshuang 已提交
23 24
from paddle.dataset.image import load_image
import paddle
L
LiufangSang 已提交
25
import random
W
wanghaoshuang 已提交
26

W
whs 已提交
27 28 29 30 31
try:
    input = raw_input
except NameError:
    pass

32 33
SOS = 0
EOS = 1
34
NUM_CLASSES = 95
W
wanghaoshuang 已提交
35
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
36

W
whs 已提交
37 38
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
W
wanghaoshuang 已提交
39 40 41 42 43 44 45 46
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
47 48

class DataGenerator(object):
49 50
    def __init__(self, model="crnn_ctc"):
        self.model = model
X
xiaohang 已提交
51

W
whs 已提交
52 53 54 55 56 57
    def train_reader(self,
                     img_root_dir,
                     img_label_list,
                     batchsize,
                     cycle,
                     shuffle=True):
X
xiaohang 已提交
58 59 60
        '''
        Reader interface for training.

X
xiaohang 已提交
61
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
62
        :type img_root_dir: str
X
xiaohang 已提交
63 64

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
65
        :type img_label_list: str
X
xiaohang 已提交
66

67 68 69
        :param cycle: If number of iterations is greater than dataset_size / batch_size
        it reiterates dataset over as many times as necessary.
        :type cycle: bool
L
LiufangSang 已提交
70

X
xiaohang 已提交
71
        '''
X
xiaohang 已提交
72

X
xiaohang 已提交
73
        img_label_lines = []
W
whs 已提交
74
        to_file = "tmp.txt"
L
LiufangSang 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

        def _shuffle_data(input_file_path, output_file_path, shuffle,
                          batchsize):
            def _write_file(file_path, lines_to_write):
                open(file_path, 'w').writelines(
                    ["{}\n".format(item) for item in lines_to_write])

            input_file = open(input_file_path, 'r')
            lines_to_shuf = [line.strip() for line in input_file.readlines()]

            if not shuffle:
                _write_file(output_file_path, lines_to_shuf)
            elif batchsize == 1:
                random.shuffle(lines_to_shuf)
                _write_file(output_file_path, lines_to_shuf)
            else:
                #partial shuffle
                for i in range(len(lines_to_shuf)):
                    str_i = lines_to_shuf[i]
                    list_i = str_i.strip().split(' ')
                    str_i_ = "%04d%.4f " % (int(list_i[0]), random.random()
                                            ) + str_i
                    lines_to_shuf[i] = str_i_
                lines_to_shuf.sort()
                delete_num = random.randint(1, 100)
                del lines_to_shuf[0:delete_num]

                #batch merge and shuffle
                lines_concat = []
                for i in range(0, len(lines_to_shuf), batchsize):
                    lines_concat.append(' '.join(lines_to_shuf[i:i +
                                                               batchsize]))
                random.shuffle(lines_concat)

                #batch split
                out_file = open(output_file_path, 'w')
                for i in range(len(lines_concat)):
                    tmp_list = lines_concat[i].split(' ')
                    for j in range(int(len(tmp_list) / 5)):
                        out_file.write("{} {} {} {}\n".format(tmp_list[
                            5 * j + 1], tmp_list[5 * j + 2], tmp_list[
                                5 * j + 3], tmp_list[5 * j + 4]))
                out_file.close()
            input_file.close()

        _shuffle_data(img_label_list, to_file, shuffle, batchsize)
121
        print("finish batch shuffle")
W
whs 已提交
122
        img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
123 124

        def reader():
125
            sizes = len(img_label_lines) // batchsize
126 127 128 129 130 131 132 133 134 135 136 137 138
            if sizes == 0:
                raise ValueError('Batch size is bigger than the dataset size.')
            while True:
                for i in range(sizes):
                    result = []
                    sz = [0, 0]
                    for j in range(batchsize):
                        line = img_label_lines[i * batchsize + j]
                        # h, w, img_name, labels
                        items = line.split(' ')

                        label = [int(c) for c in items[-1].split(',')]
                        img = Image.open(os.path.join(img_root_dir, items[
139
                            2])).convert('L')
140 141
                        if j == 0:
                            sz = img.size
142
                        img = img.resize((sz[0], DATA_SHAPE[1]))
143 144
                        img = np.array(img) - 127.5
                        img = img[np.newaxis, ...]
145 146 147 148
                        if self.model == "crnn_ctc":
                            result.append([img, label])
                        else:
                            result.append([img, [SOS] + label, label + [EOS]])
149 150 151
                    yield result
                if not cycle:
                    break
X
xiaohang 已提交
152 153 154 155 156 157 158

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
159
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
160
        :type img_root_dir: str
X
xiaohang 已提交
161 162

        :param img_label_list: The path of the <image_name, label> file for testing.
W
wanghaoshuang 已提交
163
        :type img_label_list: str
X
xiaohang 已提交
164 165 166 167 168 169 170 171
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
172 173
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
174 175

                img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
X
xiaohang 已提交
176 177
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
178 179 180 181
                if self.model == "crnn_ctc":
                    yield img, label
                else:
                    yield img, [SOS] + label, label + [EOS]
X
xiaohang 已提交
182 183

        return reader
W
wanghaoshuang 已提交
184

185
    def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
W
wanghaoshuang 已提交
186 187 188 189 190 191 192 193 194 195
        '''A reader interface for inference.

        :param img_root_dir: The root path of the images for training.
        :type img_root_dir: str

        :param img_label_list: The path of the <image_name, label> file for
        inference. It should be the path of <image_path> file if img_root_dir
        was None. If img_label_list was set to None, it will read image path
        from stdin.
        :type img_root_dir: str
L
LiufangSang 已提交
196

197 198 199
        :param cycle: If number of iterations is greater than dataset_size /
        batch_size it reiterates dataset over as many times as necessary.
        :type cycle: bool
W
wanghaoshuang 已提交
200 201 202
        '''

        def reader():
203 204
            def yield_img_and_label(lines):
                for line in lines:
W
wanghaoshuang 已提交
205 206 207 208 209 210 211
                    if img_root_dir is not None:
                        # h, w, img_name, labels
                        img_name = line.split(' ')[2]
                        img_path = os.path.join(img_root_dir, img_name)
                    else:
                        img_path = line.strip("\t\n\r")
                    img = Image.open(img_path).convert('L')
212
                    img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
W
wanghaoshuang 已提交
213 214
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
215
                    label = [int(c) for c in line.split(' ')[3].split(',')]
W
wanghaoshuang 已提交
216
                    yield img, label
217 218 219 220 221 222 223 224 225 226

            if img_label_list is not None:
                lines = []
                with open(img_label_list) as f:
                    lines = f.readlines()
                for img, label in yield_img_and_label(lines):
                    yield img, label
                while cycle:
                    for img, label in yield_img_and_label(lines):
                        yield img, label
W
wanghaoshuang 已提交
227 228
            else:
                while True:
W
whs 已提交
229
                    img_path = input("Please input the path of image: ")
W
wanghaoshuang 已提交
230
                    img = Image.open(img_path).convert('L')
231
                    img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
W
wanghaoshuang 已提交
232 233 234 235 236 237
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    yield img, [[0]]

        return reader

W
wanghaoshuang 已提交
238 239

def num_classes():
W
wanghaoshuang 已提交
240 241
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
242 243 244 245
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
246 247
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
248 249 250
    return DATA_SHAPE


251 252 253 254 255 256
def train(batch_size,
          train_images_dir=None,
          train_list_file=None,
          cycle=False,
          model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
257 258 259 260 261
    if train_images_dir is None:
        data_dir = download_data()
        train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
    if train_list_file is None:
        train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
W
whs 已提交
262 263 264 265 266
    shuffle = True
    if 'ce_mode' in os.environ:
        shuffle = False
    return generator.train_reader(
        train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
W
wanghaoshuang 已提交
267 268


269 270 271 272 273
def test(batch_size=1,
         test_images_dir=None,
         test_list_file=None,
         model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
274 275 276 277 278
    if test_images_dir is None:
        data_dir = download_data()
        test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
    if test_list_file is None:
        test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
W
wanghaoshuang 已提交
279
    return paddle.batch(
W
wanghaoshuang 已提交
280
        generator.test_reader(test_images_dir, test_list_file), batch_size)
W
wanghaoshuang 已提交
281 282


283 284 285
def inference(batch_size=1,
              infer_images_dir=None,
              infer_list_file=None,
286 287 288
              cycle=False,
              model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
289
    return paddle.batch(
290 291
        generator.infer_reader(infer_images_dir, infer_list_file, cycle),
        batch_size)
W
wanghaoshuang 已提交
292 293


W
wanghaoshuang 已提交
294 295 296 297 298 299 300 301 302 303 304
def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir