提交 8bb81322 编写于 作者: Z zhxfl

merge develop

...@@ -10,9 +10,11 @@ import Queue ...@@ -10,9 +10,11 @@ import Queue
import time import time
import numpy as np import numpy as np
from threading import Thread from threading import Thread
import signal
from multiprocessing import Manager, Process from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
class SampleInfo(object): class SampleInfo(object):
...@@ -175,6 +177,9 @@ class DataReader(object): ...@@ -175,6 +177,9 @@ class DataReader(object):
shuffle_block_num (int): Block number indicating the minimum unit to do shuffle_block_num (int): Block number indicating the minimum unit to do
shuffle. shuffle.
random_seed (int): Random seed. random_seed (int): Random seed.
verbose (int): If set to 0, complaints including exceptions and signal
traceback from sub-process will be suppressed. If set
to 1, all complaints will be printed.
""" """
def __init__(self, def __init__(self,
...@@ -186,7 +191,8 @@ class DataReader(object): ...@@ -186,7 +191,8 @@ class DataReader(object):
sample_info_buffer_size=1024, sample_info_buffer_size=1024,
batch_buffer_size=1024, batch_buffer_size=1024,
shuffle_block_num=10, shuffle_block_num=10,
random_seed=0): random_seed=0,
verbose=0):
self._feature_file_list = feature_file_list self._feature_file_list = feature_file_list
self._label_file_list = label_file_list self._label_file_list = label_file_list
self._drop_frame_len = drop_frame_len self._drop_frame_len = drop_frame_len
...@@ -201,6 +207,7 @@ class DataReader(object): ...@@ -201,6 +207,7 @@ class DataReader(object):
self._sample_info_buffer_size = sample_info_buffer_size self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size self._batch_buffer_size = batch_buffer_size
self._process_num = process_num self._process_num = process_num
self._verbose = verbose
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
...@@ -239,6 +246,7 @@ class DataReader(object): ...@@ -239,6 +246,7 @@ class DataReader(object):
sample_queue = self._manager.Queue(self._sample_buffer_size) sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0 self._order_id = 0
@suppress_complaints(verbose=self._verbose)
def ordered_feeding_task(sample_info_queue): def ordered_feeding_task(sample_info_queue):
for sample_info_bucket in self._bucket_list: for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list( sample_info_list = sample_info_bucket.generate_sample_info_list(
...@@ -256,7 +264,12 @@ class DataReader(object): ...@@ -256,7 +264,12 @@ class DataReader(object):
feeding_thread.daemon = True feeding_thread.daemon = True
feeding_thread.start() feeding_thread.start()
@suppress_complaints(verbose=self._verbose)
def ordered_processing_task(sample_info_queue, sample_queue, out_order): def ordered_processing_task(sample_info_queue, sample_queue, out_order):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal())
signal.signal(signal.SIGINT, suppress_signal())
def read_bytes(fpath, start, size): def read_bytes(fpath, start, size):
f = open(fpath, 'r') f = open(fpath, 'r')
f.seek(start, 0) f.seek(start, 0)
...@@ -358,6 +371,7 @@ class DataReader(object): ...@@ -358,6 +371,7 @@ class DataReader(object):
start += frame_num start += frame_num
return (batch_feature, batch_label) return (batch_feature, batch_label)
@suppress_complaints(verbose=self._verbose)
def batch_assembling_task(sample_generator, batch_queue): def batch_assembling_task(sample_generator, batch_queue):
batch_samples = [] batch_samples = []
lod = [0] lod = [0]
...@@ -386,9 +400,14 @@ class DataReader(object): ...@@ -386,9 +400,14 @@ class DataReader(object):
assembling_thread.daemon = True assembling_thread.daemon = True
assembling_thread.start() assembling_thread.start()
batch_data = batch_queue.get() while True:
while not isinstance(batch_data, EpochEndSignal): try:
yield batch_data batch_data = batch_queue.get_nowait()
batch_data = batch_queue.get() except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
assembling_thread.join() assembling_thread.join()
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
from six import reraise
from tblib import Traceback
def to_lodtensor(data, place): def to_lodtensor(data, place):
...@@ -28,3 +31,23 @@ def lodtensor_to_ndarray(lod_tensor): ...@@ -28,3 +31,23 @@ def lodtensor_to_ndarray(lod_tensor):
for i in xrange(np.product(dims)): for i in xrange(np.product(dims)):
ret.ravel()[i] = lod_tensor.get_float_element(i) ret.ravel()[i] = lod_tensor.get_float_element(i)
return ret, lod_tensor.lod() return ret, lod_tensor.lod()
def suppress_signal(signo, stack_frame):
pass
def suppress_complaints(verbose):
def decorator_maker(func):
def suppress_warpper(*args, **kwargs):
try:
func(*args, **kwargs)
except:
et, ev, tb = sys.exc_info()
tb = Traceback(tb)
if verbose == 1:
reraise(et, ev, tb.as_traceback())
return suppress_warpper
return decorator_maker
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册