提交 dd008f15 编写于 作者: Y yangyaming

Make batch assembling parallel.

上级 f12deac8
......@@ -5,13 +5,14 @@ from __future__ import division
from __future__ import print_function
import random
import numpy as np
import struct
import Queue
import time
import numpy as np
from threading import Thread
from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from multiprocessing import Manager, Process
from threading import Thread
import time
class SampleInfo(object):
......@@ -127,6 +128,8 @@ class DataReader(object):
cached.
sample_info_buffer_size (int): Buffer size to indicate the maximum
sample information cached.
batch_buffer_size (int): Buffer size to indicate the maximum batch
cached.
shuffle_block_num (int): Block number indicating the minimum unit to do
shuffle.
random_seed (int): Random seed.
......@@ -141,7 +144,8 @@ class DataReader(object):
drop_frame_len=256,
process_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=10000,
sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=1,
random_seed=0):
self._feature_file_list = feature_file_list
......@@ -158,6 +162,7 @@ class DataReader(object):
self._manager = Manager()
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._process_num = process_num
def generate_bucket_list(self, is_shuffle):
......@@ -197,7 +202,7 @@ class DataReader(object):
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0
def ordered_feeding_worker(sample_info_queue):
def ordered_feeding_task(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
......@@ -210,12 +215,11 @@ class DataReader(object):
sample_info_queue.put(EpochEndSignal())
feeding_thread = Thread(
target=ordered_feeding_worker, args=(sample_info_queue, ))
target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()
def ordered_processing_worker(sample_info_queue, sample_queue,
out_order):
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
def read_bytes(fpath, start, size):
f = open(fpath, 'r')
f.seek(start, 0)
......@@ -273,7 +277,7 @@ class DataReader(object):
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_worker, args=args)
target=ordered_processing_task, args=args)
for _ in xrange(self._process_num)
]
......@@ -295,13 +299,27 @@ class DataReader(object):
w.join()
def batch_iterator(self, batch_size, minimum_batch_size):
batch_samples = []
lod = [0]
# check whether need parallel here
for sample in self._sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
......@@ -311,18 +329,21 @@ class DataReader(object):
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
yield (batch_feature, batch_label, lod)
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
yield (batch_feature, batch_label, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_queue.put(EpochEndSignal())
batch_queue = Queue.Queue(self._batch_buffer_size)
assembling_thread = Thread(
target=batch_assembling_task,
args=(self._sample_generator, batch_queue))
assembling_thread.daemon = True
assembling_thread.start()
batch_data = batch_queue.get()
while not isinstance(batch_data, EpochEndSignal):
yield batch_data
batch_data = batch_queue.get()
assembling_thread.join()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册