提交 79cb57cc 编写于 作者: Y yangyaming

Fix the exiting logic for sample generator.

上级 49fb3e6b
...@@ -98,8 +98,8 @@ class DataReader(object): ...@@ -98,8 +98,8 @@ class DataReader(object):
self._label_file_list = label_file_list self._label_file_list = label_file_list
self._block_info_list = None self._block_info_list = None
self._bucket_list = None self._bucket_list = None
self._order_id = 0
self.generate_bucket_list(True) self.generate_bucket_list(True)
self._order_id = 0
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
...@@ -136,7 +136,7 @@ class DataReader(object): ...@@ -136,7 +136,7 @@ class DataReader(object):
manager = Manager() manager = Manager()
sample_info_queue = manager.Queue(1024) sample_info_queue = manager.Queue(1024)
sample_queue = manager.Queue(1024) sample_queue = manager.Queue(1024)
process_num = 1 process_num = 2
self._order_id = 0 self._order_id = 0
def ordered_feeding_worker(sample_info_queue): def ordered_feeding_worker(sample_info_queue):
...@@ -161,6 +161,7 @@ class DataReader(object): ...@@ -161,6 +161,7 @@ class DataReader(object):
ins = sample_info_queue.get() ins = sample_info_queue.get()
while not isinstance(ins, EpochEndSignal): while not isinstance(ins, EpochEndSignal):
# @TODO(pkuyym) add block cache to cache several block (LRU) into memory
sample_info, order_id = ins sample_info, order_id = ins
f_feature = open(sample_info.feature_bin_path, 'r') f_feature = open(sample_info.feature_bin_path, 'r')
f_label = open(sample_info.label_bin_path, 'r') f_label = open(sample_info.label_bin_path, 'r')
...@@ -200,17 +201,11 @@ class DataReader(object): ...@@ -200,17 +201,11 @@ class DataReader(object):
if self._drop_sentence_len >= sample_data[0].shape[0]: if self._drop_sentence_len >= sample_data[0].shape[0]:
sample_queue.put(sample_data) sample_queue.put(sample_data)
print('sub process: %d' % sample_queue.qsize())
out_order[0] += 1 out_order[0] += 1
time.sleep(0.1)
if order_id == self._order_id:
sample_queue.put(EpochEndSignal())
ins = sample_info_queue.get() ins = sample_info_queue.get()
sample_queue.put(EpochEndSignal())
out_order = manager.list([0]) out_order = manager.list([0])
args = (sample_info_queue, sample_queue, out_order) args = (sample_info_queue, sample_queue, out_order)
workers = [ workers = [
...@@ -223,10 +218,13 @@ class DataReader(object): ...@@ -223,10 +218,13 @@ class DataReader(object):
w.daemon = True w.daemon = True
w.start() w.start()
finished_process_num = 0
while True: while True:
print('main thread: %d' % sample_queue.qsize())
sample = sample_queue.get() sample = sample_queue.get()
if isinstance(sample, EpochEndSignal): break if isinstance(sample, EpochEndSignal):
finished_process_num += 1
continue
yield sample yield sample
feeding_thread.join() feeding_thread.join()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册