提交 91b3ed00 编写于 作者: Y yangyaming

Make some parameters configurable including buffer size and random seed

etc.
上级 b24a2c20
......@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import random
import Queue
import numpy as np
import struct
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
......@@ -87,7 +86,10 @@ class DataReader(object):
feature_file_list,
label_file_list,
drop_sentence_len=512,
seed=1):
parallel_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=10000,
random_seed=0):
self._drop_sentence_len = drop_sentence_len
self._frame_dim = 120 * 11
self._drop_frame_len = 256
......@@ -97,8 +99,12 @@ class DataReader(object):
self._label_file_list = label_file_list
self._block_info_list = None
self._bucket_list = None
self._rng = random.Random(random_seed)
self.generate_bucket_list(True)
self._order_id = 0
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._process_num = parallel_num
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
......@@ -112,10 +118,10 @@ class DataReader(object):
block_label_info_lines[i],
block_label_info_lines[i + 1])
self._block_info_list.append(
map(lambda x: x.strip(), block_info))
map(lambda line: line.strip(), block_info))
if is_shuffle:
random.shuffle(self._block_info_list)
self._rng.shuffle(self._block_info_list)
self._bucket_list = []
for i in xrange(0, len(self._block_info_list), self._shuffle_block_num):
......@@ -133,21 +139,20 @@ class DataReader(object):
def _sample_generator(self):
manager = Manager()
sample_info_queue = manager.Queue(1024)
sample_queue = manager.Queue(1024)
process_num = 2
sample_info_queue = manager.Queue(self._sample_info_buffer_size)
sample_queue = manager.Queue(self._sample_buffer_size)
self._order_id = 0
def ordered_feeding_worker(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
random.shuffle(sample_info_list) # do shuffle here
self._rng.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
for i in xrange(process_num):
for i in xrange(self._process_num):
sample_info_queue.put(EpochEndSignal())
feeding_thread = Thread(
......@@ -215,7 +220,7 @@ class DataReader(object):
workers = [
Process(
target=ordered_processing_worker, args=args)
for _ in xrange(process_num)
for _ in xrange(self._process_num)
]
for w in workers:
......@@ -224,7 +229,7 @@ class DataReader(object):
finished_process_num = 0
while finished_process_num < process_num:
while finished_process_num < self._process_num:
sample = sample_queue.get()
if isinstance(sample, EpochEndSignal):
finished_process_num += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册