提交 e7442964 编写于 作者: Y yangyaming

Avoid suppressing critical exception.

上级 2738ca10
......@@ -15,6 +15,7 @@ from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import CriticalException, ForceExitWrapper
class SampleInfo(object):
......@@ -166,6 +167,7 @@ class DataReader(object):
self._batch_buffer_size = batch_buffer_size
self._process_num = process_num
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
......@@ -204,15 +206,19 @@ class DataReader(object):
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0
@suppress_complaints(verbose=self._verbose)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
self._rng.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
try:
sample_info_list = \
sample_info_bucket.generate_sample_info_list()
except Exception as e:
raise CriticalException(e)
else:
self._rng.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
for i in xrange(self._process_num):
sample_info_queue.put(EpochEndSignal())
......@@ -222,18 +228,21 @@ class DataReader(object):
feeding_thread.daemon = True
feeding_thread.start()
@suppress_complaints(verbose=self._verbose)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
def read_bytes(fpath, start, size):
f = open(fpath, 'r')
f.seek(start, 0)
binary_bytes = f.read(size)
f.close()
return binary_bytes
try:
f = open(fpath, 'r')
f.seek(start, 0)
binary_bytes = f.read(size)
f.close()
return binary_bytes
except Exception as e:
raise CriticalException(e)
ins = sample_info_queue.get()
......@@ -295,16 +304,20 @@ class DataReader(object):
finished_process_num = 0
while finished_process_num < self._process_num:
sample = sample_queue.get()
if isinstance(sample, EpochEndSignal):
finished_process_num += 1
continue
yield sample
while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
finished_process_num += 1
if finished_process_num >= self._process_num:
break
else:
continue
feeding_thread.join()
for w in workers:
w.join()
yield sample
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
......@@ -320,7 +333,7 @@ class DataReader(object):
start += frame_num
return (batch_feature, batch_label)
@suppress_complaints(verbose=self._verbose)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
......@@ -349,7 +362,7 @@ class DataReader(object):
assembling_thread.daemon = True
assembling_thread.start()
while True:
while self._force_exit == False:
try:
batch_data = batch_queue.get_nowait()
except Queue.Empty:
......@@ -358,5 +371,3 @@ class DataReader(object):
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
assembling_thread.join()
......@@ -35,21 +35,40 @@ def lodtensor_to_ndarray(lod_tensor):
return ret, lod_tensor.lod()
class CriticalException(Exception):
pass
def suppress_signal(signo, stack_frame):
pass
def suppress_complaints(verbose):
def suppress_complaints(verbose, notify=None):
def decorator_maker(func):
def suppress_warpper(*args, **kwargs):
try:
func(*args, **kwargs)
except:
et, ev, tb = sys.exc_info()
tb = Traceback(tb)
if verbose == 1:
reraise(et, ev, tb.as_traceback())
if notify is not None:
notify(except_type=et, except_value=ev, traceback=tb)
if verbose == 1 or isinstance(ev, CriticalException):
reraise(et, ev, Traceback(tb).as_traceback())
return suppress_warpper
return decorator_maker
class ForceExitWrapper(object):
def __init__(self, exit_flag):
self._exit_flag = exit_flag
@suppress_complaints(verbose=0)
def __call__(self, *args, **kwargs):
self._exit_flag.value = True
def __eq__(self, flag):
return self._exit_flag.value == flag
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册