提交 a3f1ba3a 编写于 作者: X Xinghai Sun

Turn on rnn_use_batch of Paddle for accelartion.

Improve xmap_reader_mp by adding a flush thread.
上级 efef5d92
......@@ -9,6 +9,7 @@ import os
import tarfile
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager
from paddle.v2.dataset.common import md5file
......@@ -100,7 +101,8 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
in_queue.put((order_id, sample))
in_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results to out_queue
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
......@@ -109,7 +111,8 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end_flag)
out_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results to out_queue with order
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
......@@ -123,6 +126,18 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end_flag)
out_queue.put(end_flag)
# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
def xreader():
# prepare shared memory
manager = Manager()
......@@ -147,13 +162,16 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
w.daemon = True
w.start()
# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
# get results
finish = 0
while finish < process_num:
sample = out_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
yield sample
sample = flush_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = flush_queue.get()
return xreader
......@@ -116,7 +116,9 @@ def infer():
def main():
print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count)
infer()
......
......@@ -119,7 +119,9 @@ def evaluate():
def main():
print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count)
evaluate()
......
......@@ -217,7 +217,9 @@ def tune():
def main():
print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count)
tune()
......
......@@ -119,6 +119,7 @@ def train():
def main():
print_arguments(args)
paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count,
log_clipping=True)
train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册