ctc_reader.py 6.9 KB
Newer Older
X
xiaohang 已提交
1 2
import os
import cv2
W
wanghaoshuang 已提交
3
import tarfile
X
xiaohang 已提交
4
import numpy as np
X
xiaohang 已提交
5
from PIL import Image
W
wanghaoshuang 已提交
6
from os import path
X
xiaohang 已提交
7
from paddle.v2.image import load_image
W
wanghaoshuang 已提交
8 9 10 11
import paddle.v2 as paddle

NUM_CLASSES = 10784
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
12

W
whs 已提交
13 14
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
W
wanghaoshuang 已提交
15 16 17 18 19 20 21 22
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
23 24 25 26 27

class DataGenerator(object):
    def __init__(self):
        pass

X
xiaohang 已提交
28
    def train_reader(self, img_root_dir, img_label_list, batchsize):
X
xiaohang 已提交
29 30 31
        '''
        Reader interface for training.

X
xiaohang 已提交
32
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
33
        :type img_root_dir: str
X
xiaohang 已提交
34 35

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
36
        :type img_label_list: str
X
xiaohang 已提交
37 38

        '''
X
xiaohang 已提交
39

X
xiaohang 已提交
40
        img_label_lines = []
X
xiaohang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        if batchsize == 1:
            to_file = "tmp.txt"
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
        else:
            to_file = "tmp.txt"
            #cmd1: partial shuffle
            cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
            #cmd2: batch merge and shuffle
            cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
                batchsize) + " == 0) print \"\";}' | shuf | "
            #cmd3: batch split
            cmd += "awk '{if(NF == " + str(
                batchsize
            ) + " * 4) {for(i = 0; i < " + str(
                batchsize
            ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
65 66

        def reader():
X
xiaohang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
            sizes = len(img_label_lines) / batchsize
            for i in range(sizes):
                result = []
                sz = [0, 0]
                for j in range(batchsize):
                    line = img_label_lines[i * batchsize + j]
                    # h, w, img_name, labels
                    items = line.split(' ')

                    label = [int(c) for c in items[-1].split(',')]
                    img = Image.open(os.path.join(img_root_dir, items[
                        2])).convert('L')  #zhuanhuidu
                    if j == 0:
                        sz = img.size
                    img = img.resize((sz[0], sz[1]))
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    result.append([img, label])
                yield result
X
xiaohang 已提交
86 87 88 89 90 91 92

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
93
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
94
        :type img_root_dir: str
X
xiaohang 已提交
95 96

        :param img_label_list: The path of the <image_name, label> file for testing.
W
wanghaoshuang 已提交
97
        :type img_label_list: str
X
xiaohang 已提交
98 99 100 101 102 103 104 105
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
106 107 108 109
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
X
xiaohang 已提交
110 111 112
                yield img, label

        return reader
W
wanghaoshuang 已提交
113

W
wanghaoshuang 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    def infer_reader(self, img_root_dir=None, img_label_list=None):
        '''A reader interface for inference.

        :param img_root_dir: The root path of the images for training.
        :type img_root_dir: str

        :param img_label_list: The path of the <image_name, label> file for
        inference. It should be the path of <image_path> file if img_root_dir
        was None. If img_label_list was set to None, it will read image path
        from stdin.
        :type img_root_dir: str
        '''

        def reader():
            if img_label_list is not None:
                for line in open(img_label_list):
                    if img_root_dir is not None:
                        # h, w, img_name, labels
                        img_name = line.split(' ')[2]
                        img_path = os.path.join(img_root_dir, img_name)
                    else:
                        img_path = line.strip("\t\n\r")
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
139
                    label = [int(c) for c in line.split(' ')[3].split(',')]
W
wanghaoshuang 已提交
140 141 142 143 144 145 146 147 148 149 150
                    yield img, label
            else:
                while True:
                    img_path = raw_input("Please input the path of image: ")
                    img = Image.open(img_path).convert('L')
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    yield img, [[0]]

        return reader

W
wanghaoshuang 已提交
151 152

def num_classes():
W
wanghaoshuang 已提交
153 154
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
155 156 157 158
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
159 160
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
161 162 163
    return DATA_SHAPE


W
wanghaoshuang 已提交
164
def train(batch_size, train_images_dir=None, train_list_file=None):
W
wanghaoshuang 已提交
165
    generator = DataGenerator()
W
wanghaoshuang 已提交
166 167 168 169 170 171
    if train_images_dir is None:
        data_dir = download_data()
        train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
    if train_list_file is None:
        train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
    return generator.train_reader(train_images_dir, train_list_file, batch_size)
W
wanghaoshuang 已提交
172 173


W
wanghaoshuang 已提交
174
def test(batch_size=1, test_images_dir=None, test_list_file=None):
W
wanghaoshuang 已提交
175
    generator = DataGenerator()
W
wanghaoshuang 已提交
176 177 178 179 180
    if test_images_dir is None:
        data_dir = download_data()
        test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
    if test_list_file is None:
        test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
W
wanghaoshuang 已提交
181
    return paddle.batch(
W
wanghaoshuang 已提交
182
        generator.test_reader(test_images_dir, test_list_file), batch_size)
W
wanghaoshuang 已提交
183 184


W
wanghaoshuang 已提交
185 186 187 188 189 190
def inference(infer_images_dir=None, infer_list_file=None):
    generator = DataGenerator()
    return paddle.batch(
        generator.infer_reader(infer_images_dir, infer_list_file), 1)


W
wanghaoshuang 已提交
191 192 193 194 195 196 197 198 199 200 201
def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir