提交 8bb81322 编写于 作者: Z zhxfl

merge develop

......@@ -10,9 +10,11 @@ import Queue
import time
import numpy as np
from threading import Thread
import signal
from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
class SampleInfo(object):
......@@ -175,6 +177,9 @@ class DataReader(object):
shuffle_block_num (int): Block number indicating the minimum unit to do
shuffle.
random_seed (int): Random seed.
verbose (int): If set to 0, complaints including exceptions and signal
traceback from sub-process will be suppressed. If set
to 1, all complaints will be printed.
"""
def __init__(self,
......@@ -186,7 +191,8 @@ class DataReader(object):
sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=10,
random_seed=0):
random_seed=0,
verbose=0):
self._feature_file_list = feature_file_list
self._label_file_list = label_file_list
self._drop_frame_len = drop_frame_len
......@@ -201,6 +207,7 @@ class DataReader(object):
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._process_num = process_num
self._verbose = verbose
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
......@@ -239,6 +246,7 @@ class DataReader(object):
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0
@suppress_complaints(verbose=self._verbose)
def ordered_feeding_task(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
......@@ -256,7 +264,12 @@ class DataReader(object):
feeding_thread.daemon = True
feeding_thread.start()
@suppress_complaints(verbose=self._verbose)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal())
signal.signal(signal.SIGINT, suppress_signal())
def read_bytes(fpath, start, size):
f = open(fpath, 'r')
f.seek(start, 0)
......@@ -358,6 +371,7 @@ class DataReader(object):
start += frame_num
return (batch_feature, batch_label)
@suppress_complaints(verbose=self._verbose)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
......@@ -386,9 +400,14 @@ class DataReader(object):
assembling_thread.daemon = True
assembling_thread.start()
batch_data = batch_queue.get()
while not isinstance(batch_data, EpochEndSignal):
yield batch_data
batch_data = batch_queue.get()
while True:
try:
batch_data = batch_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
assembling_thread.join()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from six import reraise
from tblib import Traceback
def to_lodtensor(data, place):
......@@ -28,3 +31,23 @@ def lodtensor_to_ndarray(lod_tensor):
for i in xrange(np.product(dims)):
ret.ravel()[i] = lod_tensor.get_float_element(i)
return ret, lod_tensor.lod()
def suppress_signal(signo, stack_frame):
pass
def suppress_complaints(verbose):
def decorator_maker(func):
def suppress_warpper(*args, **kwargs):
try:
func(*args, **kwargs)
except:
et, ev, tb = sys.exc_info()
tb = Traceback(tb)
if verbose == 1:
reraise(et, ev, tb.as_traceback())
return suppress_warpper
return decorator_maker
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册