From 68b61067ff83d41ef0eb7e2ccff71fe0978c3ca3 Mon Sep 17 00:00:00 2001 From: xiaohang Date: Mon, 5 Feb 2018 11:44:01 +0800 Subject: [PATCH] add ctc reader --- fluid/ocr_recognition/ctc_reader.py | 82 ++++++++++++++++++----------- 1 file changed, 52 insertions(+), 30 deletions(-) diff --git a/fluid/ocr_recognition/ctc_reader.py b/fluid/ocr_recognition/ctc_reader.py index f3f8c951..e5264c33 100644 --- a/fluid/ocr_recognition/ctc_reader.py +++ b/fluid/ocr_recognition/ctc_reader.py @@ -1,6 +1,7 @@ import os import cv2 import numpy as np +from PIL import Image from paddle.v2.image import load_image @@ -9,40 +10,64 @@ class DataGenerator(object): def __init__(self): pass - def train_reader(self, img_root_dir, img_label_list): + def train_reader(self, img_root_dir, img_label_list, batchsize): ''' Reader interface for training. - :param img_root_dir: The root path of the image for training. + :param img_root_dir: The root path of the image for training. :type file_list: str :param img_label_list: The path of the file for training. :type file_list: str ''' - # sort by height, e.g. idx + img_label_lines = [] - for line in open(img_label_list): - # h, w, img_name, labels - items = line.split(' ') - idx = "{:0>5d}".format(int(items[0])) - img_label_lines.append(idx + ' ' + line) - img_label_lines.sort() + if batchsize == 1: + to_file = "tmp.txt" + cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file + print "cmd: " + cmd + os.system(cmd) + print "finish batch shuffle" + img_label_lines = open(to_file, 'r').readlines() + else: + to_file = "tmp.txt" + #cmd1: partial shuffle + cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | " + #cmd2: batch merge and shuffle + cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str( + batchsize) + " == 0) print \"\";}' | shuf | " + #cmd3: batch split + cmd += "awk '{if(NF == " + str( + batchsize + ) + " * 4) {for(i = 0; i < " + str( + batchsize + ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file + print "cmd: " + cmd + os.system(cmd) + print "finish batch shuffle" + img_label_lines = open(to_file, 'r').readlines() def reader(): - for line in img_label_lines: - # h, w, img_name, labels - items = line.split(' ')[1:] - - assert len(items) == 4 - - label = [int(c) for c in items[-1].split(',')] - - img = load_image(os.path.join(img_root_dir, items[2])) - img = np.transpose(img, (2, 0, 1)) - #img = img[np.newaxis, ...] - - yield img, label + sizes = len(img_label_lines) / batchsize + for i in range(sizes): + result = [] + sz = [0, 0] + for j in range(batchsize): + line = img_label_lines[i * batchsize + j] + # h, w, img_name, labels + items = line.split(' ') + + label = [int(c) for c in items[-1].split(',')] + img = Image.open(os.path.join(img_root_dir, items[ + 2])).convert('L') #zhuanhuidu + if j == 0: + sz = img.size + img = img.resize((sz[0], sz[1])) + img = np.array(img) - 127.5 + img = img[np.newaxis, ...] + result.append([img, label]) + yield result return reader @@ -50,7 +75,7 @@ class DataGenerator(object): ''' Reader interface for inference. - :param img_root_dir: The root path of the images for training. + :param img_root_dir: The root path of the images for training. :type file_list: str :param img_label_list: The path of the file for testing. @@ -62,14 +87,11 @@ class DataGenerator(object): # h, w, img_name, labels items = line.split(' ') - assert len(items) == 4 - label = [int(c) for c in items[-1].split(',')] - - img = load_image(os.path.join(img_root_dir, items[2])) - img = np.transpose(img, (2, 0, 1)) - #img = img[np.newaxis, ...] - + img = Image.open(os.path.join(img_root_dir, items[2])).convert( + 'L') + img = np.array(img) - 127.5 + img = img[np.newaxis, ...] yield img, label return reader -- GitLab