data_reader.py 8.7 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
xiaohang 已提交
4 5
import os
import cv2
W
wanghaoshuang 已提交
6
import tarfile
X
xiaohang 已提交
7
import numpy as np
X
xiaohang 已提交
8
from PIL import Image
W
wanghaoshuang 已提交
9
from os import path
W
wanghaoshuang 已提交
10 11
from paddle.dataset.image import load_image
import paddle
W
wanghaoshuang 已提交
12

13 14
SOS = 0
EOS = 1
15
NUM_CLASSES = 95
W
wanghaoshuang 已提交
16
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
17

W
whs 已提交
18 19
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
W
wanghaoshuang 已提交
20 21 22 23 24 25 26 27
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
28 29

class DataGenerator(object):
30 31
    def __init__(self, model="crnn_ctc"):
        self.model = model
X
xiaohang 已提交
32

W
whs 已提交
33 34 35 36 37 38
    def train_reader(self,
                     img_root_dir,
                     img_label_list,
                     batchsize,
                     cycle,
                     shuffle=True):
X
xiaohang 已提交
39 40 41
        '''
        Reader interface for training.

X
xiaohang 已提交
42
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
43
        :type img_root_dir: str
X
xiaohang 已提交
44 45

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
46
        :type img_label_list: str
X
xiaohang 已提交
47

48 49 50 51
        :param cycle: If number of iterations is greater than dataset_size / batch_size
        it reiterates dataset over as many times as necessary.
        :type cycle: bool
        
X
xiaohang 已提交
52
        '''
X
xiaohang 已提交
53

X
xiaohang 已提交
54
        img_label_lines = []
W
whs 已提交
55 56 57 58
        to_file = "tmp.txt"
        if not shuffle:
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file
        elif batchsize == 1:
X
xiaohang 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
        else:
            #cmd1: partial shuffle
            cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
            #cmd2: batch merge and shuffle
            cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
                batchsize) + " == 0) print \"\";}' | shuf | "
            #cmd3: batch split
            cmd += "awk '{if(NF == " + str(
                batchsize
            ) + " * 4) {for(i = 0; i < " + str(
                batchsize
            ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
W
whs 已提交
72
        os.system(cmd)
73
        print("finish batch shuffle")
W
whs 已提交
74
        img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
75 76

        def reader():
77
            sizes = len(img_label_lines) // batchsize
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
            if sizes == 0:
                raise ValueError('Batch size is bigger than the dataset size.')
            while True:
                for i in range(sizes):
                    result = []
                    sz = [0, 0]
                    for j in range(batchsize):
                        line = img_label_lines[i * batchsize + j]
                        # h, w, img_name, labels
                        items = line.split(' ')

                        label = [int(c) for c in items[-1].split(',')]
                        img = Image.open(os.path.join(img_root_dir, items[
                            2])).convert('L')  #zhuanhuidu
                        if j == 0:
                            sz = img.size
                        img = img.resize((sz[0], sz[1]))
                        img = np.array(img) - 127.5
                        img = img[np.newaxis, ...]
97 98 99 100
                        if self.model == "crnn_ctc":
                            result.append([img, label])
                        else:
                            result.append([img, [SOS] + label, label + [EOS]])
101 102 103
                    yield result
                if not cycle:
                    break
X
xiaohang 已提交
104 105 106 107 108 109 110

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
111
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
112
        :type img_root_dir: str
X
xiaohang 已提交
113 114

        :param img_label_list: The path of the <image_name, label> file for testing.
W
wanghaoshuang 已提交
115
        :type img_label_list: str
X
xiaohang 已提交
116 117 118 119 120 121 122 123
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
124 125 126 127
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
128 129 130 131
                if self.model == "crnn_ctc":
                    yield img, label
                else:
                    yield img, [SOS] + label, label + [EOS]
X
xiaohang 已提交
132 133

        return reader
W
wanghaoshuang 已提交
134

135
    def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
W
wanghaoshuang 已提交
136 137 138 139 140 141 142 143 144 145
        '''A reader interface for inference.

        :param img_root_dir: The root path of the images for training.
        :type img_root_dir: str

        :param img_label_list: The path of the <image_name, label> file for
        inference. It should be the path of <image_path> file if img_root_dir
        was None. If img_label_list was set to None, it will read image path
        from stdin.
        :type img_root_dir: str
146 147 148 149
        
        :param cycle: If number of iterations is greater than dataset_size /
        batch_size it reiterates dataset over as many times as necessary.
        :type cycle: bool
W
wanghaoshuang 已提交
150 151 152
        '''

        def reader():
153 154
            def yield_img_and_label(lines):
                for line in lines:
W
wanghaoshuang 已提交
155 156 157 158 159 160 161 162 163
                    if img_root_dir is not None:
                        # h, w, img_name, labels
                        img_name = line.split(' ')[2]
                        img_path = os.path.join(img_root_dir, img_name)
                    else:
                        img_path = line.strip("\t\n\r")
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
164
                    label = [int(c) for c in line.split(' ')[3].split(',')]
W
wanghaoshuang 已提交
165
                    yield img, label
166 167 168 169 170 171 172 173 174 175

            if img_label_list is not None:
                lines = []
                with open(img_label_list) as f:
                    lines = f.readlines()
                for img, label in yield_img_and_label(lines):
                    yield img, label
                while cycle:
                    for img, label in yield_img_and_label(lines):
                        yield img, label
W
wanghaoshuang 已提交
176 177 178 179 180 181 182 183 184 185
            else:
                while True:
                    img_path = raw_input("Please input the path of image: ")
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    yield img, [[0]]

        return reader

W
wanghaoshuang 已提交
186 187

def num_classes():
W
wanghaoshuang 已提交
188 189
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
190 191 192 193
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
194 195
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
196 197 198
    return DATA_SHAPE


199 200 201 202 203 204
def train(batch_size,
          train_images_dir=None,
          train_list_file=None,
          cycle=False,
          model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
205 206 207 208 209
    if train_images_dir is None:
        data_dir = download_data()
        train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
    if train_list_file is None:
        train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
W
whs 已提交
210 211 212 213 214
    shuffle = True
    if 'ce_mode' in os.environ:
        shuffle = False
    return generator.train_reader(
        train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
W
wanghaoshuang 已提交
215 216


217 218 219 220 221
def test(batch_size=1,
         test_images_dir=None,
         test_list_file=None,
         model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
222 223 224 225 226
    if test_images_dir is None:
        data_dir = download_data()
        test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
    if test_list_file is None:
        test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
W
wanghaoshuang 已提交
227
    return paddle.batch(
W
wanghaoshuang 已提交
228
        generator.test_reader(test_images_dir, test_list_file), batch_size)
W
wanghaoshuang 已提交
229 230


231 232 233
def inference(batch_size=1,
              infer_images_dir=None,
              infer_list_file=None,
234 235 236
              cycle=False,
              model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
237
    return paddle.batch(
238 239
        generator.infer_reader(infer_images_dir, infer_list_file, cycle),
        batch_size)
W
wanghaoshuang 已提交
240 241


W
wanghaoshuang 已提交
242 243 244 245 246 247 248 249 250 251 252
def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir