提交 49fb3e6b 编写于 作者: Y yangyaming

Refine parallel reader.

上级 bfda10aa
......@@ -8,8 +8,9 @@ import numpy as np
import struct
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from multiprocessing import Manager, Pool
from multiprocessing import Manager, Process
from threading import Thread
import time
class SampleInfo(object):
......@@ -78,7 +79,11 @@ class SampleInfoBucket(object):
return sample_info_list
def DataReader(object):
class EpochEndSignal():
pass
class DataReader(object):
def __init__(self,
feature_file_list,
label_file_list,
......@@ -91,6 +96,9 @@ def DataReader(object):
self._drop_frame_len = 256
self._feature_file_list = feature_file_list
self._label_file_list = label_file_list
self._block_info_list = None
self._bucket_list = None
self._order_id = 0
self.generate_bucket_list(True)
def generate_bucket_list(self, is_shuffle):
......@@ -114,7 +122,7 @@ def DataReader(object):
for i in xrange(0, len(self._block_info_list), self._shuffle_block_num):
bucket_block_info = self._block_info_list[i:i +
self._shuffle_block_num]
buket_list.append(
self._bucket_list.append(
SampleInfoBucket(
map(lambda info: info[0], bucket_block_info),
map(lambda info: info[1], bucket_block_info),
......@@ -125,12 +133,35 @@ def DataReader(object):
self._transformers = transformers
def _sample_generator(self):
sample_queue = Queue.Queue(1024)
manager = Manager()
sample_info_queue = manager.Queue(1024)
sample_queue = manager.Queue(1024)
process_num = 1
self._order_id = 0
def data_loading_worker(sample_queue):
pool = Pool(processes=10)
def ordered_feeding_worker(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
random.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
def sample_processing_worker(sample_info):
for i in xrange(process_num):
sample_info_queue.put(EpochEndSignal())
feeding_thread = Thread(
target=ordered_feeding_worker, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()
def ordered_processing_worker(sample_info_queue, sample_queue,
out_order):
ins = sample_info_queue.get()
while not isinstance(ins, EpochEndSignal):
sample_info, order_id = ins
f_feature = open(sample_info.feature_bin_path, 'r')
f_label = open(sample_info.label_bin_path, 'r')
......@@ -138,7 +169,7 @@ def DataReader(object):
label_bytes = f_label.read(sample_info.label_size)
f_label.close()
assert sample_info.label_frame_num * 4 == label_bytes
assert sample_info.label_frame_num * 4 == len(label_bytes)
label_array = struct.unpack('I' * sample_info.label_frame_num,
label_bytes)
label_data = np.array(
......@@ -148,55 +179,82 @@ def DataReader(object):
f_feature.seek(sample_info.feature_start, 0)
feature_bytes = f_feature.read(sample_info.feature_size)
f_feature.close()
assert sample_info.feature_frame_num * sample_info.feature_dim * 4 == feature_bytes
assert sample_info.feature_frame_num * sample_info.feature_dim * 4 == len(
feature_bytes)
feature_array = struct.unpack(
'f' * sample_info.feature_frame_num *
sample_info.feature_dim, feature_bytes)
feature_data = np.array(
feature_array, dytpe='float32').reshape((
feature_array, dtype='float32').reshape((
sample_info.feature_frame_num, sample_info.feature_dim))
# drop long sentence
if self._drop_sentence_len < sample_data[0].shape[0]:
return None
sample_data = (feature_data, label_data)
for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data)
return sample_data
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
random.shuffle(sample_info_list) # do shuffle here
processed_data = pool.map(
f, sample_info_list) # the result is ordered
while order_id != out_order[0]:
time.sleep(0.001)
for sample_data in processed_data:
if sample_data is None: continue
# drop long sentence
if self._drop_sentence_len >= sample_data[0].shape[0]:
sample_queue.put(sample_data)
sample_queue.put(None)
print('sub process: %d' % sample_queue.qsize())
out_order[0] += 1
t = Thread(target=data_processing_worker, args=(sample_queue))
t.daemon = True
t.start()
time.sleep(0.1)
if order_id == self._order_id:
sample_queue.put(EpochEndSignal())
ins = sample_info_queue.get()
out_order = manager.list([0])
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_worker, args=args)
for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
while True:
print('main thread: %d' % sample_queue.qsize())
sample = sample_queue.get()
if sample is None: break
if isinstance(sample, EpochEndSignal): break
yield sample
def batch_iterator(self, batch_size, minimum_batch_size):
batch_samples = []
lod = [0]
# check whether need parallel here
for sample in self._sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
feeding_thread.join()
for w in workers:
w.join()
def batch_iterator(self, batch_size, minimum_batch_size):
batch_samples = []
lod = [0]
# check whether need parallel here
for sample in self._sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
yield (batch_feature, batch_label, lod)
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
......@@ -207,16 +265,3 @@ def batch_iterator(self, batch_size, minimum_batch_size):
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
yield (batch_feature, batch_label, lod)
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros((lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
yield (batch_feature, batch_label, lod)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册