提交 b0d91bd8 编写于 作者: Y yangyaming

Make buffer queues be objects of class AsyncReader.

上级 61499cda
...@@ -22,7 +22,6 @@ from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal ...@@ -22,7 +22,6 @@ from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
class SampleInfo(object): class SampleInfo(object):
"""SampleInfo holds the necessary information to load a sample from disk. """SampleInfo holds the necessary information to load a sample from disk.
Args: Args:
feature_bin_path (str): File containing the feature data. feature_bin_path (str): File containing the feature data.
feature_start (int): Start position of the sample's feature data. feature_start (int): Start position of the sample's feature data.
...@@ -55,7 +54,6 @@ class SampleInfoBucket(object): ...@@ -55,7 +54,6 @@ class SampleInfoBucket(object):
data, sample start position, sample byte number etc.) to access samples' data, sample start position, sample byte number etc.) to access samples'
feature data and the same with the label description file. SampleInfoBucket feature data and the same with the label description file. SampleInfoBucket
is the minimum unit to do shuffle. is the minimum unit to do shuffle.
Args: Args:
feature_bin_paths (list|tuple): Files containing the binary feature feature_bin_paths (list|tuple): Files containing the binary feature
data. data.
...@@ -165,7 +163,6 @@ class SampleInfoBucket(object): ...@@ -165,7 +163,6 @@ class SampleInfoBucket(object):
class AsyncDataReader(object): class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including """DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation. data loading and data augmentation.
Args: Args:
feature_file_list (str): File containing paths of feature data file and feature_file_list (str): File containing paths of feature data file and
corresponding description file. corresponding description file.
...@@ -209,8 +206,6 @@ class AsyncDataReader(object): ...@@ -209,8 +206,6 @@ class AsyncDataReader(object):
self.generate_bucket_list(True) self.generate_bucket_list(True)
self._order_id = 0 self._order_id = 0
self._manager = Manager() self._manager = Manager()
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num self._proc_num = proc_num
if self._proc_num <= 2: if self._proc_num <= 2:
...@@ -218,6 +213,10 @@ class AsyncDataReader(object): ...@@ -218,6 +213,10 @@ class AsyncDataReader(object):
self._sample_proc_num = self._proc_num - 2 self._sample_proc_num = self._proc_num - 2
self._verbose = verbose self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False)) self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
# buffer queue
self._sample_info_queue = self._manager.Queue(sample_info_buffer_size)
self._sample_queue = self._manager.Queue(sample_buffer_size)
self._batch_queue = self._manager.Queue(batch_buffer_size)
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
...@@ -258,8 +257,6 @@ class AsyncDataReader(object): ...@@ -258,8 +257,6 @@ class AsyncDataReader(object):
shared_ndarray.recycle(self._pool_manager.pool) shared_ndarray.recycle(self._pool_manager.pool)
def _start_async_processing(self): def _start_async_processing(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0 self._order_id = 0
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
...@@ -284,7 +281,9 @@ class AsyncDataReader(object): ...@@ -284,7 +281,9 @@ class AsyncDataReader(object):
sample_info_queue.put(EpochEndSignal()) sample_info_queue.put(EpochEndSignal())
feeding_proc = DaemonProcessGroup( feeding_proc = DaemonProcessGroup(
proc_num=1, target=ordered_feeding_task, args=(sample_info_queue, )) proc_num=1,
target=ordered_feeding_task,
args=(self._sample_info_queue, ))
feeding_proc.start_all() feeding_proc.start_all()
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
...@@ -361,15 +360,13 @@ class AsyncDataReader(object): ...@@ -361,15 +360,13 @@ class AsyncDataReader(object):
sample_queue.put(EpochEndSignal()) sample_queue.put(EpochEndSignal())
out_order = self._manager.list([0]) out_order = self._manager.list([0])
args = (sample_info_queue, sample_queue, out_order) args = (self._sample_info_queue, self._sample_queue, out_order)
sample_proc = DaemonProcessGroup( sample_proc = DaemonProcessGroup(
proc_num=self._sample_proc_num, proc_num=self._sample_proc_num,
target=ordered_processing_task, target=ordered_processing_task,
args=args) args=args)
sample_proc.start_all() sample_proc.start_all()
return sample_queue
def batch_iterator(self, batch_size, minimum_batch_size): def batch_iterator(self, batch_size, minimum_batch_size):
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool): def batch_assembling_task(sample_queue, batch_queue, pool):
...@@ -419,8 +416,7 @@ class AsyncDataReader(object): ...@@ -419,8 +416,7 @@ class AsyncDataReader(object):
batch_queue.put(EpochEndSignal()) batch_queue.put(EpochEndSignal())
sample_queue = self._start_async_processing() self._start_async_processing()
batch_queue = self._manager.Queue(self._batch_buffer_size)
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size * self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager) 3, self._manager)
...@@ -428,12 +424,13 @@ class AsyncDataReader(object): ...@@ -428,12 +424,13 @@ class AsyncDataReader(object):
assembling_proc = DaemonProcessGroup( assembling_proc = DaemonProcessGroup(
proc_num=1, proc_num=1,
target=batch_assembling_task, target=batch_assembling_task,
args=(sample_queue, batch_queue, self._pool_manager.pool)) args=(self._sample_queue, self._batch_queue,
self._pool_manager.pool))
assembling_proc.start_all() assembling_proc.start_all()
while self._force_exit == False: while self._force_exit == False:
try: try:
batch_data = batch_queue.get_nowait() batch_data = self._batch_queue.get_nowait()
except Queue.Empty: except Queue.Empty:
time.sleep(0.001) time.sleep(0.001)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册