diff --git a/data_utils/utility.py b/data_utils/utility.py index 96df2485f4c8a2f9960eea6f47158e33fb81ffa2..49eed6d8dcaff133953733dca48ff491cfc8c1d4 100644 --- a/data_utils/utility.py +++ b/data_utils/utility.py @@ -9,6 +9,7 @@ import os import tarfile import time from Queue import Queue +from threading import Thread from multiprocessing import Process, Manager from paddle.v2.dataset.common import md5file @@ -100,7 +101,8 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): in_queue.put((order_id, sample)) in_queue.put(end_flag) - # define a worker to handle samples from in_queue by mapper and put results to out_queue + # define a worker to handle samples from in_queue by mapper and put results + # to out_queue def handle_worker(in_queue, out_queue, mapper): sample = in_queue.get() while not isinstance(sample, XmapEndSignal): @@ -109,7 +111,8 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): in_queue.put(end_flag) out_queue.put(end_flag) - # define a worker to handle samples from in_queue by mapper and put results to out_queue with order + # define a worker to handle samples from in_queue by mapper and put results + # to out_queue with order def order_handle_worker(in_queue, out_queue, mapper, out_order): ins = in_queue.get() while not isinstance(ins, XmapEndSignal): @@ -123,6 +126,18 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): in_queue.put(end_flag) out_queue.put(end_flag) + # define a thread worker to flush samples from Manager.Queue to Queue + # for acceleration + def flush_worker(in_queue, out_queue): + finish = 0 + while finish < process_num: + sample = in_queue.get() + if isinstance(sample, XmapEndSignal): + finish += 1 + else: + out_queue.put(sample) + out_queue.put(end_flag) + def xreader(): # prepare shared memory manager = Manager() @@ -147,13 +162,16 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): w.daemon = True w.start() + # start a thread to read data from slow Manager.Queue + flush_queue = Queue(buffer_size) + t = Thread(target=flush_worker, args=(out_queue, flush_queue)) + t.daemon = True + t.start() + # get results - finish = 0 - while finish < process_num: - sample = out_queue.get() - if isinstance(sample, XmapEndSignal): - finish += 1 - else: - yield sample + sample = flush_queue.get() + while not isinstance(sample, XmapEndSignal): + yield sample + sample = flush_queue.get() return xreader diff --git a/infer.py b/infer.py index e635f6d0f973b3763320896416b1c4d4aa35b2c4..5d9439cf2b6830f34aa6a15f3bf8500d91f15cb4 100644 --- a/infer.py +++ b/infer.py @@ -116,7 +116,9 @@ def infer(): def main(): print_arguments(args) - paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) + paddle.init(use_gpu=args.use_gpu, + rnn_use_batch=True, + trainer_count=args.trainer_count) infer() diff --git a/test.py b/test.py index 51c725c5f76fd6b1e59d042988546b0b6934e43b..1fe0fbb7c2975b741473376d7b9e1916ab58d36f 100644 --- a/test.py +++ b/test.py @@ -119,7 +119,9 @@ def evaluate(): def main(): print_arguments(args) - paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) + paddle.init(use_gpu=args.use_gpu, + rnn_use_batch=True, + trainer_count=args.trainer_count) evaluate() diff --git a/tools/tune.py b/tools/tune.py index 85c2d73887e4805770dd8c66255f83698ffb46bb..83c71e7dbecbedf2e0e89be4de6be4b0a47c8e04 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -217,7 +217,9 @@ def tune(): def main(): print_arguments(args) - paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) + paddle.init(use_gpu=args.use_gpu, + rnn_use_batch=True, + trainer_count=args.trainer_count) tune() diff --git a/train.py b/train.py index 017cc73f64aaac82f176c54fb6063b6afa22e829..a9c7157692c88ec756b8ee57dd57b362bfec679b 100644 --- a/train.py +++ b/train.py @@ -119,6 +119,7 @@ def train(): def main(): print_arguments(args) paddle.init(use_gpu=args.use_gpu, + rnn_use_batch=True, trainer_count=args.trainer_count, log_clipping=True) train()