提交 49fb3e6b 编写于 作者: Y yangyaming

Refine parallel reader.

上级 bfda10aa
...@@ -8,8 +8,9 @@ import numpy as np ...@@ -8,8 +8,9 @@ import numpy as np
import struct import struct
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
from multiprocessing import Manager, Pool from multiprocessing import Manager, Process
from threading import Thread from threading import Thread
import time
class SampleInfo(object): class SampleInfo(object):
...@@ -78,7 +79,11 @@ class SampleInfoBucket(object): ...@@ -78,7 +79,11 @@ class SampleInfoBucket(object):
return sample_info_list return sample_info_list
def DataReader(object): class EpochEndSignal():
pass
class DataReader(object):
def __init__(self, def __init__(self,
feature_file_list, feature_file_list,
label_file_list, label_file_list,
...@@ -91,6 +96,9 @@ def DataReader(object): ...@@ -91,6 +96,9 @@ def DataReader(object):
self._drop_frame_len = 256 self._drop_frame_len = 256
self._feature_file_list = feature_file_list self._feature_file_list = feature_file_list
self._label_file_list = label_file_list self._label_file_list = label_file_list
self._block_info_list = None
self._bucket_list = None
self._order_id = 0
self.generate_bucket_list(True) self.generate_bucket_list(True)
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
...@@ -114,7 +122,7 @@ def DataReader(object): ...@@ -114,7 +122,7 @@ def DataReader(object):
for i in xrange(0, len(self._block_info_list), self._shuffle_block_num): for i in xrange(0, len(self._block_info_list), self._shuffle_block_num):
bucket_block_info = self._block_info_list[i:i + bucket_block_info = self._block_info_list[i:i +
self._shuffle_block_num] self._shuffle_block_num]
buket_list.append( self._bucket_list.append(
SampleInfoBucket( SampleInfoBucket(
map(lambda info: info[0], bucket_block_info), map(lambda info: info[0], bucket_block_info),
map(lambda info: info[1], bucket_block_info), map(lambda info: info[1], bucket_block_info),
...@@ -125,12 +133,35 @@ def DataReader(object): ...@@ -125,12 +133,35 @@ def DataReader(object):
self._transformers = transformers self._transformers = transformers
def _sample_generator(self): def _sample_generator(self):
sample_queue = Queue.Queue(1024) manager = Manager()
sample_info_queue = manager.Queue(1024)
sample_queue = manager.Queue(1024)
process_num = 1
self._order_id = 0
def ordered_feeding_worker(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
random.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
for i in xrange(process_num):
sample_info_queue.put(EpochEndSignal())
feeding_thread = Thread(
target=ordered_feeding_worker, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()
def data_loading_worker(sample_queue): def ordered_processing_worker(sample_info_queue, sample_queue,
pool = Pool(processes=10) out_order):
ins = sample_info_queue.get()
def sample_processing_worker(sample_info): while not isinstance(ins, EpochEndSignal):
sample_info, order_id = ins
f_feature = open(sample_info.feature_bin_path, 'r') f_feature = open(sample_info.feature_bin_path, 'r')
f_label = open(sample_info.label_bin_path, 'r') f_label = open(sample_info.label_bin_path, 'r')
...@@ -138,7 +169,7 @@ def DataReader(object): ...@@ -138,7 +169,7 @@ def DataReader(object):
label_bytes = f_label.read(sample_info.label_size) label_bytes = f_label.read(sample_info.label_size)
f_label.close() f_label.close()
assert sample_info.label_frame_num * 4 == label_bytes assert sample_info.label_frame_num * 4 == len(label_bytes)
label_array = struct.unpack('I' * sample_info.label_frame_num, label_array = struct.unpack('I' * sample_info.label_frame_num,
label_bytes) label_bytes)
label_data = np.array( label_data = np.array(
...@@ -148,48 +179,61 @@ def DataReader(object): ...@@ -148,48 +179,61 @@ def DataReader(object):
f_feature.seek(sample_info.feature_start, 0) f_feature.seek(sample_info.feature_start, 0)
feature_bytes = f_feature.read(sample_info.feature_size) feature_bytes = f_feature.read(sample_info.feature_size)
f_feature.close() f_feature.close()
assert sample_info.feature_frame_num * sample_info.feature_dim * 4 == feature_bytes assert sample_info.feature_frame_num * sample_info.feature_dim * 4 == len(
feature_bytes)
feature_array = struct.unpack( feature_array = struct.unpack(
'f' * sample_info.feature_frame_num * 'f' * sample_info.feature_frame_num *
sample_info.feature_dim, feature_bytes) sample_info.feature_dim, feature_bytes)
feature_data = np.array( feature_data = np.array(
feature_array, dytpe='float32').reshape(( feature_array, dtype='float32').reshape((
sample_info.feature_frame_num, sample_info.feature_dim)) sample_info.feature_frame_num, sample_info.feature_dim))
# drop long sentence
if self._drop_sentence_len < sample_data[0].shape[0]:
return None
sample_data = (feature_data, label_data) sample_data = (feature_data, label_data)
for transformer in self._transformers: for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data # @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data) sample_data = transformer.perform_trans(sample_data)
return sample_data while order_id != out_order[0]:
time.sleep(0.001)
for sample_info_bucket in self._bucket_list: # drop long sentence
sample_info_list = sample_info_bucket.generate_sample_info_list( if self._drop_sentence_len >= sample_data[0].shape[0]:
)
random.shuffle(sample_info_list) # do shuffle here
processed_data = pool.map(
f, sample_info_list) # the result is ordered
for sample_data in processed_data:
if sample_data is None: continue
sample_queue.put(sample_data) sample_queue.put(sample_data)
sample_queue.put(None) print('sub process: %d' % sample_queue.qsize())
out_order[0] += 1
time.sleep(0.1)
t = Thread(target=data_processing_worker, args=(sample_queue)) if order_id == self._order_id:
t.daemon = True sample_queue.put(EpochEndSignal())
t.start()
ins = sample_info_queue.get()
out_order = manager.list([0])
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_worker, args=args)
for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
while True: while True:
print('main thread: %d' % sample_queue.qsize())
sample = sample_queue.get() sample = sample_queue.get()
if sample is None: break if isinstance(sample, EpochEndSignal): break
yield sample yield sample
feeding_thread.join()
for w in workers:
w.join()
def batch_iterator(self, batch_size, minimum_batch_size): def batch_iterator(self, batch_size, minimum_batch_size):
batch_samples = [] batch_samples = []
lod = [0] lod = [0]
# check whether need parallel here # check whether need parallel here
...@@ -211,7 +255,8 @@ def batch_iterator(self, batch_size, minimum_batch_size): ...@@ -211,7 +255,8 @@ def batch_iterator(self, batch_size, minimum_batch_size):
lod = [0] lod = [0]
if len(batch_samples) >= minimum_batch_size: if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros((lod[-1], self._frame_dim), dtype="float32") batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64") batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0 start = 0
for sample in batch_samples: for sample in batch_samples:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册