data_reader.py 10.7 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
xiaohang 已提交
17 18
import os
import cv2
W
wanghaoshuang 已提交
19
import tarfile
X
xiaohang 已提交
20
import numpy as np
X
xiaohang 已提交
21
from PIL import Image
W
wanghaoshuang 已提交
22
from os import path
W
wanghaoshuang 已提交
23 24
from paddle.dataset.image import load_image
import paddle
L
LiufangSang 已提交
25
import random
W
wanghaoshuang 已提交
26

W
whs 已提交
27 28 29 30 31
try:
    input = raw_input
except NameError:
    pass

32 33
SOS = 0
EOS = 1
34
NUM_CLASSES = 95
D
Double_V 已提交
35 36
IMG_WIDTH = 384
DATA_SHAPE = [1, 48, IMG_WIDTH]
X
xiaohang 已提交
37

W
whs 已提交
38 39
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
W
wanghaoshuang 已提交
40 41 42 43 44 45 46 47
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
48 49

class DataGenerator(object):
50 51
    def __init__(self, model="crnn_ctc"):
        self.model = model
X
xiaohang 已提交
52

W
whs 已提交
53 54 55 56 57 58
    def train_reader(self,
                     img_root_dir,
                     img_label_list,
                     batchsize,
                     cycle,
                     shuffle=True):
X
xiaohang 已提交
59 60 61
        '''
        Reader interface for training.

X
xiaohang 已提交
62
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
63
        :type img_root_dir: str
X
xiaohang 已提交
64 65

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
66
        :type img_label_list: str
X
xiaohang 已提交
67

68 69 70
        :param cycle: If number of iterations is greater than dataset_size / batch_size
        it reiterates dataset over as many times as necessary.
        :type cycle: bool
L
LiufangSang 已提交
71

X
xiaohang 已提交
72
        '''
X
xiaohang 已提交
73

X
xiaohang 已提交
74
        img_label_lines = []
W
whs 已提交
75
        to_file = "tmp.txt"
L
LiufangSang 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

        def _shuffle_data(input_file_path, output_file_path, shuffle,
                          batchsize):
            def _write_file(file_path, lines_to_write):
                open(file_path, 'w').writelines(
                    ["{}\n".format(item) for item in lines_to_write])

            input_file = open(input_file_path, 'r')
            lines_to_shuf = [line.strip() for line in input_file.readlines()]

            if not shuffle:
                _write_file(output_file_path, lines_to_shuf)
            elif batchsize == 1:
                random.shuffle(lines_to_shuf)
                _write_file(output_file_path, lines_to_shuf)
            else:
                #partial shuffle
                for i in range(len(lines_to_shuf)):
                    str_i = lines_to_shuf[i]
                    list_i = str_i.strip().split(' ')
                    str_i_ = "%04d%.4f " % (int(list_i[0]), random.random()
                                            ) + str_i
                    lines_to_shuf[i] = str_i_
                lines_to_shuf.sort()
                delete_num = random.randint(1, 100)
                del lines_to_shuf[0:delete_num]

                #batch merge and shuffle
                lines_concat = []
                for i in range(0, len(lines_to_shuf), batchsize):
                    lines_concat.append(' '.join(lines_to_shuf[i:i +
                                                               batchsize]))
                random.shuffle(lines_concat)

                #batch split
                out_file = open(output_file_path, 'w')
                for i in range(len(lines_concat)):
                    tmp_list = lines_concat[i].split(' ')
                    for j in range(int(len(tmp_list) / 5)):
                        out_file.write("{} {} {} {}\n".format(tmp_list[
                            5 * j + 1], tmp_list[5 * j + 2], tmp_list[
                                5 * j + 3], tmp_list[5 * j + 4]))
                out_file.close()
            input_file.close()

        _shuffle_data(img_label_list, to_file, shuffle, batchsize)
122
        print("finish batch shuffle")
W
whs 已提交
123
        img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
124 125

        def reader():
126
            sizes = len(img_label_lines) // batchsize
127 128 129 130 131 132 133 134 135 136 137 138 139
            if sizes == 0:
                raise ValueError('Batch size is bigger than the dataset size.')
            while True:
                for i in range(sizes):
                    result = []
                    sz = [0, 0]
                    for j in range(batchsize):
                        line = img_label_lines[i * batchsize + j]
                        # h, w, img_name, labels
                        items = line.split(' ')

                        label = [int(c) for c in items[-1].split(',')]
                        img = Image.open(os.path.join(img_root_dir, items[
140
                            2])).convert('L')
141 142
                        if j == 0:
                            sz = img.size
143
                        img = img.resize((sz[0], DATA_SHAPE[1]))
144 145
                        img = np.array(img) - 127.5
                        img = img[np.newaxis, ...]
146 147 148 149
                        if self.model == "crnn_ctc":
                            result.append([img, label])
                        else:
                            result.append([img, [SOS] + label, label + [EOS]])
150 151 152
                    yield result
                if not cycle:
                    break
X
xiaohang 已提交
153 154 155 156 157 158 159

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
160
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
161
        :type img_root_dir: str
X
xiaohang 已提交
162 163

        :param img_label_list: The path of the <image_name, label> file for testing.
W
wanghaoshuang 已提交
164
        :type img_label_list: str
X
xiaohang 已提交
165 166 167 168 169 170 171 172
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
173 174
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
175 176

                img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
X
xiaohang 已提交
177 178
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
179 180 181 182
                if self.model == "crnn_ctc":
                    yield img, label
                else:
                    yield img, [SOS] + label, label + [EOS]
X
xiaohang 已提交
183 184

        return reader
W
wanghaoshuang 已提交
185

186
    def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
W
wanghaoshuang 已提交
187 188 189 190 191 192 193 194 195 196
        '''A reader interface for inference.

        :param img_root_dir: The root path of the images for training.
        :type img_root_dir: str

        :param img_label_list: The path of the <image_name, label> file for
        inference. It should be the path of <image_path> file if img_root_dir
        was None. If img_label_list was set to None, it will read image path
        from stdin.
        :type img_root_dir: str
L
LiufangSang 已提交
197

198 199 200
        :param cycle: If number of iterations is greater than dataset_size /
        batch_size it reiterates dataset over as many times as necessary.
        :type cycle: bool
W
wanghaoshuang 已提交
201 202 203
        '''

        def reader():
204 205
            def yield_img_and_label(lines):
                for line in lines:
W
wanghaoshuang 已提交
206 207 208 209 210 211 212
                    if img_root_dir is not None:
                        # h, w, img_name, labels
                        img_name = line.split(' ')[2]
                        img_path = os.path.join(img_root_dir, img_name)
                    else:
                        img_path = line.strip("\t\n\r")
                    img = Image.open(img_path).convert('L')
213
                    img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
W
wanghaoshuang 已提交
214 215
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
W
whs 已提交
216
                    yield img, [[0]]
217 218 219 220 221 222 223 224 225 226

            if img_label_list is not None:
                lines = []
                with open(img_label_list) as f:
                    lines = f.readlines()
                for img, label in yield_img_and_label(lines):
                    yield img, label
                while cycle:
                    for img, label in yield_img_and_label(lines):
                        yield img, label
W
wanghaoshuang 已提交
227 228
            else:
                while True:
W
whs 已提交
229
                    img_path = input("Please input the path of image: ")
W
wanghaoshuang 已提交
230
                    img = Image.open(img_path).convert('L')
231
                    img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
W
wanghaoshuang 已提交
232 233 234 235 236 237
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    yield img, [[0]]

        return reader

W
wanghaoshuang 已提交
238 239

def num_classes():
W
wanghaoshuang 已提交
240 241
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
242 243 244 245
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
246 247
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
248 249 250
    return DATA_SHAPE


251 252 253 254 255 256
def train(batch_size,
          train_images_dir=None,
          train_list_file=None,
          cycle=False,
          model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
257 258 259 260 261
    if train_images_dir is None:
        data_dir = download_data()
        train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
    if train_list_file is None:
        train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
W
whs 已提交
262 263 264 265 266
    shuffle = True
    if 'ce_mode' in os.environ:
        shuffle = False
    return generator.train_reader(
        train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
W
wanghaoshuang 已提交
267 268


269 270 271 272 273
def test(batch_size=1,
         test_images_dir=None,
         test_list_file=None,
         model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
274 275 276 277 278
    if test_images_dir is None:
        data_dir = download_data()
        test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
    if test_list_file is None:
        test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
W
wanghaoshuang 已提交
279
    return paddle.batch(
W
wanghaoshuang 已提交
280
        generator.test_reader(test_images_dir, test_list_file), batch_size)
W
wanghaoshuang 已提交
281 282


283 284 285
def inference(batch_size=1,
              infer_images_dir=None,
              infer_list_file=None,
286 287 288
              cycle=False,
              model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
289
    return paddle.batch(
290 291
        generator.infer_reader(infer_images_dir, infer_list_file, cycle),
        batch_size)
W
wanghaoshuang 已提交
292 293


W
wanghaoshuang 已提交
294 295 296 297 298 299 300 301 302 303 304
def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir