ctc_reader.py 2.0 KB
Newer Older
X
xiaohang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
import os
import cv2
import numpy as np

from paddle.v2.image import load_image


class DataGenerator(object):
    def __init__(self):
        pass

    def train_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for training.

		:param img_root_dir: The root path of the image for training.
        :type file_list: str 

        :param img_label_list: The path of the <image_name, label> file for training.
        :type file_list: str 

        '''
        # sort by height, e.g. idx
        img_label_lines = []
        for line in open(img_label_list):
            # h, w, img_name, labels
            items = line.split(' ')
            idx = "{:0>5d}".format(int(items[0]))
            img_label_lines.append(idx + ' ' + line)
        img_label_lines.sort()

        def reader():
            for line in img_label_lines:
                # h, w, img_name, labels
                items = line.split(' ')[1:]

                assert len(items) == 4

                label = [int(c) for c in items[-1].split(',')]

                img = load_image(os.path.join(img_root_dir, items[2]))
                img = np.transpose(img, (2, 0, 1))
                #img = img[np.newaxis, ...]

                yield img, label

        return reader

    def test_reader(self, img_root_dir, img_label_list):
        '''
        Reader interface for inference.

		:param img_root_dir: The root path of the images for training.
        :type file_list: str 

        :param img_label_list: The path of the <image_name, label> file for testing.
        :type file_list: list
        '''

        def reader():
            for line in open(img_label_list):
                # h, w, img_name, labels
                items = line.split(' ')

                assert len(items) == 4

                label = [int(c) for c in items[-1].split(',')]

                img = load_image(os.path.join(img_root_dir, items[2]))
                img = np.transpose(img, (2, 0, 1))
                #img = img[np.newaxis, ...]

                yield img, label

        return reader