未验证 提交 039e501d 编写于 作者: Q qingqing01 提交者: GitHub

Speed data reader in PyramidBox model. (#1016)

* Speed data processing by multi-threads/multi-process.
* Add profiling scripts.
* Use depthwise transposed conv2d.
上级 8761ab3d
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
import time
import numpy as np
import threading
import multiprocessing
try:
import queue
except ImportError:
import Queue as queue
class GeneratorEnqueuer(object):
"""
Builds a queue out of a data generator.
Args:
generator: a generator function which endlessly yields data
use_multiprocessing (bool): use multiprocessing if True,
otherwise use threading.
wait_time (float): time to sleep in-between calls to `put()`.
random_seed (int): Initial seed for workers,
will be incremented by one for each workers.
"""
def __init__(self,
generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self._manager = None
self.seed = random_seed
def start(self, workers=1, max_queue_size=10):
"""
Start worker threads which add data from the generator into the queue.
Args:
workers (int): number of worker threads
max_queue_size (int): queue size
(when full, threads could block on `put()`)
"""
def data_generator_task():
"""
Data generator task.
"""
def task():
if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator)
self.queue.put((generator_output))
else:
time.sleep(self.wait_time)
if not self._use_multiprocessing:
while not self._stop_event.is_set():
with self.genlock:
try:
task()
except Exception:
self._stop_event.set()
break
else:
while not self._stop_event.is_set():
try:
task()
except Exception:
self._stop_event.set()
break
try:
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.genlock = threading.Lock()
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
"""
Returns:
bool: Whether the worker theads are running.
"""
return self._stop_event is not None and not self._stop_event.is_set()
def stop(self, timeout=None):
"""
Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Args:
timeout(int|None): maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()
for thread in self._threads:
if self._use_multiprocessing:
if thread.is_alive():
thread.terminate()
else:
thread.join(timeout)
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
self.queue = None
def get(self):
"""
Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Yields
tuple of data in the queue.
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
import os
import shutil
import numpy as np
import time
import argparse
import functools
import reader
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "parallel")
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 20, "Minibatch size.")
add_arg('num_iteration', int, 10, "Epoch number.")
add_arg('skip_reader', bool, False, "Whether to skip data reader.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('model_save_dir', str, 'output', "The path to save model.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
#yapf: enable
def train(args, config, train_file_list, optimizer_method):
learning_rate = args.learning_rate
batch_size = args.batch_size
height = args.resize_h
width = args.resize_w
use_gpu = args.use_gpu
use_pyramidbox = args.use_pyramidbox
model_save_dir = args.model_save_dir
pretrained_model = args.pretrained_model
skip_reader = args.skip_reader
num_iterations = args.num_iteration
parallel = args.parallel
num_classes = 2
image_shape = [3, height, width]
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
fetches = []
network = PyramidBox(image_shape, num_classes,
sub_network=use_pyramidbox)
if use_pyramidbox:
face_loss, head_loss, loss = network.train()
fetches = [face_loss, head_loss]
else:
loss = network.vgg_ssd_loss()
fetches = [loss]
epocs = 12880 / batch_size
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
]
if optimizer_method == "momentum":
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=boundaries, values=values),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(0.0005),
)
else:
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.0005),
)
optimizer.minimize(loss)
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
start_pass = 0
if pretrained_model:
if pretrained_model.isdigit():
start_pass = int(pretrained_model) + 1
pretrained_model = os.path.join(model_save_dir, pretrained_model)
print("Resume from %s " %(pretrained_model))
if not os.path.exists(pretrained_model):
raise ValueError("The pre-trained model path [%s] does not exist." %
(pretrained_model))
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
if parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_gpu, loss_name=loss.name)
train_reader = reader.train_batch_reader(config, train_file_list, batch_size=batch_size)
def tensor(data, place, lod=None):
t = fluid.core.LoDTensor()
t.set(data, place)
if lod:
t.set_lod(lod)
return t
im, face_box, head_box, labels, lod = next(train_reader)
im_t = tensor(im, place)
box1 = tensor(face_box, place, [lod])
box2 = tensor(head_box, place, [lod])
lbl_t = tensor(labels, place, [lod])
feed_data = {'image': im_t, 'face_box': box1,
'head_box': box2, 'gt_label': lbl_t}
def run(iterations, feed_data):
# global feed_data
reader_time = []
run_time = []
for batch_id in range(iterations):
start_time = time.time()
if not skip_reader:
im, face_box, head_box, labels, lod = next(train_reader)
im_t = tensor(im, place)
box1 = tensor(face_box, place, [lod])
box2 = tensor(head_box, place, [lod])
lbl_t = tensor(labels, place, [lod])
feed_data = {'image': im_t, 'face_box': box1,
'head_box': box2, 'gt_label': lbl_t}
end_time = time.time()
reader_time.append(end_time - start_time)
start_time = time.time()
if parallel:
fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
feed=feed_data)
else:
fetch_vars = exe.run(fluid.default_main_program(),
feed=feed_data,
fetch_list=fetches)
end_time = time.time()
run_time.append(end_time - start_time)
fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
if not args.use_pyramidbox:
print("Batch {0}, loss {1}".format(batch_id, fetch_vars[0]))
else:
print("Batch {0}, face loss {1}, head loss {2}".format(
batch_id, fetch_vars[0], fetch_vars[1]))
return reader_time, run_time
# start-up
run(2, feed_data)
# profiling
start = time.time()
if not parallel:
with profiler.profiler('All', 'total', '/tmp/profile_file'):
reader_time, run_time = run(num_iterations, feed_data)
else:
reader_time, run_time = run(num_iterations, feed_data)
end = time.time()
total_time = end - start
print("Total time: {0}, reader time: {1} s, run time: {2} s".format(
total_time, np.sum(reader_time), np.sum(run_time)))
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = 'data/WIDERFACE/WIDER_train/images/'
train_file_list = 'label/train_gt_widerface.res'
config = reader.Settings(
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
apply_expand=False,
mean_value=[104., 117., 123.],
ap_version='11point')
train(args, config, train_file_list, optimizer_method="momentum")
......@@ -81,10 +81,7 @@ class PyramidBox(object):
if self.is_infer:
return [self.image]
else:
return [
self.image, self.face_box, self.head_box, self.gt_label,
self.difficult
]
return [self.image, self.face_box, self.head_box, self.gt_label]
def _input(self):
self.image = fluid.layers.data(
......@@ -96,8 +93,6 @@ class PyramidBox(object):
name='head_box', shape=[4], dtype='float32', lod_level=1)
self.gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
self.difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
def _vgg(self):
self.conv1, self.pool1 = conv_block(self.image, 2, [64] * 2, [3] * 2)
......@@ -144,7 +139,8 @@ class PyramidBox(object):
stride=2,
groups=ch,
param_attr=w_attr,
bias_attr=False)
bias_attr=False,
use_cudnn=True)
else:
upsampling = fluid.layers.resize_bilinear(
conv1, out_shape=up_to.shape[2:])
......
......@@ -24,6 +24,7 @@ import time
import copy
import random
import cv2
from data_util import GeneratorEnqueuer
class Settings(object):
......@@ -184,20 +185,20 @@ def preprocess(img, bbox_labels, mode, settings, image_path):
return img, sampled_labels
def put_txt_in_dict(input_txt):
def load_file_list(input_txt):
with open(input_txt, 'r') as f_dir:
lines_input_txt = f_dir.readlines()
dict_input_txt = {}
file_dict = {}
num_class = 0
for i in range(len(lines_input_txt)):
tmp_line_txt = lines_input_txt[i].strip('\n\t\r')
if '--' in tmp_line_txt:
if i != 0:
num_class += 1
dict_input_txt[num_class] = []
file_dict[num_class] = []
dict_name = tmp_line_txt
dict_input_txt[num_class].append(tmp_line_txt)
file_dict[num_class].append(tmp_line_txt)
if '--' not in tmp_line_txt:
if len(tmp_line_txt) > 6:
split_str = tmp_line_txt.split(' ')
......@@ -207,11 +208,11 @@ def put_txt_in_dict(input_txt):
y2_max = float(split_str[3])
tmp_line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
dict_input_txt[num_class].append(tmp_line_txt)
file_dict[num_class].append(tmp_line_txt)
else:
dict_input_txt[num_class].append(tmp_line_txt)
file_dict[num_class].append(tmp_line_txt)
return dict_input_txt
return file_dict
def expand_bboxes(bboxes,
......@@ -238,68 +239,106 @@ def expand_bboxes(bboxes,
return expand_boxes
def pyramidbox(settings, file_list, mode, shuffle):
def train_generator(settings, file_list, batch_size, shuffle=True):
file_dict = load_file_list(file_list)
while True:
if shuffle:
random.shuffle(file_dict)
images, face_boxes, head_boxes, label_ids = [], [], [], []
label_offs = [0]
dict_input_txt = {}
dict_input_txt = put_txt_in_dict(file_list)
def reader():
if mode == 'train' and shuffle:
random.shuffle(dict_input_txt)
for index_image in range(len(dict_input_txt)):
image_name = dict_input_txt[index_image][0] + '.jpg'
for index_image in file_dict.keys():
image_name = file_dict[index_image][0] + '.jpg'
image_path = os.path.join(settings.data_dir, image_name)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
# layout: label | xmin | ymin | xmax | ymax
if mode == 'train':
bbox_labels = []
for index_box in range(len(dict_input_txt[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = dict_input_txt[index_image][
index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(1)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
bbox_sample.append(float(ymax) / im_height)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings,
image_path)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 1:5]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
yield im, boxes, expand_bboxes(boxes), lbls, difficults
if mode == 'test':
yield im, image_path
return reader
bbox_labels = []
for index_box in range(len(file_dict[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = file_dict[index_image][index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(1)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
bbox_sample.append(float(ymax) / im_height)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, "train", settings,
image_path)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
face_box = sample_labels[:, 1:5]
head_box = expand_bboxes(face_box)
label = [1] * len(face_box)
images.append(im)
face_boxes.extend(face_box)
head_boxes.extend(head_box)
label_ids.extend(label)
label_offs.append(label_offs[-1] + len(face_box))
if len(images) == batch_size:
images = np.array(images).astype('float32')
face_boxes = np.array(face_boxes).astype('float32')
head_boxes = np.array(head_boxes).astype('float32')
label_ids = np.array(label_ids).astype('int32')
yield images, face_boxes, head_boxes, label_ids, label_offs
images, face_boxes, head_boxes = [], [], []
label_ids, label_offs = [], [0]
def train_batch_reader(settings,
file_list,
batch_size,
shuffle=True,
num_workers=8):
try:
enqueuer = GeneratorEnqueuer(
train_generator(settings, file_list, batch_size, shuffle),
use_multiprocessing=False)
enqueuer.start(max_queue_size=24, workers=num_workers)
generator_output = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_output = enqueuer.queue.get()
break
else:
time.sleep(0.01)
yield generator_output
generator_output = None
finally:
if enqueuer is not None:
enqueuer.stop()
def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)
def test(settings, file_list):
file_dict = load_file_list(file_list)
def reader():
for index_image in file_dict.keys():
image_name = file_dict[index_image][0] + '.jpg'
image_path = os.path.join(settings.data_dir, image_name)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
yield im, image_path
def test(settings, file_list):
return pyramidbox(settings, file_list, 'test', False)
return reader
def infer(settings, image_path):
......
......@@ -58,8 +58,9 @@ def train(args, config, train_file_list, optimizer_method):
loss = network.vgg_ssd_loss()
fetches = [loss]
epocs = 12880 / batch_size
boundaries = [epocs * 50, epocs * 80, epocs * 120, epocs * 140]
steps_per_pass = 12880 / batch_size
boundaries = [steps_per_pass * 50, steps_per_pass * 80,
steps_per_pass * 120, steps_per_pass * 140]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
......@@ -104,9 +105,7 @@ def train(args, config, train_file_list, optimizer_method):
train_exe = fluid.ParallelExecutor(
use_cuda=use_gpu, loss_name=loss.name)
train_reader = paddle.batch(
reader.train(config, train_file_list), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=network.feeds())
train_reader = reader.train_batch_reader(config, train_file_list, batch_size=batch_size)
def save_model(postfix):
model_path = os.path.join(model_save_dir, postfix)
......@@ -115,20 +114,34 @@ def train(args, config, train_file_list, optimizer_method):
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
def tensor(data, place, lod=None):
t = fluid.core.LoDTensor()
t.set(data, place)
if lod:
t.set_lod(lod)
return t
for pass_id in range(start_pass, num_passes):
start_time = time.time()
prev_start_time = start_time
end_time = 0
for batch_id, data in enumerate(train_reader()):
for batch_id in range(steps_per_pass):
im, face_box, head_box, labels, lod = next(train_reader)
im_t = tensor(im, place)
box1 = tensor(face_box, place, [lod])
box2 = tensor(head_box, place, [lod])
lbl_t = tensor(labels, place, [lod])
feeding = {'image': im_t, 'face_box': box1,
'head_box': box2, 'gt_label': lbl_t}
prev_start_time = start_time
start_time = time.time()
if len(data) < 2 * devices_num: continue
if args.parallel:
fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
feed=feeder.feed(data))
feed=feeding)
else:
fetch_vars = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
feed=feeding,
fetch_list=fetches)
end_time = time.time()
fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册