提交 0d4df0f6 编写于 作者: D dengkaipeng

refine reader.

上级 30717188
......@@ -13,7 +13,7 @@
## Installation
Running sample code in this directory requires PaddelPaddle Fluid v.1.1.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle) and make an update.
Running sample code in this directory requires PaddelPaddle Fluid v.1.4 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/install_doc.html#paddlepaddle) and make an update.
## Introduction
......
......@@ -13,7 +13,7 @@
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.1.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.4或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 简介
......
......@@ -6,14 +6,11 @@ import time
import numpy as np
import threading
import multiprocessing
import cv2
try:
import queue
except ImportError:
import Queue as queue
import image_utils
class GeneratorEnqueuer(object):
"""
......@@ -38,11 +35,11 @@ class GeneratorEnqueuer(object):
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queues = []
self.queue = None
self._manager = None
self.seed = random_seed
def start(self, workers=1, max_queue_size=10, random_sizes=[608]):
def start(self, workers=1, max_queue_size=10):
"""
Start worker threads which add data from the generator into the queue.
......@@ -52,38 +49,16 @@ class GeneratorEnqueuer(object):
(when full, threads could block on `put()`)
"""
self.random_sizes = random_sizes
self.size_num = len(random_sizes)
def data_generator_task():
"""
Data generator task.
"""
def task():
if len(self.queues) > 0:
if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator)
queue_idx = 0
while(True):
if self.queues[queue_idx].full():
queue_idx = (queue_idx + 1) % self.size_num
time.sleep(0.02)
continue
else:
size = self.random_sizes[queue_idx]
for g in generator_output:
g[0] = g[0].transpose((1, 2, 0))
g[0] = image_utils.random_interp(g[0], size)
g[0] = g[0].transpose((2, 0, 1))
try:
self.queues[queue_idx].put_nowait(generator_output)
except:
continue
else:
break
# self.queue.put((generator_output))
self.queue.put((generator_output))
else:
time.sleep(self.wait_time)
......@@ -106,14 +81,11 @@ class GeneratorEnqueuer(object):
try:
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
for i in range(self.size_num):
self.queues.append(self._manager.Queue(maxsize=max_queue_size))
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.genlock = threading.Lock()
# self.queue = queue.Queue()
for i in range(self.size_num):
self.queues.append(queue.Queue())
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
......@@ -162,7 +134,7 @@ class GeneratorEnqueuer(object):
self._stop_event = None
self.queue = None
def get(self, queue_idx):
def get(self):
"""
Creates a generator to extract data from the queue.
Skip the data if it is `None`.
......@@ -171,8 +143,8 @@ class GeneratorEnqueuer(object):
tuple of data in the queue.
"""
while self.is_running():
if not self.queues[queue_idx].empty():
inputs = self.queues[queue_idx].get()
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
......
......@@ -209,7 +209,7 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores)
# img = random_interp(img, size, cv2.INTER_LINEAR)
img = random_interp(img, size)
img, gtboxes = random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores)
......
......@@ -221,8 +221,7 @@ class DataSetReader(object):
yield batch_out
batch_out = []
total_iter += 1
if total_iter % 10 == 0:
img_size = get_img_size(size, random_sizes)
img_size = get_img_size(size, random_sizes)
elif mode == 'test':
imgs = self._parse_images_by_mode(mode)
......@@ -253,12 +252,10 @@ def train(size=416,
shuffle=True,
mixup_iter=0,
random_sizes=[],
interval=10,
pyreader_num=1,
num_workers=2,
max_queue=4,
num_workers=12,
max_queue=36,
use_multiprocessing=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/pyreader_num), random_sizes)
generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/num_workers), random_sizes)
if not use_multiprocessing:
return generator
......@@ -271,26 +268,18 @@ def train(size=416,
def reader():
try:
enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=True)
enqueuer.start(max_queue_size=max_queue, workers=num_workers, random_sizes=random_sizes)
infinite_reader(), use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
np.random.seed(1000)
intervals = pyreader_num * interval
cnt = 0
idx = len(random_sizes) - 1
while True:
while enqueuer.is_running():
if not enqueuer.queues[idx].empty():
generator_out = enqueuer.queues[idx].get()
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get()
break
else:
time.sleep(0.02)
yield generator_out
generator_out = None
cnt += 1
if cnt % intervals == 0:
idx = np.random.randint(len(random_sizes))
print("Resizing: ", random_sizes[idx])
finally:
if enqueuer is not None:
enqueuer.stop()
......
......@@ -91,7 +91,7 @@ def train():
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter
if cfg.use_pyreader:
train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, interval=10, pyreader_num=devices_num, use_multiprocessing=cfg.use_multiprocess)
train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册