From 8c737c22ae5d896f5445995660d664d959ce1c08 Mon Sep 17 00:00:00 2001 From: xiaohang Date: Thu, 1 Feb 2018 21:05:07 +0800 Subject: [PATCH] add ctc reader --- fluid/ocr_recognition/ctc_reader.py | 75 +++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 fluid/ocr_recognition/ctc_reader.py diff --git a/fluid/ocr_recognition/ctc_reader.py b/fluid/ocr_recognition/ctc_reader.py new file mode 100644 index 00000000..f3f8c951 --- /dev/null +++ b/fluid/ocr_recognition/ctc_reader.py @@ -0,0 +1,75 @@ +import os +import cv2 +import numpy as np + +from paddle.v2.image import load_image + + +class DataGenerator(object): + def __init__(self): + pass + + def train_reader(self, img_root_dir, img_label_list): + ''' + Reader interface for training. + + :param img_root_dir: The root path of the image for training. + :type file_list: str + + :param img_label_list: The path of the file for training. + :type file_list: str + + ''' + # sort by height, e.g. idx + img_label_lines = [] + for line in open(img_label_list): + # h, w, img_name, labels + items = line.split(' ') + idx = "{:0>5d}".format(int(items[0])) + img_label_lines.append(idx + ' ' + line) + img_label_lines.sort() + + def reader(): + for line in img_label_lines: + # h, w, img_name, labels + items = line.split(' ')[1:] + + assert len(items) == 4 + + label = [int(c) for c in items[-1].split(',')] + + img = load_image(os.path.join(img_root_dir, items[2])) + img = np.transpose(img, (2, 0, 1)) + #img = img[np.newaxis, ...] + + yield img, label + + return reader + + def test_reader(self, img_root_dir, img_label_list): + ''' + Reader interface for inference. + + :param img_root_dir: The root path of the images for training. + :type file_list: str + + :param img_label_list: The path of the file for testing. + :type file_list: list + ''' + + def reader(): + for line in open(img_label_list): + # h, w, img_name, labels + items = line.split(' ') + + assert len(items) == 4 + + label = [int(c) for c in items[-1].split(',')] + + img = load_image(os.path.join(img_root_dir, items[2])) + img = np.transpose(img, (2, 0, 1)) + #img = img[np.newaxis, ...] + + yield img, label + + return reader -- GitLab