提交 8cae62d7 编写于 作者: S Sylwester Fraczek

added crnn-ctc mkldnn benchmarking

上级 5abba732
......@@ -118,6 +118,10 @@ env CUDA_VISIABLE_DEVICES=0 python ctc_train.py
```
env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True
```
使用默认数据在CPU上训练:
```
env OMP_NUM_THREADS=<num_of_physical_cores> KMP_AFFINITY=granularity=fine,compact,1,0 python ctc_train.py --use_gpu False --use_mkldnn True --parallel=False
```
执行`python ctc_train.py --help`可查看更多使用方式和参数详细说明。
......
......@@ -4,6 +4,8 @@ import paddle.fluid as fluid
def conv_bn_pool(input,
group,
out_ch,
use_mkldnn,
use_cudnn,
act="relu",
param=None,
bias=None,
......@@ -18,20 +20,23 @@ def conv_bn_pool(input,
padding=1,
param_attr=param if param_0 is None else param_0,
act=None, # LinearActivation
use_cudnn=True)
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
#tmp = fluid.layers.Print(tmp)
tmp = fluid.layers.batch_norm(
input=tmp,
act=act,
param_attr=param,
bias_attr=bias,
use_mkldnn=use_mkldnn,
is_test=is_test)
tmp = fluid.layers.pool2d(
input=tmp,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=True,
use_cudnn=use_cudnn,
use_mkldnn=use_mkldnn,
ceil_mode=True)
return tmp
......@@ -40,6 +45,8 @@ def conv_bn_pool(input,
def ocr_convs(input,
num,
with_bn,
use_mkldnn,
use_cudnn,
regularizer=None,
gradient_clip=None,
is_test=False):
......@@ -59,16 +66,46 @@ def ocr_convs(input,
initializer=fluid.initializer.Normal(0.0, 0.01))
tmp = input
tmp = conv_bn_pool(
tmp, 2, [16, 16], param=w1, bias=b, param_0=w0, is_test=is_test)
tmp,
2, [16, 16],
param=w1,
bias=b,
param_0=w0,
is_test=is_test,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
tmp = conv_bn_pool(tmp, 2, [32, 32], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(tmp, 2, [64, 64], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(
tmp,
2, [32, 32],
param=w1,
bias=b,
is_test=is_test,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
tmp = conv_bn_pool(
tmp,
2, [64, 64],
param=w1,
bias=b,
is_test=is_test,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
tmp = conv_bn_pool(
tmp,
2, [128, 128],
param=w1,
bias=b,
is_test=is_test,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
return tmp
def encoder_net(images,
num_classes,
use_mkldnn,
use_cudnn,
rnn_hidden_size=200,
regularizer=None,
gradient_clip=None,
......@@ -79,6 +116,8 @@ def encoder_net(images,
True,
regularizer=regularizer,
gradient_clip=gradient_clip,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn,
is_test=is_test)
sliced_feature = fluid.layers.im2sequence(
input=conv_features,
......@@ -102,12 +141,13 @@ def encoder_net(images,
fc_1 = fluid.layers.fc(input=sliced_feature,
size=rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=bias_attr_nobias)
bias_attr=bias_attr_nobias,
use_mkldnn=use_mkldnn)
fc_2 = fluid.layers.fc(input=sliced_feature,
size=rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=bias_attr_nobias)
bias_attr=bias_attr_nobias,
use_mkldnn=use_mkldnn)
gru_forward = fluid.layers.dynamic_gru(
input=fc_1,
size=rnn_hidden_size,
......@@ -134,6 +174,7 @@ def encoder_net(images,
fc_out = fluid.layers.fc(input=[gru_forward, gru_backward],
size=num_classes + 1,
param_attr=w_attr,
use_mkldnn=use_mkldnn,
bias_attr=b_attr)
return fc_out
......@@ -145,7 +186,12 @@ def ctc_train_net(images, label, args, num_classes):
MOMENTUM = 0.9
regularizer = fluid.regularizer.L2Decay(L2_RATE)
fc_out = encoder_net(images, num_classes, regularizer=regularizer)
fc_out = encoder_net(
images,
num_classes,
regularizer=regularizer,
use_mkldnn=args.use_mkldnn,
use_cudnn=True if args.use_gpu else False)
cost = fluid.layers.warpctc(
input=fc_out, label=label, blank=num_classes, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)
......@@ -167,13 +213,23 @@ def ctc_train_net(images, label, args, num_classes):
return sum_cost, error_evaluator, inference_program, model_average
def ctc_infer(images, num_classes):
fc_out = encoder_net(images, num_classes, is_test=True)
def ctc_infer(images, num_classes, use_mkldnn, use_cudnn):
fc_out = encoder_net(
images,
num_classes,
is_test=True,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
return fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes)
def ctc_eval(images, label, num_classes):
fc_out = encoder_net(images, num_classes, is_test=True)
def ctc_eval(images, label, num_classes, use_mkldnn, use_cudnn):
fc_out = encoder_net(
images,
num_classes,
is_test=True,
use_mkldnn=use_mkldnn,
use_cudnn=use_cudnn)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
......
......@@ -25,7 +25,7 @@ class DataGenerator(object):
def __init__(self):
pass
def train_reader(self, img_root_dir, img_label_list, batchsize):
def train_reader(self, img_root_dir, img_label_list, batchsize, cycle):
'''
Reader interface for training.
......@@ -65,24 +65,29 @@ class DataGenerator(object):
def reader():
sizes = len(img_label_lines) / batchsize
for i in range(sizes):
result = []
sz = [0, 0]
for j in range(batchsize):
line = img_label_lines[i * batchsize + j]
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L') #zhuanhuidu
if j == 0:
sz = img.size
img = img.resize((sz[0], sz[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
result.append([img, label])
yield result
if sizes == 0:
raise ValueError('Batch size is bigger than the dataset size.')
while True:
for i in range(sizes):
result = []
sz = [0, 0]
for j in range(batchsize):
line = img_label_lines[i * batchsize + j]
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L') #zhuanhuidu
if j == 0:
sz = img.size
img = img.resize((sz[0], sz[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
result.append([img, label])
yield result
if not cycle:
break
return reader
......@@ -111,7 +116,7 @@ class DataGenerator(object):
return reader
def infer_reader(self, img_root_dir=None, img_label_list=None):
def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
'''A reader interface for inference.
:param img_root_dir: The root path of the images for training.
......@@ -125,8 +130,8 @@ class DataGenerator(object):
'''
def reader():
if img_label_list is not None:
for line in open(img_label_list):
def yield_img_and_label(lines):
for line in lines:
if img_root_dir is not None:
# h, w, img_name, labels
img_name = line.split(' ')[2]
......@@ -138,6 +143,16 @@ class DataGenerator(object):
img = img[np.newaxis, ...]
label = [int(c) for c in line.split(' ')[3].split(',')]
yield img, label
if img_label_list is not None:
lines = []
with open(img_label_list) as f:
lines = f.readlines()
for img, label in yield_img_and_label(lines):
yield img, label
while cycle:
for img, label in yield_img_and_label(lines):
yield img, label
else:
while True:
img_path = raw_input("Please input the path of image: ")
......@@ -161,14 +176,15 @@ def data_shape():
return DATA_SHAPE
def train(batch_size, train_images_dir=None, train_list_file=None):
def train(batch_size, train_images_dir=None, train_list_file=None, cycle=False):
generator = DataGenerator()
if train_images_dir is None:
data_dir = download_data()
train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
if train_list_file is None:
train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
return generator.train_reader(train_images_dir, train_list_file, batch_size)
return generator.train_reader(train_images_dir, train_list_file, batch_size,
cycle)
def test(batch_size=1, test_images_dir=None, test_list_file=None):
......@@ -182,10 +198,14 @@ def test(batch_size=1, test_images_dir=None, test_list_file=None):
generator.test_reader(test_images_dir, test_list_file), batch_size)
def inference(infer_images_dir=None, infer_list_file=None):
def inference(batch_size,
infer_images_dir=None,
infer_list_file=None,
cycle=False):
generator = DataGenerator()
return paddle.batch(
generator.infer_reader(infer_images_dir, infer_list_file), 1)
generator.infer_reader(infer_images_dir, infer_list_file, cycle),
batch_size)
def download_data():
......
"""Trainer for OCR CTC model."""
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
......@@ -10,6 +9,8 @@ import time
import os
import numpy as np
import paddle.fluid.profiler as profiler
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
......@@ -20,11 +21,16 @@ add_arg('save_model_period', int, 15000, "Save model period. '-1' means n
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 15625, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 15625, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether to use parallel training.")
add_arg('use_mkldnn', bool, False, "Whether to use mkldnn to train.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('iterations', int, 0, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
# yapf: enable
......@@ -49,7 +55,8 @@ def train(args, data_reader=ctc_reader):
train_reader = data_reader.train(
args.batch_size,
train_images_dir=train_images,
train_list_file=train_list)
train_list_file=train_list,
cycle=args.iterations > 0)
test_reader = data_reader.test(
test_images_dir=test_images, test_list_file=test_list)
......@@ -74,7 +81,7 @@ def train(args, data_reader=ctc_reader):
error_evaluator.reset(exe)
if args.parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=sum_cost.name)
use_cuda=True if args.use_gpu else False, loss_name=sum_cost.name)
fetch_vars = [sum_cost] + error_evaluator.metrics
......@@ -82,11 +89,11 @@ def train(args, data_reader=ctc_reader):
var_names = [var.name for var in fetch_vars]
if args.parallel:
results = train_exe.run(var_names,
feed=get_feeder_data(data, place))
feed_dict=get_feeder_data(data, place))
results = [np.array(result).sum() for result in results]
else:
results = exe.run(feed=get_feeder_data(data, place),
fetch_list=fetch_vars)
results = train_exe.run(feed=get_feeder_data(data, place),
fetch_list=fetch_vars)
results = [result[0] for result in results]
return results
......@@ -108,9 +115,21 @@ def train(args, data_reader=ctc_reader):
batch_id = 1
total_loss = 0.0
total_seq_error = 0.0
batch_times = []
iters = 0
# train a pass
for data in train_reader():
if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
break
if iters < args.skip_batch_num:
print("Warm-up iteration")
if iters == args.skip_batch_num:
profiler.reset_profiler()
start = time.time()
results = train_one_batch(data)
batch_time = time.time() - start
fps = args.batch_size / batch_time
batch_times.append(batch_time)
total_loss += results[0]
total_seq_error += results[2]
# training log
......@@ -122,7 +141,7 @@ def train(args, data_reader=ctc_reader):
sys.stdout.flush()
# evaluate
if batch_id % args.eval_period == 0:
if not args.skip_test and batch_id % args.eval_period == 0:
if model_average:
with model_average.apply(exe):
test(pass_id, batch_id)
......@@ -138,12 +157,37 @@ def train(args, data_reader=ctc_reader):
save_model(args, exe, pass_id, batch_id)
batch_id += 1
iters += 1
# Postprocess benchmark data
latencies = batch_times[args.skip_batch_num:]
latency_avg = np.average(latencies)
latency_pc99 = np.percentile(latencies, 99)
fpses = np.divide(args.batch_size, latencies)
fps_avg = np.average(fpses)
fps_pc99 = np.percentile(fpses, 1)
# Benchmark output
print('\nTotal examples (incl. warm-up): %d' %
(iters * args.batch_size))
print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
latency_pc99))
print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg,
fps_pc99))
def main():
args = parser.parse_args()
print_arguments(args)
train(args, data_reader=ctc_reader)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(args, data_reader=ctc_reader)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(args, data_reader=ctc_reader)
else:
train(args, data_reader=ctc_reader)
if __name__ == "__main__":
......
......@@ -14,7 +14,8 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('model_path', str, None, "The model path to be used for inference.")
add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
add_arg('use_gpu', bool, True, "Whether use GPU to eval.")
add_arg('use_gpu', bool, True, "Whether use GPU to eval.")
add_arg('use_mkldnn', bool, False, "Whether to use MKLDNN to eval.")
# yapf: enable
......@@ -26,7 +27,8 @@ def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
evaluator, cost = eval(images, label, num_classes)
evaluator, cost = eval(images, label, num_classes, args.use_mkldnn, True
if args.use_gpu else False)
# data reader
test_reader = data_reader.test(
......@@ -35,7 +37,7 @@ def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
# prepare environment
place = fluid.CPUPlace()
if use_gpu:
if args.use_gpu:
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......
import paddle.v2 as paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
import numpy as np
......@@ -7,6 +8,7 @@ import ctc_reader
import argparse
import functools
import os
import time
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -15,7 +17,12 @@ add_arg('model_path', str, None, "The model path to be used for infer
add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
add_arg('dict', str, None, "The dictionary. The result of inference will be index sequence if the dictionary was None.")
add_arg('use_gpu', bool, True, "Whether use GPU to infer.")
add_arg('use_gpu', bool, True, "Whether use GPU to infer.")
add_arg('use_mkldnn', bool, False, "Whether to use mkldnn. If set to True, set model_path option to a model trained with mkldnn.")
add_arg('iterations', int, 0, "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size', int, 1, "The minibatch size.")
# yapf: enable
......@@ -25,11 +32,18 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
sequence = infer(images, num_classes)
sequence = infer(
images,
num_classes,
use_mkldnn=args.use_mkldnn,
use_cudnn=True if args.use_gpu else False)
# data reader
infer_reader = data_reader.inference(
batch_size=args.batch_size,
infer_images_dir=args.input_images_dir,
infer_list_file=args.input_images_list)
infer_list_file=args.input_images_list,
cycle=True if args.iterations > 0 else False)
# prepare environment
place = fluid.CPUPlace()
if args.use_gpu:
......@@ -56,23 +70,70 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
print "Init model from: %s." % args.model_path
batch_times = []
iters = 0
for data in infer_reader():
if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
break
if iters < args.skip_batch_num:
print("Warm-up itaration")
if iters == args.skip_batch_num:
profiler.reset_profiler()
start = time.time()
result = exe.run(fluid.default_main_program(),
feed=get_feeder_data(
data, place, need_label=False),
fetch_list=[sequence],
return_numpy=False)
batch_time = time.time() - start
fps = args.batch_size / batch_time
batch_times.append(batch_time)
indexes = np.array(result[0]).flatten()
if dict_map is not None:
print "result: %s" % ([dict_map[index] for index in indexes], )
print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
iters,
batch_time,
fps,
[dict_map[index] for index in indexes], )
else:
print "result: %s" % (indexes, )
print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
iters,
batch_time,
fps,
indexes, )
iters += 1
# Postprocess benchmark data
latencies = batch_times[args.skip_batch_num:]
latency_avg = np.average(latencies)
latency_pc99 = np.percentile(latencies, 99)
fpses = np.divide(args.batch_size, latencies)
fps_avg = np.average(fpses)
fps_pc99 = np.percentile(fpses, 1)
# Benchmark output
print('\nTotal examples (incl. warm-up): %d' % (iters * args.batch_size))
print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
latency_pc99))
print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
def main():
args = parser.parse_args()
print_arguments(args)
inference(args, data_reader=ctc_reader)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
inference(args, data_reader=ctc_reader)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
inference(args, data_reader=ctc_reader)
else:
inference(args, data_reader=ctc_reader)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册