提交 0d4df0f6 编写于 作者: D dengkaipeng

refine reader.

上级 30717188
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
## Installation ## Installation
Running sample code in this directory requires PaddelPaddle Fluid v.1.1.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle) and make an update. Running sample code in this directory requires PaddelPaddle Fluid v.1.4 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/install_doc.html#paddlepaddle) and make an update.
## Introduction ## Introduction
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
## 安装 ## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.1.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。 在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.4或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 简介 ## 简介
......
...@@ -6,14 +6,11 @@ import time ...@@ -6,14 +6,11 @@ import time
import numpy as np import numpy as np
import threading import threading
import multiprocessing import multiprocessing
import cv2
try: try:
import queue import queue
except ImportError: except ImportError:
import Queue as queue import Queue as queue
import image_utils
class GeneratorEnqueuer(object): class GeneratorEnqueuer(object):
""" """
...@@ -38,11 +35,11 @@ class GeneratorEnqueuer(object): ...@@ -38,11 +35,11 @@ class GeneratorEnqueuer(object):
self._use_multiprocessing = use_multiprocessing self._use_multiprocessing = use_multiprocessing
self._threads = [] self._threads = []
self._stop_event = None self._stop_event = None
self.queues = [] self.queue = None
self._manager = None self._manager = None
self.seed = random_seed self.seed = random_seed
def start(self, workers=1, max_queue_size=10, random_sizes=[608]): def start(self, workers=1, max_queue_size=10):
""" """
Start worker threads which add data from the generator into the queue. Start worker threads which add data from the generator into the queue.
...@@ -52,38 +49,16 @@ class GeneratorEnqueuer(object): ...@@ -52,38 +49,16 @@ class GeneratorEnqueuer(object):
(when full, threads could block on `put()`) (when full, threads could block on `put()`)
""" """
self.random_sizes = random_sizes
self.size_num = len(random_sizes)
def data_generator_task(): def data_generator_task():
""" """
Data generator task. Data generator task.
""" """
def task(): def task():
if len(self.queues) > 0: if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator) generator_output = next(self._generator)
queue_idx = 0 self.queue.put((generator_output))
while(True):
if self.queues[queue_idx].full():
queue_idx = (queue_idx + 1) % self.size_num
time.sleep(0.02)
continue
else:
size = self.random_sizes[queue_idx]
for g in generator_output:
g[0] = g[0].transpose((1, 2, 0))
g[0] = image_utils.random_interp(g[0], size)
g[0] = g[0].transpose((2, 0, 1))
try:
self.queues[queue_idx].put_nowait(generator_output)
except:
continue
else:
break
# self.queue.put((generator_output))
else: else:
time.sleep(self.wait_time) time.sleep(self.wait_time)
...@@ -106,14 +81,11 @@ class GeneratorEnqueuer(object): ...@@ -106,14 +81,11 @@ class GeneratorEnqueuer(object):
try: try:
if self._use_multiprocessing: if self._use_multiprocessing:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
for i in range(self.size_num): self.queue = self._manager.Queue(maxsize=max_queue_size)
self.queues.append(self._manager.Queue(maxsize=max_queue_size))
self._stop_event = multiprocessing.Event() self._stop_event = multiprocessing.Event()
else: else:
self.genlock = threading.Lock() self.genlock = threading.Lock()
# self.queue = queue.Queue() self.queue = queue.Queue()
for i in range(self.size_num):
self.queues.append(queue.Queue())
self._stop_event = threading.Event() self._stop_event = threading.Event()
for _ in range(workers): for _ in range(workers):
if self._use_multiprocessing: if self._use_multiprocessing:
...@@ -162,7 +134,7 @@ class GeneratorEnqueuer(object): ...@@ -162,7 +134,7 @@ class GeneratorEnqueuer(object):
self._stop_event = None self._stop_event = None
self.queue = None self.queue = None
def get(self, queue_idx): def get(self):
""" """
Creates a generator to extract data from the queue. Creates a generator to extract data from the queue.
Skip the data if it is `None`. Skip the data if it is `None`.
...@@ -171,8 +143,8 @@ class GeneratorEnqueuer(object): ...@@ -171,8 +143,8 @@ class GeneratorEnqueuer(object):
tuple of data in the queue. tuple of data in the queue.
""" """
while self.is_running(): while self.is_running():
if not self.queues[queue_idx].empty(): if not self.queue.empty():
inputs = self.queues[queue_idx].get() inputs = self.queue.get()
if inputs is not None: if inputs is not None:
yield inputs yield inputs
else: else:
......
...@@ -209,7 +209,7 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None): ...@@ -209,7 +209,7 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img) img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means) img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores) img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores)
# img = random_interp(img, size, cv2.INTER_LINEAR) img = random_interp(img, size)
img, gtboxes = random_flip(img, gtboxes) img, gtboxes = random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores) gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores)
......
...@@ -221,8 +221,7 @@ class DataSetReader(object): ...@@ -221,8 +221,7 @@ class DataSetReader(object):
yield batch_out yield batch_out
batch_out = [] batch_out = []
total_iter += 1 total_iter += 1
if total_iter % 10 == 0: img_size = get_img_size(size, random_sizes)
img_size = get_img_size(size, random_sizes)
elif mode == 'test': elif mode == 'test':
imgs = self._parse_images_by_mode(mode) imgs = self._parse_images_by_mode(mode)
...@@ -253,12 +252,10 @@ def train(size=416, ...@@ -253,12 +252,10 @@ def train(size=416,
shuffle=True, shuffle=True,
mixup_iter=0, mixup_iter=0,
random_sizes=[], random_sizes=[],
interval=10, num_workers=12,
pyreader_num=1, max_queue=36,
num_workers=2,
max_queue=4,
use_multiprocessing=True): use_multiprocessing=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/pyreader_num), random_sizes) generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/num_workers), random_sizes)
if not use_multiprocessing: if not use_multiprocessing:
return generator return generator
...@@ -271,26 +268,18 @@ def train(size=416, ...@@ -271,26 +268,18 @@ def train(size=416,
def reader(): def reader():
try: try:
enqueuer = GeneratorEnqueuer( enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=True) infinite_reader(), use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers, random_sizes=random_sizes) enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None generator_out = None
np.random.seed(1000)
intervals = pyreader_num * interval
cnt = 0
idx = len(random_sizes) - 1
while True: while True:
while enqueuer.is_running(): while enqueuer.is_running():
if not enqueuer.queues[idx].empty(): if not enqueuer.queue.empty():
generator_out = enqueuer.queues[idx].get() generator_out = enqueuer.queue.get()
break break
else: else:
time.sleep(0.02) time.sleep(0.02)
yield generator_out yield generator_out
generator_out = None generator_out = None
cnt += 1
if cnt % intervals == 0:
idx = np.random.randint(len(random_sizes))
print("Resizing: ", random_sizes[idx])
finally: finally:
if enqueuer is not None: if enqueuer is not None:
enqueuer.stop() enqueuer.stop()
......
...@@ -91,7 +91,7 @@ def train(): ...@@ -91,7 +91,7 @@ def train():
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter
if cfg.use_pyreader: if cfg.use_pyreader:
train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, interval=10, pyreader_num=devices_num, use_multiprocessing=cfg.use_multiprocess) train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader) py_reader.decorate_paddle_reader(train_reader)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册