提交 91b3ed00 编写于 作者: Y yangyaming

Make some parameters configurable including buffer size and random seed

etc.
上级 b24a2c20
...@@ -3,7 +3,6 @@ from __future__ import division ...@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import random import random
import Queue
import numpy as np import numpy as np
import struct import struct
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
...@@ -87,7 +86,10 @@ class DataReader(object): ...@@ -87,7 +86,10 @@ class DataReader(object):
feature_file_list, feature_file_list,
label_file_list, label_file_list,
drop_sentence_len=512, drop_sentence_len=512,
seed=1): parallel_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=10000,
random_seed=0):
self._drop_sentence_len = drop_sentence_len self._drop_sentence_len = drop_sentence_len
self._frame_dim = 120 * 11 self._frame_dim = 120 * 11
self._drop_frame_len = 256 self._drop_frame_len = 256
...@@ -97,8 +99,12 @@ class DataReader(object): ...@@ -97,8 +99,12 @@ class DataReader(object):
self._label_file_list = label_file_list self._label_file_list = label_file_list
self._block_info_list = None self._block_info_list = None
self._bucket_list = None self._bucket_list = None
self._rng = random.Random(random_seed)
self.generate_bucket_list(True) self.generate_bucket_list(True)
self._order_id = 0 self._order_id = 0
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._process_num = parallel_num
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
...@@ -112,10 +118,10 @@ class DataReader(object): ...@@ -112,10 +118,10 @@ class DataReader(object):
block_label_info_lines[i], block_label_info_lines[i],
block_label_info_lines[i + 1]) block_label_info_lines[i + 1])
self._block_info_list.append( self._block_info_list.append(
map(lambda x: x.strip(), block_info)) map(lambda line: line.strip(), block_info))
if is_shuffle: if is_shuffle:
random.shuffle(self._block_info_list) self._rng.shuffle(self._block_info_list)
self._bucket_list = [] self._bucket_list = []
for i in xrange(0, len(self._block_info_list), self._shuffle_block_num): for i in xrange(0, len(self._block_info_list), self._shuffle_block_num):
...@@ -133,21 +139,20 @@ class DataReader(object): ...@@ -133,21 +139,20 @@ class DataReader(object):
def _sample_generator(self): def _sample_generator(self):
manager = Manager() manager = Manager()
sample_info_queue = manager.Queue(1024) sample_info_queue = manager.Queue(self._sample_info_buffer_size)
sample_queue = manager.Queue(1024) sample_queue = manager.Queue(self._sample_buffer_size)
process_num = 2
self._order_id = 0 self._order_id = 0
def ordered_feeding_worker(sample_info_queue): def ordered_feeding_worker(sample_info_queue):
for sample_info_bucket in self._bucket_list: for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list( sample_info_list = sample_info_bucket.generate_sample_info_list(
) )
random.shuffle(sample_info_list) # do shuffle here self._rng.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list: for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id)) sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1 self._order_id += 1
for i in xrange(process_num): for i in xrange(self._process_num):
sample_info_queue.put(EpochEndSignal()) sample_info_queue.put(EpochEndSignal())
feeding_thread = Thread( feeding_thread = Thread(
...@@ -215,7 +220,7 @@ class DataReader(object): ...@@ -215,7 +220,7 @@ class DataReader(object):
workers = [ workers = [
Process( Process(
target=ordered_processing_worker, args=args) target=ordered_processing_worker, args=args)
for _ in xrange(process_num) for _ in xrange(self._process_num)
] ]
for w in workers: for w in workers:
...@@ -224,7 +229,7 @@ class DataReader(object): ...@@ -224,7 +229,7 @@ class DataReader(object):
finished_process_num = 0 finished_process_num = 0
while finished_process_num < process_num: while finished_process_num < self._process_num:
sample = sample_queue.get() sample = sample_queue.get()
if isinstance(sample, EpochEndSignal): if isinstance(sample, EpochEndSignal):
finished_process_num += 1 finished_process_num += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册