未验证 提交 01f60a0b 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the use python multi-processes to read data in deeplabv3+ (#2447)

* Enable the use of python multi-processes to read data.

* Add the setting of default value of some paddle flags, such as GC.
上级 2a049dd8
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
import time
import numpy as np
import threading
import multiprocessing
try:
import queue
except ImportError:
import Queue as queue
class GeneratorEnqueuer(object):
"""
Builds a queue out of a data generator.
Args:
generator: a generator function which endlessly yields data
use_multiprocessing (bool): use multiprocessing if True,
otherwise use threading.
wait_time (float): time to sleep in-between calls to `put()`.
random_seed (int): Initial seed for workers,
will be incremented by one for each workers.
"""
def __init__(self,
generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self._manager = None
self.seed = random_seed
def start(self, workers=1, max_queue_size=10):
"""
Start worker threads which add data from the generator into the queue.
Args:
workers (int): number of worker threads
max_queue_size (int): queue size
(when full, threads could block on `put()`)
"""
def data_generator_task():
"""
Data generator task.
"""
def task():
if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator)
self.queue.put((generator_output))
else:
time.sleep(self.wait_time)
if not self._use_multiprocessing:
while not self._stop_event.is_set():
with self.genlock:
try:
task()
except Exception:
self._stop_event.set()
break
else:
while not self._stop_event.is_set():
try:
task()
except Exception:
self._stop_event.set()
break
try:
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.genlock = threading.Lock()
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
"""
Returns:
bool: Whether the worker theads are running.
"""
return self._stop_event is not None and not self._stop_event.is_set()
def stop(self, timeout=None):
"""
Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Args:
timeout(int|None): maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()
for thread in self._threads:
if self._use_multiprocessing:
if thread.is_alive():
thread.join(timeout)
else:
thread.join(timeout)
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
self.queue = None
def get(self):
"""
Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Yields
tuple of data in the queue.
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
...@@ -5,6 +5,8 @@ import cv2 ...@@ -5,6 +5,8 @@ import cv2
import numpy as np import numpy as np
import os import os
import six import six
import time
from data_utils import GeneratorEnqueuer
default_config = { default_config = {
"shuffle": True, "shuffle": True,
...@@ -138,21 +140,55 @@ class CityscapeDataset: ...@@ -138,21 +140,55 @@ class CityscapeDataset:
self.next_img() self.next_img()
return np.array(imgs), np.array(labels), names return np.array(imgs), np.array(labels), names
def get_batch_generator(self, batch_size, total_step): def get_batch_generator(self,
batch_size,
total_step,
num_workers=8,
max_queue=32,
use_multiprocessing=True):
def do_get_batch(): def do_get_batch():
for i in range(total_step): iter_id = 0
while True:
imgs, labels, names = self.get_batch(batch_size) imgs, labels, names = self.get_batch(batch_size)
labels = labels.astype(np.int32)[:, :, :, 0] labels = labels.astype(np.int32)[:, :, :, 0]
imgs = imgs[:, :, :, ::-1].transpose( imgs = imgs[:, :, :, ::-1].transpose(
0, 3, 1, 2).astype(np.float32) / (255.0 / 2) - 1 0, 3, 1, 2).astype(np.float32) / (255.0 / 2) - 1
yield i, imgs, labels, names yield imgs, labels, names
if not use_multiprocessing:
iter_id += 1
if iter_id >= total_step:
break
batches = do_get_batch() batches = do_get_batch()
try: if not use_multiprocessing:
from prefetch_generator import BackgroundGenerator try:
batches = BackgroundGenerator(batches, 100) from prefetch_generator import BackgroundGenerator
except: batches = BackgroundGenerator(batches, 100)
print( except:
"You can install 'prefetch_generator' for acceleration of data reading." print(
) "You can install 'prefetch_generator' for acceleration of data reading."
return batches )
return batches
def reader():
try:
enqueuer = GeneratorEnqueuer(
batches, use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
for i in range(total_step):
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get()
break
else:
time.sleep(0.02)
yield generator_out
generator_out = None
enqueuer.stop()
finally:
if enqueuer is not None:
enqueuer.stop()
data_gen = reader()
return data_gen
...@@ -2,8 +2,27 @@ from __future__ import absolute_import ...@@ -2,8 +2,27 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
if 'FLAGS_fraction_of_gpu_memory_to_use' not in os.environ:
os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = '0.98'
def set_paddle_flags(flags):
for key, value in flags.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags({
'FLAGS_eager_delete_tensor_gb': 0, # enable GC
# You can omit the following settings, because the default
# value of FLAGS_memory_fraction_of_eager_deletion is 1,
# and default value of FLAGS_fast_eager_deletion_mode is 1
'FLAGS_memory_fraction_of_eager_deletion': 1,
'FLAGS_fast_eager_deletion_mode': 1,
# Setting the default used gpu memory
'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -36,6 +55,7 @@ add_arg('memory_optimize', bool, True, "Using memory optimizer.") ...@@ -36,6 +55,7 @@ add_arg('memory_optimize', bool, True, "Using memory optimizer.")
add_arg('norm_type', str, 'bn', "Normalization type, should be 'bn' or 'gn'.") add_arg('norm_type', str, 'bn', "Normalization type, should be 'bn' or 'gn'.")
add_arg('profile', bool, False, "Enable profiler.") add_arg('profile', bool, False, "Enable profiler.")
add_arg('use_py_reader', bool, True, "Use py reader.") add_arg('use_py_reader', bool, True, "Use py reader.")
add_arg("num_workers", int, 8, "The number of python processes used to read and preprocess data.")
parser.add_argument( parser.add_argument(
'--enable_ce', '--enable_ce',
action='store_true', action='store_true',
...@@ -192,13 +212,14 @@ if args.use_py_reader: ...@@ -192,13 +212,14 @@ if args.use_py_reader:
def data_gen(): def data_gen():
batches = dataset.get_batch_generator( batches = dataset.get_batch_generator(
batch_size // fluid.core.get_cuda_device_count(), batch_size // fluid.core.get_cuda_device_count(),
total_step * fluid.core.get_cuda_device_count()) total_step * fluid.core.get_cuda_device_count(),
use_multiprocessing=True, num_workers=args.num_workers)
for b in batches: for b in batches:
yield b[1], b[2] yield b[0], b[1]
py_reader.decorate_tensor_provider(data_gen) py_reader.decorate_tensor_provider(data_gen)
py_reader.start() py_reader.start()
else: else:
batches = dataset.get_batch_generator(batch_size, total_step) batches = dataset.get_batch_generator(batch_size, total_step, use_multiprocessing=True, num_workers=args.num_workers)
total_time = 0.0 total_time = 0.0
epoch_idx = 0 epoch_idx = 0
train_loss = 0 train_loss = 0
...@@ -207,9 +228,8 @@ with profile_context(args.profile): ...@@ -207,9 +228,8 @@ with profile_context(args.profile):
for i in range(total_step): for i in range(total_step):
epoch_idx += 1 epoch_idx += 1
begin_time = time.time() begin_time = time.time()
prev_start_time = time.time()
if not args.use_py_reader: if not args.use_py_reader:
_, imgs, labels, names = next(batches) imgs, labels, names = next(batches)
train_loss, = exe.run(binary, train_loss, = exe.run(binary,
feed={'img': imgs, feed={'img': imgs,
'label': labels}, fetch_list=[loss_mean]) 'label': labels}, fetch_list=[loss_mean])
...@@ -221,8 +241,8 @@ with profile_context(args.profile): ...@@ -221,8 +241,8 @@ with profile_context(args.profile):
if i % 100 == 0: if i % 100 == 0:
print("Model is saved to", args.save_weights_path) print("Model is saved to", args.save_weights_path)
save_model() save_model()
print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f}".format( print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f} s".format(
i, train_loss, end_time - prev_start_time)) i, train_loss, end_time - begin_time))
print("Training done. Model is saved to", args.save_weights_path) print("Training done. Model is saved to", args.save_weights_path)
save_model() save_model()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册