data_provider.py 2.8 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
from __future__ import absolute_import
from __future__ import division

import os
from paddle.v2.image import load_image
import cv2


class AsciiDic(object):
    UNK = 0

    def __init__(self):
        self.dic = {
            '<unk>': self.UNK,
        }
        self.chars = [chr(i) for i in range(40, 171)]
        for id, c in enumerate(self.chars):
            self.dic[c] = id + 1

    def lookup(self, w):
        return self.dic.get(w, self.UNK)

    def id2word(self):
        self.id2word = {}
        for key, value in self.dic.items():
            self.id2word[value] = key

        return self.id2word

    def word2ids(self, sent):
        '''
        transform a word to a list of ids.
        @sent: str
        '''
        return [self.lookup(c) for c in list(sent)]

    def size(self):
        return len(self.dic)


class ImageDataset(object):
    def __init__(self,
                 train_image_paths_generator,
                 test_image_paths_generator,
                 infer_image_paths_generator,
                 fixed_shape=None,
                 is_infer=False):
        '''
        @image_paths_generator: function
            return a list of images' paths, called like:

                for path in image_paths_generator():
                    load_image(path)
        '''
        if is_infer == False:
            self.train_filelist = [p for p in train_image_paths_generator]
            self.test_filelist = [p for p in test_image_paths_generator]
        else:
            self.infer_filelist = [p for p in infer_image_paths_generator]

        self.fixed_shape = fixed_shape
        self.ascii_dic = AsciiDic()

    def train(self):
        for i, (image, label) in enumerate(self.train_filelist):
            yield self.load_image(image), self.ascii_dic.word2ids(label)

    def test(self):
        for i, (image, label) in enumerate(self.test_filelist):
            yield self.load_image(image), self.ascii_dic.word2ids(label)

    def infer(self):
        for i, (image, label) in enumerate(self.infer_filelist):
            yield self.load_image(image), label

    def load_image(self, path):
        '''
        load image and transform to 1-dimention vector
        '''
        image = load_image(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        # resize all images to a fixed shape

        if self.fixed_shape:
            image = cv2.resize(
                image, self.fixed_shape, interpolation=cv2.INTER_CUBIC)

        image = image.flatten() / 255.
        return image


def get_file_list(image_file_list):
    pwd = os.path.dirname(image_file_list)
    with open(image_file_list) as f:
        for line in f:
            fs = line.strip().split(',')
            file = fs[0].strip()
            path = os.path.join(pwd, file)
            yield path, fs[1][2:-1]