reader.py 1.7 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
import os
import cv2

from paddle.v2.image import load_image


class DataGenerator(object):
    def __init__(self, char_dict, image_shape):
        '''
        :param char_dict: The dictionary class for labels.
        :type char_dict: class
        :param image_shape: The fixed shape of images.
        :type image_shape: tuple
        '''
        self.image_shape = image_shape
        self.char_dict = char_dict

    def train_reader(self, file_list):
        '''
        Reader interface for training.
P
peterzhang2029 已提交
21

P
peterzhang2029 已提交
22 23 24 25 26
        :param file_list: The path list of the image file for training.
        :type file_list: list
        '''

        def reader():
P
peterzhang2029 已提交
27 28 29 30
            UNK_ID = self.char_dict['<unk>']
            for image_path, label in file_list:
                label = [self.char_dict.get(c, UNK_ID) for c in label]
                yield self.load_image(image_path), label
P
peterzhang2029 已提交
31 32 33 34 35 36

        return reader

    def infer_reader(self, file_list):
        '''
        Reader interface for inference.
P
peterzhang2029 已提交
37

P
peterzhang2029 已提交
38 39 40 41 42
        :param file_list: The path list of the image file for inference.
        :type file_list: list
        '''

        def reader():
P
peterzhang2029 已提交
43 44
            for image_path, label in file_list:
                yield self.load_image(image_path), label
P
peterzhang2029 已提交
45 46 47 48 49

        return reader

    def load_image(self, path):
        '''
P
peterzhang2029 已提交
50 51
        Load an image and transform it to 1-dimention vector.

P
peterzhang2029 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64
        :param path: The path of the image data.
        :type path: str
        '''
        image = load_image(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Resize all images to a fixed shape.
        if self.image_shape:
            image = cv2.resize(
                image, self.image_shape, interpolation=cv2.INTER_CUBIC)

        image = image.flatten() / 255.
        return image