ctc_reader.py 4.9 KB
Newer Older
X
xiaohang 已提交
1 2
import os
import cv2
W
wanghaoshuang 已提交
3
import tarfile
X
xiaohang 已提交
4
import numpy as np
X
xiaohang 已提交
5
from PIL import Image
W
wanghaoshuang 已提交
6
from os import path
X
xiaohang 已提交
7
from paddle.v2.image import load_image
W
wanghaoshuang 已提交
8 9 10 11
import paddle.v2 as paddle

NUM_CLASSES = 10784
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
12

W
wanghaoshuang 已提交
13 14 15 16 17 18 19 20 21 22
DATA_MD5 = "1de60d54d19632022144e4e58c2637b5"
DATA_URL = "http://cloud.dlnel.org/filepub/?uuid=df937251-3c0b-480d-9a7b-0080dfeee65c"
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
TEST_DATA_DIR_NAME = "test_images"
TRAIN_LIST_FILE_NAME = "train.list"
TEST_LIST_FILE_NAME = "test.list"

X
xiaohang 已提交
23 24 25 26 27

class DataGenerator(object):
    def __init__(self):
        pass

X
xiaohang 已提交
28
    def train_reader(self, img_root_dir, img_label_list, batchsize):
X
xiaohang 已提交
29 30 31
        '''
        Reader interface for training.

X
xiaohang 已提交
32
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
33
        :type file_list: str
X
xiaohang 已提交
34 35

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
36
        :type file_list: str
X
xiaohang 已提交
37 38

        '''
X
xiaohang 已提交
39

X
xiaohang 已提交
40
        img_label_lines = []
X
xiaohang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        if batchsize == 1:
            to_file = "tmp.txt"
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
        else:
            to_file = "tmp.txt"
            #cmd1: partial shuffle
            cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
            #cmd2: batch merge and shuffle
            cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
                batchsize) + " == 0) print \"\";}' | shuf | "
            #cmd3: batch split
            cmd += "awk '{if(NF == " + str(
                batchsize
            ) + " * 4) {for(i = 0; i < " + str(
                batchsize
            ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
65 66

        def reader():
X
xiaohang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
            sizes = len(img_label_lines) / batchsize
            for i in range(sizes):
                result = []
                sz = [0, 0]
                for j in range(batchsize):
                    line = img_label_lines[i * batchsize + j]
                    # h, w, img_name, labels
                    items = line.split(' ')

                    label = [int(c) for c in items[-1].split(',')]
                    img = Image.open(os.path.join(img_root_dir, items[
                        2])).convert('L')  #zhuanhuidu
                    if j == 0:
                        sz = img.size
                    img = img.resize((sz[0], sz[1]))
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    result.append([img, label])
                yield result
X
xiaohang 已提交
86 87 88 89 90 91 92

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
93
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
94
        :type file_list: str
X
xiaohang 已提交
95 96 97 98 99 100 101 102 103 104 105

        :param img_label_list: The path of the <image_name, label> file for testing.
        :type file_list: list
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
106 107 108 109
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
X
xiaohang 已提交
110 111 112
                yield img, label

        return reader
W
wanghaoshuang 已提交
113 114 115


def num_classes():
W
wanghaoshuang 已提交
116 117
    '''Get classes number of this dataset.
    '''
W
wanghaoshuang 已提交
118 119 120 121
    return NUM_CLASSES


def data_shape():
W
wanghaoshuang 已提交
122 123
    '''Get image shape of this dataset. It is a dummy shape for this dataset.
    '''
W
wanghaoshuang 已提交
124 125 126 127 128
    return DATA_SHAPE


def train(batch_size):
    generator = DataGenerator()
W
wanghaoshuang 已提交
129
    data_dir = download_data()
W
wanghaoshuang 已提交
130
    return generator.train_reader(
W
wanghaoshuang 已提交
131 132
        path.join(data_dir, TRAIN_DATA_DIR_NAME),
        path.join(data_dir, TRAIN_LIST_FILE_NAME), batch_size)
W
wanghaoshuang 已提交
133 134 135 136


def test(batch_size=1):
    generator = DataGenerator()
W
wanghaoshuang 已提交
137
    data_dir = download_data()
W
wanghaoshuang 已提交
138 139
    return paddle.batch(
        generator.test_reader(
W
wanghaoshuang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
            path.join(data_dir, TRAIN_DATA_DIR_NAME),
            path.join(data_dir, TRAIN_LIST_FILE_NAME)), batch_size)


def download_data():
    '''Download train and test data.
    '''
    tar_file = paddle.dataset.common.download(
        DATA_URL, CACHE_DIR_NAME, DATA_MD5, save_name=SAVED_FILE_NAME)
    data_dir = path.join(path.dirname(tar_file), DATA_DIR_NAME)
    if not path.isdir(data_dir):
        t = tarfile.open(tar_file, "r:gz")
        t.extractall(path=path.dirname(tar_file))
        t.close()
    return data_dir