From 79cb57cc098974abb8857dfd979567e4892f2cdb Mon Sep 17 00:00:00 2001 From: yangyaming Date: Sat, 3 Feb 2018 22:51:59 +0800 Subject: [PATCH] Fix the exiting logic for sample generator. --- fluid/DeepASR/data_utils/parallel_reader.py | 22 ++++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/fluid/DeepASR/data_utils/parallel_reader.py b/fluid/DeepASR/data_utils/parallel_reader.py index c4915c49..aa31ad53 100644 --- a/fluid/DeepASR/data_utils/parallel_reader.py +++ b/fluid/DeepASR/data_utils/parallel_reader.py @@ -98,8 +98,8 @@ class DataReader(object): self._label_file_list = label_file_list self._block_info_list = None self._bucket_list = None - self._order_id = 0 self.generate_bucket_list(True) + self._order_id = 0 def generate_bucket_list(self, is_shuffle): if self._block_info_list is None: @@ -136,7 +136,7 @@ class DataReader(object): manager = Manager() sample_info_queue = manager.Queue(1024) sample_queue = manager.Queue(1024) - process_num = 1 + process_num = 2 self._order_id = 0 def ordered_feeding_worker(sample_info_queue): @@ -161,6 +161,7 @@ class DataReader(object): ins = sample_info_queue.get() while not isinstance(ins, EpochEndSignal): + # @TODO(pkuyym) add block cache to cache several block (LRU) into memory sample_info, order_id = ins f_feature = open(sample_info.feature_bin_path, 'r') f_label = open(sample_info.label_bin_path, 'r') @@ -200,17 +201,11 @@ class DataReader(object): if self._drop_sentence_len >= sample_data[0].shape[0]: sample_queue.put(sample_data) - print('sub process: %d' % sample_queue.qsize()) - out_order[0] += 1 - - time.sleep(0.1) - - if order_id == self._order_id: - sample_queue.put(EpochEndSignal()) - ins = sample_info_queue.get() + sample_queue.put(EpochEndSignal()) + out_order = manager.list([0]) args = (sample_info_queue, sample_queue, out_order) workers = [ @@ -223,10 +218,13 @@ class DataReader(object): w.daemon = True w.start() + finished_process_num = 0 + while True: - print('main thread: %d' % sample_queue.qsize()) sample = sample_queue.get() - if isinstance(sample, EpochEndSignal): break + if isinstance(sample, EpochEndSignal): + finished_process_num += 1 + continue yield sample feeding_thread.join() -- GitLab