ctc_reader.py 4.0 KB
Newer Older
X
xiaohang 已提交
1 2 3
import os
import cv2
import numpy as np
X
xiaohang 已提交
4
from PIL import Image
X
xiaohang 已提交
5 6

from paddle.v2.image import load_image
W
wanghaoshuang 已提交
7 8 9 10
import paddle.v2 as paddle

NUM_CLASSES = 10784
DATA_SHAPE = [1, 48, 512]
X
xiaohang 已提交
11 12 13 14 15 16


class DataGenerator(object):
    def __init__(self):
        pass

X
xiaohang 已提交
17
    def train_reader(self, img_root_dir, img_label_list, batchsize):
X
xiaohang 已提交
18 19 20
        '''
        Reader interface for training.

X
xiaohang 已提交
21
        :param img_root_dir: The root path of the image for training.
W
wanghaoshuang 已提交
22
        :type file_list: str
X
xiaohang 已提交
23 24

        :param img_label_list: The path of the <image_name, label> file for training.
W
wanghaoshuang 已提交
25
        :type file_list: str
X
xiaohang 已提交
26 27

        '''
X
xiaohang 已提交
28

X
xiaohang 已提交
29
        img_label_lines = []
X
xiaohang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        if batchsize == 1:
            to_file = "tmp.txt"
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
        else:
            to_file = "tmp.txt"
            #cmd1: partial shuffle
            cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
            #cmd2: batch merge and shuffle
            cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
                batchsize) + " == 0) print \"\";}' | shuf | "
            #cmd3: batch split
            cmd += "awk '{if(NF == " + str(
                batchsize
            ) + " * 4) {for(i = 0; i < " + str(
                batchsize
            ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
54 55

        def reader():
X
xiaohang 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
            sizes = len(img_label_lines) / batchsize
            for i in range(sizes):
                result = []
                sz = [0, 0]
                for j in range(batchsize):
                    line = img_label_lines[i * batchsize + j]
                    # h, w, img_name, labels
                    items = line.split(' ')

                    label = [int(c) for c in items[-1].split(',')]
                    img = Image.open(os.path.join(img_root_dir, items[
                        2])).convert('L')  #zhuanhuidu
                    if j == 0:
                        sz = img.size
                    img = img.resize((sz[0], sz[1]))
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    result.append([img, label])
                yield result
X
xiaohang 已提交
75 76 77 78 79 80 81

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
82
        :param img_root_dir: The root path of the images for training.
W
wanghaoshuang 已提交
83
        :type file_list: str
X
xiaohang 已提交
84 85 86 87 88 89 90 91 92 93 94

        :param img_label_list: The path of the <image_name, label> file for testing.
        :type file_list: list
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
95 96 97 98
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
X
xiaohang 已提交
99 100 101
                yield img, label

        return reader
W
wanghaoshuang 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126


def num_classes():
    return NUM_CLASSES


def data_shape():
    return DATA_SHAPE


def train(batch_size):
    generator = DataGenerator()
    return generator.train_reader(
        "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train_images/",
        "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/train.list",
        batch_size)


def test(batch_size=1):
    generator = DataGenerator()
    return paddle.batch(
        generator.test_reader(
            "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test_images/",
            "/home/disk1/wanghaoshuang/models/fluid/ocr_recognition/data/test.list"
        ), batch_size)