提交 bf281041 编写于 作者: Z zhxfl

fix by review

......@@ -15,6 +15,7 @@ from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import CriticalException, ForceExitWrapper
class SampleInfo(object):
......@@ -89,6 +90,7 @@ class SampleInfoBucket(object):
self._split_perturb = split_perturb
self._split_sentence_threshold = split_sentence_threshold
self._split_sub_sentence_len = split_sub_sentence_len
self._rng = random.Random(0)
def generate_sample_info_list(self):
sample_info_list = []
......@@ -213,6 +215,7 @@ class DataReader(object):
self._batch_buffer_size = batch_buffer_size
self._process_num = process_num
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
......@@ -251,15 +254,19 @@ class DataReader(object):
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0
@suppress_complaints(verbose=self._verbose)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
for sample_info_bucket in self._bucket_list:
sample_info_list = sample_info_bucket.generate_sample_info_list(
)
self._rng.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
try:
sample_info_list = \
sample_info_bucket.generate_sample_info_list()
except Exception as e:
raise CriticalException(e)
else:
self._rng.shuffle(sample_info_list) # do shuffle here
for sample_info in sample_info_list:
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
for i in xrange(self._process_num):
sample_info_queue.put(EpochEndSignal())
......@@ -269,18 +276,21 @@ class DataReader(object):
feeding_thread.daemon = True
feeding_thread.start()
@suppress_complaints(verbose=self._verbose)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
def read_bytes(fpath, start, size):
f = open(fpath, 'r')
f.seek(start, 0)
binary_bytes = f.read(size)
f.close()
return binary_bytes
try:
f = open(fpath, 'r')
f.seek(start, 0)
binary_bytes = f.read(size)
f.close()
return binary_bytes
except Exception as e:
raise CriticalException(e)
ins = sample_info_queue.get()
......@@ -352,16 +362,21 @@ class DataReader(object):
w.start()
finished_process_num = 0
while finished_process_num < self._process_num:
sample = sample_queue.get()
if isinstance(sample, EpochEndSignal):
finished_process_num += 1
continue
yield sample
feeding_thread.join()
for w in workers:
w.join()
while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
finished_process_num += 1
if finished_process_num >= self._process_num:
break
else:
continue
yield sample
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
......@@ -377,7 +392,7 @@ class DataReader(object):
start += frame_num
return (batch_feature, batch_label)
@suppress_complaints(verbose=self._verbose)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
......@@ -406,7 +421,7 @@ class DataReader(object):
assembling_thread.daemon = True
assembling_thread.start()
while True:
while self._force_exit == False:
try:
batch_data = batch_queue.get_nowait()
except Queue.Empty:
......@@ -415,5 +430,3 @@ class DataReader(object):
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
assembling_thread.join()
......@@ -35,21 +35,40 @@ def lodtensor_to_ndarray(lod_tensor):
return ret, lod_tensor.lod()
class CriticalException(Exception):
pass
def suppress_signal(signo, stack_frame):
pass
def suppress_complaints(verbose):
def suppress_complaints(verbose, notify=None):
def decorator_maker(func):
def suppress_warpper(*args, **kwargs):
try:
func(*args, **kwargs)
except:
et, ev, tb = sys.exc_info()
tb = Traceback(tb)
if verbose == 1:
reraise(et, ev, tb.as_traceback())
if notify is not None:
notify(except_type=et, except_value=ev, traceback=tb)
if verbose == 1 or isinstance(ev, CriticalException):
reraise(et, ev, Traceback(tb).as_traceback())
return suppress_warpper
return decorator_maker
class ForceExitWrapper(object):
def __init__(self, exit_flag):
self._exit_flag = exit_flag
@suppress_complaints(verbose=0)
def __call__(self, *args, **kwargs):
self._exit_flag.value = True
def __eq__(self, flag):
return self._exit_flag.value == flag
import os
import cv2
import numpy as np
from PIL import Image
from paddle.v2.image import load_image
class DataGenerator(object):
def __init__(self):
pass
def train_reader(self, img_root_dir, img_label_list, batchsize):
'''
Reader interface for training.
:param img_root_dir: The root path of the image for training.
:type file_list: str
:param img_label_list: The path of the <image_name, label> file for training.
:type file_list: str
'''
img_label_lines = []
if batchsize == 1:
to_file = "tmp.txt"
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
print "cmd: " + cmd
os.system(cmd)
print "finish batch shuffle"
img_label_lines = open(to_file, 'r').readlines()
else:
to_file = "tmp.txt"
#cmd1: partial shuffle
cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
#cmd2: batch merge and shuffle
cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
batchsize) + " == 0) print \"\";}' | shuf | "
#cmd3: batch split
cmd += "awk '{if(NF == " + str(
batchsize
) + " * 4) {for(i = 0; i < " + str(
batchsize
) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
print "cmd: " + cmd
os.system(cmd)
print "finish batch shuffle"
img_label_lines = open(to_file, 'r').readlines()
def reader():
sizes = len(img_label_lines) / batchsize
for i in range(sizes):
result = []
sz = [0, 0]
for j in range(batchsize):
line = img_label_lines[i * batchsize + j]
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L') #zhuanhuidu
if j == 0:
sz = img.size
img = img.resize((sz[0], sz[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
result.append([img, label])
yield result
return reader
def test_reader(self, img_root_dir, img_label_list):
'''
Reader interface for inference.
:param img_root_dir: The root path of the images for training.
:type file_list: str
:param img_label_list: The path of the <image_name, label> file for testing.
:type file_list: list
'''
def reader():
for line in open(img_label_list):
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[2])).convert(
'L')
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
yield img, label
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册