data_reader.py 8.7 KB
Newer Older
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
xiaohang 已提交
4 5
import os
import cv2
W
wanghaoshuang 已提交
6
import tarfile
X
xiaohang 已提交
7
import numpy as np
X
xiaohang 已提交
8
from PIL import Image
W
wanghaoshuang 已提交
9
from os import path
W
wanghaoshuang 已提交
10 11
from paddle.dataset.image import load_image
import paddle
W
wanghaoshuang 已提交
12

W
whs 已提交
13 14 15 16 17
try:
    input = raw_input
except NameError:
    pass

18 19
SOS = 0
EOS = 1
20
NUM_CLASSES = 95
W
wanghaoshuang 已提交
21
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
22

W
whs 已提交
23 24
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
W
wanghaoshuang 已提交
25 26 27 28 29 30 31 32
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
33 34

class DataGenerator(object):
35 36
    def __init__(self, model="crnn_ctc"):
        self.model = model
X
xiaohang 已提交
37

W
whs 已提交
38 39 40 41 42 43
    def train_reader(self,
                     img_root_dir,
                     img_label_list,
                     batchsize,
                     cycle,
                     shuffle=True):
X
xiaohang 已提交
44 45 46
        '''
        Reader interface for training.

X
xiaohang 已提交
47
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
48
        :type img_root_dir: str
X
xiaohang 已提交
49 50

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
51
        :type img_label_list: str
X
xiaohang 已提交
52

53 54 55 56
        :param cycle: If number of iterations is greater than dataset_size / batch_size
        it reiterates dataset over as many times as necessary.
        :type cycle: bool
        
X
xiaohang 已提交
57
        '''
X
xiaohang 已提交
58

X
xiaohang 已提交
59
        img_label_lines = []
W
whs 已提交
60 61 62 63
        to_file = "tmp.txt"
        if not shuffle:
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file
        elif batchsize == 1:
X
xiaohang 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
        else:
            #cmd1: partial shuffle
            cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
            #cmd2: batch merge and shuffle
            cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
                batchsize) + " == 0) print \"\";}' | shuf | "
            #cmd3: batch split
            cmd += "awk '{if(NF == " + str(
                batchsize
            ) + " * 4) {for(i = 0; i < " + str(
                batchsize
            ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
W
whs 已提交
77
        os.system(cmd)
78
        print("finish batch shuffle")
W
whs 已提交
79
        img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
80 81

        def reader():
82
            sizes = len(img_label_lines) // batchsize
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
            if sizes == 0:
                raise ValueError('Batch size is bigger than the dataset size.')
            while True:
                for i in range(sizes):
                    result = []
                    sz = [0, 0]
                    for j in range(batchsize):
                        line = img_label_lines[i * batchsize + j]
                        # h, w, img_name, labels
                        items = line.split(' ')

                        label = [int(c) for c in items[-1].split(',')]
                        img = Image.open(os.path.join(img_root_dir, items[
                            2])).convert('L')  #zhuanhuidu
                        if j == 0:
                            sz = img.size
                        img = img.resize((sz[0], sz[1]))
                        img = np.array(img) - 127.5
                        img = img[np.newaxis, ...]
102 103 104 105
                        if self.model == "crnn_ctc":
                            result.append([img, label])
                        else:
                            result.append([img, [SOS] + label, label + [EOS]])
106 107 108
                    yield result
                if not cycle:
                    break
X
xiaohang 已提交
109 110 111 112 113 114 115

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
116
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
117
        :type img_root_dir: str
X
xiaohang 已提交
118 119

        :param img_label_list: The path of the <image_name, label> file for testing.
W
wanghaoshuang 已提交
120
        :type img_label_list: str
X
xiaohang 已提交
121 122 123 124 125 126 127 128
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
129 130 131 132
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
133 134 135 136
                if self.model == "crnn_ctc":
                    yield img, label
                else:
                    yield img, [SOS] + label, label + [EOS]
X
xiaohang 已提交
137 138

        return reader
W
wanghaoshuang 已提交
139

140
    def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
W
wanghaoshuang 已提交
141 142 143 144 145 146 147 148 149 150
        '''A reader interface for inference.

        :param img_root_dir: The root path of the images for training.
        :type img_root_dir: str

        :param img_label_list: The path of the <image_name, label> file for
        inference. It should be the path of <image_path> file if img_root_dir
        was None. If img_label_list was set to None, it will read image path
        from stdin.
        :type img_root_dir: str
151 152 153 154
        
        :param cycle: If number of iterations is greater than dataset_size /
        batch_size it reiterates dataset over as many times as necessary.
        :type cycle: bool
W
wanghaoshuang 已提交
155 156 157
        '''

        def reader():
158 159
            def yield_img_and_label(lines):
                for line in lines:
W
wanghaoshuang 已提交
160 161 162 163 164 165 166 167 168
                    if img_root_dir is not None:
                        # h, w, img_name, labels
                        img_name = line.split(' ')[2]
                        img_path = os.path.join(img_root_dir, img_name)
                    else:
                        img_path = line.strip("\t\n\r")
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
169
                    label = [int(c) for c in line.split(' ')[3].split(',')]
W
wanghaoshuang 已提交
170
                    yield img, label
171 172 173 174 175 176 177 178 179 180

            if img_label_list is not None:
                lines = []
                with open(img_label_list) as f:
                    lines = f.readlines()
                for img, label in yield_img_and_label(lines):
                    yield img, label
                while cycle:
                    for img, label in yield_img_and_label(lines):
                        yield img, label
W
wanghaoshuang 已提交
181 182
            else:
                while True:
W
whs 已提交
183
                    img_path = input("Please input the path of image: ")
W
wanghaoshuang 已提交
184 185 186 187 188 189 190
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    yield img, [[0]]

        return reader

W
wanghaoshuang 已提交
191 192

def num_classes():
W
wanghaoshuang 已提交
193 194
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
195 196 197 198
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
199 200
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
201 202 203
    return DATA_SHAPE


204 205 206 207 208 209
def train(batch_size,
          train_images_dir=None,
          train_list_file=None,
          cycle=False,
          model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
210 211 212 213 214
    if train_images_dir is None:
        data_dir = download_data()
        train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
    if train_list_file is None:
        train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
W
whs 已提交
215 216 217 218 219
    shuffle = True
    if 'ce_mode' in os.environ:
        shuffle = False
    return generator.train_reader(
        train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
W
wanghaoshuang 已提交
220 221


222 223 224 225 226
def test(batch_size=1,
         test_images_dir=None,
         test_list_file=None,
         model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
227 228 229 230 231
    if test_images_dir is None:
        data_dir = download_data()
        test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
    if test_list_file is None:
        test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
W
wanghaoshuang 已提交
232
    return paddle.batch(
W
wanghaoshuang 已提交
233
        generator.test_reader(test_images_dir, test_list_file), batch_size)
W
wanghaoshuang 已提交
234 235


236 237 238
def inference(batch_size=1,
              infer_images_dir=None,
              infer_list_file=None,
239 240 241
              cycle=False,
              model="crnn_ctc"):
    generator = DataGenerator(model)
W
wanghaoshuang 已提交
242
    return paddle.batch(
243 244
        generator.infer_reader(infer_images_dir, infer_list_file, cycle),
        batch_size)
W
wanghaoshuang 已提交
245 246


W
wanghaoshuang 已提交
247 248 249 250 251 252 253 254 255 256 257
def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir