提交 dd008f15 编写于 作者: Y yangyaming

Make batch assembling parallel.

上级 f12deac8
...@@ -5,13 +5,14 @@ from __future__ import division ...@@ -5,13 +5,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import random import random
import numpy as np
import struct import struct
import Queue
import time
import numpy as np
from threading import Thread
from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
from multiprocessing import Manager, Process
from threading import Thread
import time
class SampleInfo(object): class SampleInfo(object):
...@@ -127,6 +128,8 @@ class DataReader(object): ...@@ -127,6 +128,8 @@ class DataReader(object):
cached. cached.
sample_info_buffer_size (int): Buffer size to indicate the maximum sample_info_buffer_size (int): Buffer size to indicate the maximum
sample information cached. sample information cached.
batch_buffer_size (int): Buffer size to indicate the maximum batch
cached.
shuffle_block_num (int): Block number indicating the minimum unit to do shuffle_block_num (int): Block number indicating the minimum unit to do
shuffle. shuffle.
random_seed (int): Random seed. random_seed (int): Random seed.
...@@ -141,7 +144,8 @@ class DataReader(object): ...@@ -141,7 +144,8 @@ class DataReader(object):
drop_frame_len=256, drop_frame_len=256,
process_num=10, process_num=10,
sample_buffer_size=1024, sample_buffer_size=1024,
sample_info_buffer_size=10000, sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=1, shuffle_block_num=1,
random_seed=0): random_seed=0):
self._feature_file_list = feature_file_list self._feature_file_list = feature_file_list
...@@ -158,6 +162,7 @@ class DataReader(object): ...@@ -158,6 +162,7 @@ class DataReader(object):
self._manager = Manager() self._manager = Manager()
self._sample_buffer_size = sample_buffer_size self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._process_num = process_num self._process_num = process_num
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
...@@ -197,7 +202,7 @@ class DataReader(object): ...@@ -197,7 +202,7 @@ class DataReader(object):
sample_queue = self._manager.Queue(self._sample_buffer_size) sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0 self._order_id = 0
def ordered_feeding_worker(sample_info_queue): def ordered_feeding_task(sample_info_queue):
for sample_info_bucket in self._bucket_list: for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list( sample_info_list = sample_info_bucket.generate_sample_info_list(
) )
...@@ -210,12 +215,11 @@ class DataReader(object): ...@@ -210,12 +215,11 @@ class DataReader(object):
sample_info_queue.put(EpochEndSignal()) sample_info_queue.put(EpochEndSignal())
feeding_thread = Thread( feeding_thread = Thread(
target=ordered_feeding_worker, args=(sample_info_queue, )) target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True feeding_thread.daemon = True
feeding_thread.start() feeding_thread.start()
def ordered_processing_worker(sample_info_queue, sample_queue, def ordered_processing_task(sample_info_queue, sample_queue, out_order):
out_order):
def read_bytes(fpath, start, size): def read_bytes(fpath, start, size):
f = open(fpath, 'r') f = open(fpath, 'r')
f.seek(start, 0) f.seek(start, 0)
...@@ -273,7 +277,7 @@ class DataReader(object): ...@@ -273,7 +277,7 @@ class DataReader(object):
args = (sample_info_queue, sample_queue, out_order) args = (sample_info_queue, sample_queue, out_order)
workers = [ workers = [
Process( Process(
target=ordered_processing_worker, args=args) target=ordered_processing_task, args=args)
for _ in xrange(self._process_num) for _ in xrange(self._process_num)
] ]
...@@ -295,13 +299,27 @@ class DataReader(object): ...@@ -295,13 +299,27 @@ class DataReader(object):
w.join() w.join()
def batch_iterator(self, batch_size, minimum_batch_size): def batch_iterator(self, batch_size, minimum_batch_size):
batch_samples = [] def batch_assembling_task(sample_generator, batch_queue):
lod = [0] batch_samples = []
# check whether need parallel here lod = [0]
for sample in self._sample_generator(): for sample in sample_generator():
batch_samples.append(sample) batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0]) lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size: if len(batch_samples) == batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros( batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32") (lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64") batch_label = np.zeros((lod[-1], 1), dtype="int64")
...@@ -311,18 +329,21 @@ class DataReader(object): ...@@ -311,18 +329,21 @@ class DataReader(object):
batch_feature[start:start + frame_num, :] = sample[0] batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1] batch_label[start:start + frame_num, :] = sample[1]
start += frame_num start += frame_num
yield (batch_feature, batch_label, lod) batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0] batch_queue.put(EpochEndSignal())
if len(batch_samples) >= minimum_batch_size: batch_queue = Queue.Queue(self._batch_buffer_size)
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32") assembling_thread = Thread(
batch_label = np.zeros((lod[-1], 1), dtype="int64") target=batch_assembling_task,
start = 0 args=(self._sample_generator, batch_queue))
for sample in batch_samples: assembling_thread.daemon = True
frame_num = sample[0].shape[0] assembling_thread.start()
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1] batch_data = batch_queue.get()
start += frame_num while not isinstance(batch_data, EpochEndSignal):
yield (batch_feature, batch_label, lod) yield batch_data
batch_data = batch_queue.get()
assembling_thread.join()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册