ctc_reader.py 3.3 KB
Newer Older
X
xiaohang 已提交
1 2 3
import os
import cv2
import numpy as np
X
xiaohang 已提交
4
from PIL import Image
X
xiaohang 已提交
5 6 7 8 9 10 11 12

from paddle.v2.image import load_image


class DataGenerator(object):
    def __init__(self):
        pass

X
xiaohang 已提交
13
    def train_reader(self, img_root_dir, img_label_list, batchsize):
X
xiaohang 已提交
14 15 16
        '''
        Reader interface for training.

X
xiaohang 已提交
17
        :param img_root_dir: The root path of the image for training.
X
xiaohang 已提交
18 19 20 21 22 23
        :type file_list: str 

        :param img_label_list: The path of the <image_name, label> file for training.
        :type file_list: str 

        '''
X
xiaohang 已提交
24

X
xiaohang 已提交
25
        img_label_lines = []
X
xiaohang 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        if batchsize == 1:
            to_file = "tmp.txt"
            cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
        else:
            to_file = "tmp.txt"
            #cmd1: partial shuffle
            cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
            #cmd2: batch merge and shuffle
            cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
                batchsize) + " == 0) print \"\";}' | shuf | "
            #cmd3: batch split
            cmd += "awk '{if(NF == " + str(
                batchsize
            ) + " * 4) {for(i = 0; i < " + str(
                batchsize
            ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
            print "cmd: " + cmd
            os.system(cmd)
            print "finish batch shuffle"
            img_label_lines = open(to_file, 'r').readlines()
X
xiaohang 已提交
50 51

        def reader():
X
xiaohang 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
            sizes = len(img_label_lines) / batchsize
            for i in range(sizes):
                result = []
                sz = [0, 0]
                for j in range(batchsize):
                    line = img_label_lines[i * batchsize + j]
                    # h, w, img_name, labels
                    items = line.split(' ')

                    label = [int(c) for c in items[-1].split(',')]
                    img = Image.open(os.path.join(img_root_dir, items[
                        2])).convert('L')  #zhuanhuidu
                    if j == 0:
                        sz = img.size
                    img = img.resize((sz[0], sz[1]))
                    img = np.array(img) - 127.5
                    img = img[np.newaxis, ...]
                    result.append([img, label])
                yield result
X
xiaohang 已提交
71 72 73 74 75 76 77

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

X
xiaohang 已提交
78
        :param img_root_dir: The root path of the images for training.
X
xiaohang 已提交
79 80 81 82 83 84 85 86 87 88 89 90
        :type file_list: str 

        :param img_label_list: The path of the <image_name, label> file for testing.
        :type file_list: list
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                label = [int(c) for c in items[-1].split(',')]
X
xiaohang 已提交
91 92 93 94
                img = Image.open(os.path.join(img_root_dir, items[2])).convert(
                    'L')
                img = np.array(img) - 127.5
                img = img[np.newaxis, ...]
X
xiaohang 已提交
95 96 97
                yield img, label

        return reader