diff --git a/fluid/DeepASR/data_utils/data_reader.py b/fluid/DeepASR/data_utils/data_reader.py index 51eb325f1fe52c4308606742a477329faafaf5b4..33f878af01ecbb1cf1dae3d569afd10306e295ae 100644 --- a/fluid/DeepASR/data_utils/data_reader.py +++ b/fluid/DeepASR/data_utils/data_reader.py @@ -15,6 +15,7 @@ from multiprocessing import Manager, Process import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_add_delta as trans_add_delta from data_utils.util import suppress_complaints, suppress_signal +from data_utils.util import CriticalException, ForceExitWrapper class SampleInfo(object): @@ -89,6 +90,7 @@ class SampleInfoBucket(object): self._split_perturb = split_perturb self._split_sentence_threshold = split_sentence_threshold self._split_sub_sentence_len = split_sub_sentence_len + self._rng = random.Random(0) def generate_sample_info_list(self): sample_info_list = [] @@ -213,6 +215,7 @@ class DataReader(object): self._batch_buffer_size = batch_buffer_size self._process_num = process_num self._verbose = verbose + self._force_exit = ForceExitWrapper(self._manager.Value('b', False)) def generate_bucket_list(self, is_shuffle): if self._block_info_list is None: @@ -251,15 +254,19 @@ class DataReader(object): sample_queue = self._manager.Queue(self._sample_buffer_size) self._order_id = 0 - @suppress_complaints(verbose=self._verbose) + @suppress_complaints(verbose=self._verbose, notify=self._force_exit) def ordered_feeding_task(sample_info_queue): for sample_info_bucket in self._bucket_list: - sample_info_list = sample_info_bucket.generate_sample_info_list( - ) - self._rng.shuffle(sample_info_list) # do shuffle here - for sample_info in sample_info_list: - sample_info_queue.put((sample_info, self._order_id)) - self._order_id += 1 + try: + sample_info_list = \ + sample_info_bucket.generate_sample_info_list() + except Exception as e: + raise CriticalException(e) + else: + self._rng.shuffle(sample_info_list) # do shuffle here + for sample_info in sample_info_list: + sample_info_queue.put((sample_info, self._order_id)) + self._order_id += 1 for i in xrange(self._process_num): sample_info_queue.put(EpochEndSignal()) @@ -269,18 +276,21 @@ class DataReader(object): feeding_thread.daemon = True feeding_thread.start() - @suppress_complaints(verbose=self._verbose) + @suppress_complaints(verbose=self._verbose, notify=self._force_exit) def ordered_processing_task(sample_info_queue, sample_queue, out_order): if self._verbose == 0: signal.signal(signal.SIGTERM, suppress_signal) signal.signal(signal.SIGINT, suppress_signal) def read_bytes(fpath, start, size): - f = open(fpath, 'r') - f.seek(start, 0) - binary_bytes = f.read(size) - f.close() - return binary_bytes + try: + f = open(fpath, 'r') + f.seek(start, 0) + binary_bytes = f.read(size) + f.close() + return binary_bytes + except Exception as e: + raise CriticalException(e) ins = sample_info_queue.get() @@ -352,16 +362,21 @@ class DataReader(object): w.start() finished_process_num = 0 - while finished_process_num < self._process_num: - sample = sample_queue.get() - if isinstance(sample, EpochEndSignal): - finished_process_num += 1 - continue - yield sample - - feeding_thread.join() - for w in workers: - w.join() + + while self._force_exit == False: + try: + sample = sample_queue.get_nowait() + except Queue.Empty: + time.sleep(0.001) + else: + if isinstance(sample, EpochEndSignal): + finished_process_num += 1 + if finished_process_num >= self._process_num: + break + else: + continue + + yield sample def batch_iterator(self, batch_size, minimum_batch_size): def batch_to_ndarray(batch_samples, lod): @@ -377,7 +392,7 @@ class DataReader(object): start += frame_num return (batch_feature, batch_label) - @suppress_complaints(verbose=self._verbose) + @suppress_complaints(verbose=self._verbose, notify=self._force_exit) def batch_assembling_task(sample_generator, batch_queue): batch_samples = [] lod = [0] @@ -406,7 +421,7 @@ class DataReader(object): assembling_thread.daemon = True assembling_thread.start() - while True: + while self._force_exit == False: try: batch_data = batch_queue.get_nowait() except Queue.Empty: @@ -415,5 +430,3 @@ class DataReader(object): if isinstance(batch_data, EpochEndSignal): break yield batch_data - - assembling_thread.join() diff --git a/fluid/DeepASR/data_utils/util.py b/fluid/DeepASR/data_utils/util.py index e64417e502363dfc7731c0bc116be7680c833c6f..2670240a7869ebb34975d9273546ff9489cf026a 100644 --- a/fluid/DeepASR/data_utils/util.py +++ b/fluid/DeepASR/data_utils/util.py @@ -35,21 +35,40 @@ def lodtensor_to_ndarray(lod_tensor): return ret, lod_tensor.lod() +class CriticalException(Exception): + pass + + def suppress_signal(signo, stack_frame): pass -def suppress_complaints(verbose): +def suppress_complaints(verbose, notify=None): def decorator_maker(func): def suppress_warpper(*args, **kwargs): try: func(*args, **kwargs) except: et, ev, tb = sys.exc_info() - tb = Traceback(tb) - if verbose == 1: - reraise(et, ev, tb.as_traceback()) + + if notify is not None: + notify(except_type=et, except_value=ev, traceback=tb) + + if verbose == 1 or isinstance(ev, CriticalException): + reraise(et, ev, Traceback(tb).as_traceback()) return suppress_warpper return decorator_maker + + +class ForceExitWrapper(object): + def __init__(self, exit_flag): + self._exit_flag = exit_flag + + @suppress_complaints(verbose=0) + def __call__(self, *args, **kwargs): + self._exit_flag.value = True + + def __eq__(self, flag): + return self._exit_flag.value == flag diff --git a/fluid/ocr_recognition/ctc_reader.py b/fluid/ocr_recognition/ctc_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..e5264c33de526846161c1e3ada2555addba53e0d --- /dev/null +++ b/fluid/ocr_recognition/ctc_reader.py @@ -0,0 +1,97 @@ +import os +import cv2 +import numpy as np +from PIL import Image + +from paddle.v2.image import load_image + + +class DataGenerator(object): + def __init__(self): + pass + + def train_reader(self, img_root_dir, img_label_list, batchsize): + ''' + Reader interface for training. + + :param img_root_dir: The root path of the image for training. + :type file_list: str + + :param img_label_list: The path of the file for training. + :type file_list: str + + ''' + + img_label_lines = [] + if batchsize == 1: + to_file = "tmp.txt" + cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file + print "cmd: " + cmd + os.system(cmd) + print "finish batch shuffle" + img_label_lines = open(to_file, 'r').readlines() + else: + to_file = "tmp.txt" + #cmd1: partial shuffle + cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | " + #cmd2: batch merge and shuffle + cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str( + batchsize) + " == 0) print \"\";}' | shuf | " + #cmd3: batch split + cmd += "awk '{if(NF == " + str( + batchsize + ) + " * 4) {for(i = 0; i < " + str( + batchsize + ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file + print "cmd: " + cmd + os.system(cmd) + print "finish batch shuffle" + img_label_lines = open(to_file, 'r').readlines() + + def reader(): + sizes = len(img_label_lines) / batchsize + for i in range(sizes): + result = [] + sz = [0, 0] + for j in range(batchsize): + line = img_label_lines[i * batchsize + j] + # h, w, img_name, labels + items = line.split(' ') + + label = [int(c) for c in items[-1].split(',')] + img = Image.open(os.path.join(img_root_dir, items[ + 2])).convert('L') #zhuanhuidu + if j == 0: + sz = img.size + img = img.resize((sz[0], sz[1])) + img = np.array(img) - 127.5 + img = img[np.newaxis, ...] + result.append([img, label]) + yield result + + return reader + + def test_reader(self, img_root_dir, img_label_list): + ''' + Reader interface for inference. + + :param img_root_dir: The root path of the images for training. + :type file_list: str + + :param img_label_list: The path of the file for testing. + :type file_list: list + ''' + + def reader(): + for line in open(img_label_list): + # h, w, img_name, labels + items = line.split(' ') + + label = [int(c) for c in items[-1].split(',')] + img = Image.open(os.path.join(img_root_dir, items[2])).convert( + 'L') + img = np.array(img) - 127.5 + img = img[np.newaxis, ...] + yield img, label + + return reader