diff --git a/fluid/DeepASR/data_utils/parallel_reader.py b/fluid/DeepASR/data_utils/parallel_reader.py index be4baf3cd514ceb4e9e2ae7f4d4131d767470266..6b1430a01208b73998f3caa00cbe03797ec3bee4 100644 --- a/fluid/DeepASR/data_utils/parallel_reader.py +++ b/fluid/DeepASR/data_utils/parallel_reader.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import random -import Queue import numpy as np import struct import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm @@ -87,7 +86,10 @@ class DataReader(object): feature_file_list, label_file_list, drop_sentence_len=512, - seed=1): + parallel_num=10, + sample_buffer_size=1024, + sample_info_buffer_size=10000, + random_seed=0): self._drop_sentence_len = drop_sentence_len self._frame_dim = 120 * 11 self._drop_frame_len = 256 @@ -97,8 +99,12 @@ class DataReader(object): self._label_file_list = label_file_list self._block_info_list = None self._bucket_list = None + self._rng = random.Random(random_seed) self.generate_bucket_list(True) self._order_id = 0 + self._sample_buffer_size = sample_buffer_size + self._sample_info_buffer_size = sample_info_buffer_size + self._process_num = parallel_num def generate_bucket_list(self, is_shuffle): if self._block_info_list is None: @@ -112,10 +118,10 @@ class DataReader(object): block_label_info_lines[i], block_label_info_lines[i + 1]) self._block_info_list.append( - map(lambda x: x.strip(), block_info)) + map(lambda line: line.strip(), block_info)) if is_shuffle: - random.shuffle(self._block_info_list) + self._rng.shuffle(self._block_info_list) self._bucket_list = [] for i in xrange(0, len(self._block_info_list), self._shuffle_block_num): @@ -133,21 +139,20 @@ class DataReader(object): def _sample_generator(self): manager = Manager() - sample_info_queue = manager.Queue(1024) - sample_queue = manager.Queue(1024) - process_num = 2 + sample_info_queue = manager.Queue(self._sample_info_buffer_size) + sample_queue = manager.Queue(self._sample_buffer_size) self._order_id = 0 def ordered_feeding_worker(sample_info_queue): for sample_info_bucket in self._bucket_list: sample_info_list = sample_info_bucket.generate_sample_info_list( ) - random.shuffle(sample_info_list) # do shuffle here + self._rng.shuffle(sample_info_list) # do shuffle here for sample_info in sample_info_list: sample_info_queue.put((sample_info, self._order_id)) self._order_id += 1 - for i in xrange(process_num): + for i in xrange(self._process_num): sample_info_queue.put(EpochEndSignal()) feeding_thread = Thread( @@ -215,7 +220,7 @@ class DataReader(object): workers = [ Process( target=ordered_processing_worker, args=args) - for _ in xrange(process_num) + for _ in xrange(self._process_num) ] for w in workers: @@ -224,7 +229,7 @@ class DataReader(object): finished_process_num = 0 - while finished_process_num < process_num: + while finished_process_num < self._process_num: sample = sample_queue.get() if isinstance(sample, EpochEndSignal): finished_process_num += 1