提交 bf485999 编写于 作者: L Liu Yiqun

Merge branch 'develop' into core_inference_prepare

......@@ -48,6 +48,13 @@ parser.add_argument(
type=int,
default=16,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test')
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
"--dict_size",
type=int,
......@@ -72,16 +79,21 @@ parser.add_argument(
default=3,
help="The width for beam searching. (default: %(default)d)")
parser.add_argument(
"--use_gpu",
type=distutils.util.strtobool,
default=True,
help="Whether to use gpu. (default: %(default)d)")
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
"--max_length",
type=int,
default=250,
help="The maximum length of sequence when doing generation. "
"(default: %(default)d)")
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
def lstm_step(x_t, hidden_t_prev, cell_t_prev, size):
......@@ -281,7 +293,7 @@ def train():
paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
batch_size=args.batch_size)
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = Executor(place)
exe.run(framework.default_startup_program())
......@@ -307,14 +319,20 @@ def train():
return total_loss / count
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in xrange(args.pass_num):
pass_start_time = time.time()
words_seen = 0
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_batch_generator()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
src_seq, word_num = to_lodtensor(map(lambda x: x[0], data), place)
words_seen += word_num
num_samples += word_num
trg_seq, word_num = to_lodtensor(map(lambda x: x[1], data), place)
words_seen += word_num
num_samples += word_num
lbl_seq, _ = to_lodtensor(map(lambda x: x[2], data), place)
fetch_outs = exe.run(framework.default_main_program(),
......@@ -325,24 +343,36 @@ def train():
},
fetch_list=[avg_cost])
avg_cost_val = np.array(fetch_outs[0])
print('pass_id=%d, batch_id=%d, train_loss: %f' %
(pass_id, batch_id, avg_cost_val))
iters += 1
loss = np.array(fetch_outs[0])
print(
"Pass = %d, Iter = %d, Loss = %f" % (pass_id, iters, loss)
) # The accuracy is the accumulation of batches, but not the current batch.
pass_end_time = time.time()
test_loss = do_validation()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" %
(pass_id, test_loss, words_per_sec, time_consumed))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
test_loss = do_validation()
exit(0)
def infer():
pass
def print_arguments(args):
print('----------- seq2seq Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.infer_only:
infer()
else:
......
......@@ -35,6 +35,12 @@ def parse_args():
parser = argparse.ArgumentParser("mnist model benchmark.")
parser.add_argument(
'--batch_size', type=int, default=128, help='The minibatch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=35, help='The number of minibatches.')
parser.add_argument(
......@@ -53,19 +59,14 @@ def parse_args():
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
args = parser.parse_args()
return args
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def cnn_model(data):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=data,
......@@ -138,9 +139,6 @@ def run_benchmark(model, args):
# inference program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
target_vars=[batch_acc, batch_size_tensor])
# Optimization
opt = fluid.optimizer.AdamOptimizer(
......@@ -160,39 +158,60 @@ def run_benchmark(model, args):
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
accuracy = fluid.average.WeightedAverage()
accuracy = fluid.metrics.Accuracy()
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
accuracy.reset()
pass_start = time.time()
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
img_data = np.array(
map(lambda x: x[0].reshape([1, 28, 28]), data)).astype(DTYPE)
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([len(y_data), 1])
start = time.time()
outs = exe.run(
fluid.default_main_program(),
feed={"pixel": img_data,
"label": y_data},
fetch_list=[avg_cost, batch_acc, batch_size_tensor]
) # The accuracy is the accumulation of batches, but not the current batch.
accuracy.add(value=outs[1], weight=outs[2])
end = time.time()
accuracy.update(value=outs[1], weight=outs[2])
iters += 1
num_samples += len(y_data)
loss = np.array(outs[0])
acc = np.array(outs[1])
print("pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" %
(pass_id, batch_id, loss, 1 - acc, (end - start) / 1000))
train_losses.append(loss)
train_accs.append(acc)
print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" %
(pass_id, iters, loss, acc))
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
pass_end = time.time()
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
test_avg_acc = eval_test(exe, batch_acc, batch_size_tensor,
inference_program)
exit(0)
train_avg_acc = accuracy.eval()
test_avg_acc = eval_test(exe, batch_acc, batch_size_tensor,
inference_program)
print("pass=%d, train_avg_acc=%f, test_avg_acc=%f, elapse=%f" %
(pass_id, train_avg_acc, test_avg_acc,
(pass_end - pass_start) / 1000))
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- mnist Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
......
......@@ -87,15 +87,6 @@ def parse_args():
return args
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
conv1 = fluid.layers.conv2d(
input=input,
......@@ -279,32 +270,31 @@ def run_benchmark(model, args):
'label': label},
fetch_list=[avg_cost, batch_acc, batch_size_tensor])
iters += 1
num_samples += label[0]
num_samples += len(label)
accuracy.add(value=acc, weight=weight)
train_losses.append(loss)
train_accs.append(acc)
print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" %
(pass_id, iters, loss, acc))
pass_train_acc = accuracy.eval()
# evaluation
if args.with_test:
pass_test_acc = test(exe)
train_elapsed = time.time() - start_time
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
pass_test_acc = test(exe)
exit(0)
if args.use_cprof:
pr.disable()
s = StringIO.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- resnet Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
......
#!/bin/bash
# This script benchmarking the PaddlePaddle Fluid on
# single thread single GPU.
export CUDNN_PATH=/paddle/cudnn_v5/cuda/lib
#export FLAGS_fraction_of_gpu_memory_to_use=0.0
export CUDNN_PATH=/paddle/cudnn_v5
# disable openmp and mkl parallel
#https://github.com/PaddlePaddle/Paddle/issues/7199
......@@ -25,25 +27,79 @@ export CUDA_VISIBLE_DEVICES=0
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$CUDNN_PATH:$LD_LIBRARY_PATH
# only query the gpu used
nohup stdbuf -oL nvidia-smi \
--id=${CUDA_VISIBLE_DEVICES} \
--query-gpu=timestamp \
--query-compute-apps=pid,process_name,used_memory \
--format=csv \
--filename=mem.log \
-l 1 &
# mnist
# mnist gpu mnist 128
FLAGS_benchmark=true stdbuf -oL python fluid/mnist.py \
--device=GPU \
--batch_size=128 \
--skip_batch_num=5 \
--iterations=500 \
2>&1 | tee -a mnist_gpu_128.log
# vgg16
# cifar10 gpu cifar10 128
FLAGS_benchmark=true python fluid/vgg.py \
# gpu cifar10 128
FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
--device=GPU \
--batch_size=128 \
--skip_batch_num=5 \
--iterations=30 \
2>&1 > vgg16_gpu_128.log
--iterations=30 \
2>&1 | tee -a vgg16_gpu_128.log
# flowers gpu 128
FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
--device=GPU \
--batch_size=32 \
--data_set=flowers \
--skip_batch_num=5 \
--iterations=30 \
2>&1 | tee -a vgg16_gpu_flowers_32.log
# resnet50
# resnet50 gpu cifar10 128
FLAGS_benchmark=true python fluid/resnet.py \
FLAGS_benchmark=true stdbuf -oL python fluid/resnet50.py \
--device=GPU \
--batch_size=128 \
--data_set=cifar10 \
--model=resnet_cifar10 \
--skip_batch_num=5 \
--iterations=30 \
2>&1 > resnet50_gpu_128.log
2>&1 | tee -a resnet50_gpu_128.log
# resnet50 gpu flowers 64
FLAGS_benchmark=true stdbuf -oL python fluid/resnet50.py \
--device=GPU \
--batch_size=64 \
--data_set=flowers \
--model=resnet_imagenet \
--skip_batch_num=5 \
--iterations=30 \
2>&1 | tee -a resnet50_gpu_flowers_64.log
# lstm
# lstm gpu imdb 32 # tensorflow only support batch=32
FLAGS_benchmark=true stdbuf -oL python fluid/stacked_dynamic_lstm.py \
--device=GPU \
--batch_size=32 \
--skip_batch_num=5 \
--iterations=30 \
--hidden_dim=512 \
--emb_dim=512 \
--crop_size=1500 \
2>&1 | tee -a lstm_gpu_32.log
# seq2seq
# seq2seq gpu wmb 128
FLAGS_benchmark=true stdbuf -oL python fluid/machine_translation.py \
--device=GPU \
--batch_size=128 \
--skip_batch_num=5 \
--iterations=30 \
2>&1 | tee -a lstm_gpu_128.log
......@@ -37,6 +37,14 @@ def parse_args():
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--emb_dim',
type=int,
......@@ -64,6 +72,10 @@ def parse_args():
default=int(os.environ.get('CROP_SIZE', '1500')),
help='The max sentence length of input. Since this model use plain RNN,'
' Gradient could be explored if sentence is too long')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
args = parser.parse_args()
return args
......@@ -157,37 +169,43 @@ def main():
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
def train_loop(pass_num, crop_size):
with profiler.profiler(args.device, 'total') as prof:
for pass_id in range(pass_num):
train_reader = batch(
paddle.reader.shuffle(
crop_sentence(imdb.train(word_dict), crop_size),
buf_size=25000),
batch_size=args.batch_size)
word_nums = 0
pass_start_time = time.time()
for batch_id, data in enumerate(train_reader()):
tensor_words = to_lodtensor([x[0] for x in data], place)
for x in data:
word_nums += len(x[0])
label = numpy.array([x[1] for x in data]).astype("int64")
label = label.reshape((-1, 1))
loss_np, acc, weight = exe.run(
fluid.default_main_program(),
feed={"words": tensor_words,
"label": label},
fetch_list=[loss, batch_acc, batch_size_tensor])
print("pass_id=%d, batch_id=%d, loss=%f, acc=%f" %
(pass_id, batch_id, loss_np, acc))
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
words_per_sec = word_nums / time_consumed
print("pass_id=%d, sec/pass: %f, words/s: %f" %
(pass_id, time_consumed, words_per_sec))
train_loop(args.pass_num, args.crop_size)
train_reader = batch(
paddle.reader.shuffle(
crop_sentence(imdb.train(word_dict), args.crop_size),
buf_size=25000),
batch_size=args.batch_size)
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
tensor_words = to_lodtensor([x[0] for x in data], place)
label = numpy.array([x[1] for x in data]).astype("int64")
label = label.reshape((-1, 1))
loss_np, acc, weight = exe.run(
fluid.default_main_program(),
feed={"words": tensor_words,
"label": label},
fetch_list=[loss, batch_acc, batch_size_tensor])
iters += 1
for x in data:
num_samples += len(x[0])
print(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" %
(pass_id, iters, loss_np, acc)
) # The accuracy is the accumulation of batches, but not the current batch.
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
exit(0)
def to_lodtensor(data, place):
......@@ -205,5 +223,14 @@ def to_lodtensor(data, place):
return res
def print_arguments(args):
print('----------- lstm Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
main()
......@@ -191,25 +191,29 @@ def main():
fetch_list=[avg_cost, batch_acc, batch_size_tensor])
accuracy.add(value=acc, weight=weight)
iters += 1
num_samples += len(data)
num_samples += len(y_data)
print(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" %
(pass_id, iters, loss, acc)
) # The accuracy is the accumulation of batches, but not the current batch.
pass_train_acc = accuracy.eval()
# pass_train_acc = accuracy.eval()
train_losses.append(loss)
train_accs.append(acc)
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
pass_test_acc = test(exe)
train_elapsed = time.time() - start_time
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
exit(0)
def print_arguments():
print('----------- Configuration Arguments -----------')
print('----------- vgg Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
......
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
import numpy as np
import tensorflow as tf
import paddle.v2 as paddle
DTYPE = tf.float32
def parse_args():
parser = argparse.ArgumentParser("mnist model benchmark.")
parser.add_argument(
'--batch_size', type=int, default=128, help='The minibatch size.')
parser.add_argument(
'--iterations', type=int, default=35, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=5, help='The number of passes.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
args = parser.parse_args()
return args
def run_benchmark(args):
def weight_variable(dtype, shape):
initial = tf.truncated_normal(shape, stddev=0.1, dtype=dtype)
return tf.Variable(initial)
def bias_variable(dtype, shape):
initial = tf.constant(0.1, shape=shape, dtype=dtype)
return tf.Variable(initial)
device = '/cpu:0' if args.device == 'CPU' else '/device:GPU:0'
with tf.device(device):
images = tf.placeholder(DTYPE, shape=(None, 28, 28, 1))
labels = tf.placeholder(tf.int64, shape=(None, ))
# conv1, relu, pool1
conv1_weights = weight_variable(DTYPE, [5, 5, 1, 20])
conv1_bias = bias_variable(DTYPE, [20])
conv1 = tf.nn.conv2d(
images, conv1_weights, strides=[1, 1, 1, 1], padding="VALID")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_bias))
pool1 = tf.nn.max_pool(
relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")
# conv2, relu, pool2
conv2_weights = weight_variable(DTYPE, [5, 5, 20, 50])
conv2_bias = bias_variable(DTYPE, [50])
conv2 = tf.nn.conv2d(
pool1, conv2_weights, strides=[1, 1, 1, 1], padding="VALID")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
pool2 = tf.nn.max_pool(
relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")
# FC
pool_shape = pool2.get_shape().as_list()
hidden_dim = reduce(lambda a, b: a * b, pool_shape[1:], 1)
reshape = tf.reshape(pool2, shape=(tf.shape(pool2)[0], hidden_dim))
fc_weights = weight_variable(DTYPE, [hidden_dim, 10])
fc_bias = bias_variable(DTYPE, [10])
logits = tf.matmul(reshape, fc_weights) + fc_bias
# Get prediction
prediction = tf.nn.softmax(logits)
# Loss
one_hot_labels = tf.one_hot(labels, depth=10)
cost = -tf.reduce_sum(tf.log(prediction) * one_hot_labels, [1])
avg_cost = tf.reduce_mean(cost)
# Get accuracy
correct = tf.equal(tf.argmax(prediction, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
# metrics, g_accuracy
with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
g_accuracy = tf.metrics.accuracy(
labels, tf.argmax(
prediction, axis=1))
vars = tf.contrib.framework.get_variables(
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
g_accuracy_reset_op = tf.variables_initializer(vars)
# Optimizer
opt = tf.train.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
train_op = opt.minimize(avg_cost)
# train_op = tf.train.AdamOptimizer(1e-4).minimize(avg_cost)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=args.batch_size)
def eval_test():
sess.run(g_accuracy_reset_op)
for batch_id, data in enumerate(test_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype("int64")
loss, acc, g_acc = sess.run(
[avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
return g_acc[1]
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
for pass_id in range(args.pass_num):
sess.run(g_accuracy_reset_op)
pass_start = time.time()
for batch_id, data in enumerate(train_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype(
"int64")
start = time.time()
_, loss, acc, g_acc = sess.run(
[train_op, avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
end = time.time()
print("pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" %
(pass_id, batch_id, loss, 1 - acc, (end - start) / 1000))
pass_end = time.time()
test_avg_acc = eval_test()
print(
"pass=%d, training_avg_accuracy=%f, test_avg_acc=%f, elapse=%f"
% (pass_id, g_acc[1], test_avg_acc,
(pass_end - pass_start) / 1000))
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
run_benchmark(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
based on https://github.com/tensorflow/models/blob/master/official/resnet/resnet_model.py
Get help: python resnet.py --help
See performance on flowers: python resnet.py
Train on cifar10: python resnet.py --data=cifar10 --with_test
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
import numpy as np
import paddle.v2 as paddle
import tensorflow as tf
DTYPE = tf.float32
def parse_args():
parser = argparse.ArgumentParser('Convolution model benchmark.')
parser.add_argument(
'--model',
type=str,
choices=['resnet'],
default='resnet',
help='The model architecture.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
parser.add_argument(
'--use_fake_data',
action='store_true',
help='use real data or fake data')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations',
type=int,
default=105,
help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=300, help='The number of passes.')
parser.add_argument(
'--order',
type=str,
default='NHWC',
choices=['NCHW', 'NHWC'],
help='The data order, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--data',
type=str,
default='flowers102',
choices=['flowers102', 'cifar10'],
help='The kinds of data.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
args = parser.parse_args()
return args
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
vars(args)['iterations'] = vars(args)['pass_num'] * 1000 if vars(args)[
'with_test'] else vars(args)['iterations']
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def fixed_padding(inputs, kernel_size, data_format):
"""Pads the input along the spatial dimensions independently of input size.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
Should be a positive integer.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
A tensor with the same format as the input with the data either intact
(if kernel_size == 1) or padded (if kernel_size > 1).
"""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if data_format == 'channels_first':
padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], [pad_beg, pad_end],
[pad_beg, pad_end]])
else:
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
[pad_beg, pad_end], [0, 0]])
return padded_inputs
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
"""Strided 2-D convolution with explicit padding."""
# The padding is consistent and is based only on `kernel_size`, not on the
# dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
# This is consistent with PaddlePaddle.
# In addition, the calculation for output size in TensorFlow can refer:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/common_shape_fns.cc
if strides > 1:
inputs = fixed_padding(inputs, kernel_size, data_format)
return tf.layers.conv2d(
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=('SAME' if strides == 1 else 'VALID'),
use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(),
data_format=data_format)
def conv_bn(inputs,
filters,
kernel_size,
strides,
is_training,
data_format,
act=True):
# def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
# set fused=True for a significant performance boost. See
# https://www.tensorflow.org/performance/performance_guide#common_fused_ops
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
data_format=data_format)
inputs = tf.layers.batch_normalization(
inputs=inputs,
axis=1 if data_format == 'channels_first' else 3,
momentum=0.9,
epsilon=1e-05,
center=True,
scale=True,
training=is_training,
fused=True)
if act:
inputs = tf.nn.relu(inputs)
return inputs
def basicblock(inputs, filters, is_training, projection_shortcut, strides,
data_format):
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv_bn(inputs, filters, 3, strides, is_training, data_format)
inputs = conv_bn(inputs, filters, 3, 1, is_training, data_format, act=False)
inputs = inputs + shortcut
inputs = tf.nn.relu(inputs)
return inputs
def bottleneck(inputs, filters, is_training, projection_shortcut, strides,
data_format):
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv_bn(inputs, filters, 1, strides, is_training, data_format)
inputs = conv_bn(inputs, filters, 3, 1, is_training, data_format, act=False)
inputs = conv_bn(
inputs, filters * 4, 1, 1, is_training, data_format, act=False)
inputs = inputs + shortcut
inputs = tf.nn.relu(inputs)
return inputs
def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name,
data_format):
# Bottleneck blocks end with 4x the number of filters as they start with
filters_out = 4 * filters if block_fn is bottleneck else filters
def projection_shortcut(inputs):
return conv2d_fixed_padding(
inputs=inputs,
filters=filters_out,
kernel_size=1,
strides=strides,
data_format=data_format)
# Only the first block per block_layer uses projection_shortcut and strides
inputs = block_fn(inputs, filters, is_training, projection_shortcut,
strides, data_format)
for _ in range(1, blocks):
inputs = block_fn(inputs, filters, is_training, None, 1, data_format)
return tf.identity(inputs, name)
def resnet_imagenet(depth, class_dim, data_format):
"""Returns the ResNet model for a given size and number of output classes."""
def resnet_generator(block_fn,
layers,
num_classes,
data_format='channels_last'):
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
def model(inputs, is_training):
"""Constructs the ResNet model given the inputs."""
if data_format == 'channels_first':
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv_bn(inputs, 64, 7, 2, is_training, data_format)
inputs = tf.identity(inputs, 'initial_conv')
inputs = tf.layers.max_pooling2d(
inputs=inputs,
pool_size=3,
strides=2,
padding='SAME',
data_format=data_format)
inputs = tf.identity(inputs, 'initial_max_pool')
inputs = block_layer(inputs, 64, block_fn, layers[0], 1,
is_training, 'block_layer1', data_format)
inputs = block_layer(inputs, 128, block_fn, layers[1], 2,
is_training, 'block_layer2', data_format)
inputs = block_layer(inputs, 256, block_fn, layers[2], 2,
is_training, 'block_layer3', data_format)
inputs = block_layer(inputs, 512, block_fn, layers[3], 2,
is_training, 'block_layer4', data_format)
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=7,
strides=1,
padding='VALID',
data_format=data_format)
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs,
[-1, 512 if block_fn is basicblock else 2048])
inputs = tf.layers.dense(inputs=inputs, units=num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs
return model
model_params = {
18: {
'block': basicblock,
'layers': [2, 2, 2, 2]
},
34: {
'block': basicblock,
'layers': [3, 4, 6, 3]
},
50: {
'block': bottleneck,
'layers': [3, 4, 6, 3]
},
101: {
'block': bottleneck,
'layers': [3, 4, 23, 3]
},
152: {
'block': bottleneck,
'layers': [3, 8, 36, 3]
},
200: {
'block': bottleneck,
'layers': [3, 24, 36, 3]
}
}
if depth not in model_params:
raise ValueError('Not a valid depth:', depth)
params = model_params[depth]
return resnet_generator(params['block'], params['layers'], class_dim,
data_format)
def resnet_cifar10(depth, num_classes, data_format):
if depth % 6 != 2:
raise ValueError('depth must be 6n + 2:', depth)
num_blocks = (depth - 2) // 6
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
def model(inputs, is_training):
inputs = conv_bn(inputs, 16, 3, 1, is_training, data_format)
inputs = tf.identity(inputs, 'initial_conv')
inputs = block_layer(inputs, 16, basicblock, num_blocks, 1, is_training,
'block_layer1', data_format)
inputs = block_layer(inputs, 32, basicblock, num_blocks, 2, is_training,
'block_layer2', data_format)
inputs = block_layer(inputs, 64, basicblock, num_blocks, 2, is_training,
'block_layer3', data_format)
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=8,
strides=1,
padding='VALID',
data_format=data_format)
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs, [-1, 64])
inputs = tf.layers.dense(inputs=inputs, units=num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs
return model
def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
"""Our model_fn for ResNet to be used with our Estimator."""
class_dim = 1000
dshape = (None, 224, 224, 3)
pdshape = (3, 224, 224)
if args.data == 'flowers102':
class_dim = 102
dshape = (None, 224, 224, 3)
pdshape = (3, 224, 224)
elif args.data == 'cifar10':
class_dim = 10
dshape = (None, 32, 32, 3)
pdshape = (3, 32, 32)
with tf.device(device):
images = tf.placeholder(DTYPE, shape=dshape)
labels = tf.placeholder(tf.int64, shape=(None, ))
is_training = tf.placeholder('bool')
onehot_labels = tf.one_hot(labels, depth=class_dim)
network = resnet_cifar10(
32, class_dim,
data_format) if args.data == 'cifar10' else resnet_imagenet(
50, class_dim, data_format)
logits = network(inputs=images, is_training=is_training)
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=onehot_labels)
avg_cost = tf.reduce_mean(cross_entropy)
correct = tf.equal(tf.argmax(logits, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
lr = 0.1 if args.data == 'cifar10' else 0.01
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
# Batch norm requires update_ops to be added as a train_op dependency.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10()
if args.data == 'cifar10' else paddle.dataset.flowers.train(),
buf_size=5120),
batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.dataset.cifar.test10()
if args.data == 'cifar10' else paddle.dataset.flowers.test(),
batch_size=100)
def test():
test_accs = []
for batch_id, data in enumerate(test_reader()):
test_images = np.array(
map(lambda x: np.transpose(x[0].reshape(pdshape),
axes=[1, 2, 0]), data)).astype("float32")
test_labels = np.array(map(lambda x: x[1], data)).astype('int64')
test_accs.append(
accuracy.eval(feed_dict={
images: test_images,
labels: test_labels,
is_training: False
}))
print("Pass = %d, Train performance = %f imgs/s, Test accuracy = %f\n" %
(pass_id, num_samples / train_elapsed, np.mean(test_accs)))
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
if args.use_fake_data:
data = train_reader().next()
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape(pdshape),
axes=[1, 2, 0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype('int64')
iters, num_samples, start_time = 0, 0, 0.0
for pass_id in range(args.pass_num):
if iters == args.iterations:
break
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
if not args.use_fake_data:
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape(pdshape),
axes=[1, 2, 0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype(
'int64')
_, loss, acc = sess.run([train_op, avg_cost, accuracy],
feed_dict={
images: images_data,
labels: labels_data,
is_training: True
})
iters += 1
train_accs.append(acc)
train_losses.append(loss)
num_samples += len(data)
print("Pass=%d, Iter=%d, Loss=%f, Accuray=%f\n" %
(pass_id, iters, loss, acc))
train_elapsed = time.time() - start_time
print("Pass=%d, Loss=%f, Accuray=%f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
# evaluation
if args.with_test:
test()
if not args.with_test:
duration = time.time() - start_time
examples_per_sec = num_samples / duration
sec_per_batch = duration / (iters - args.skip_batch_num)
print('Total examples: %d, total time: %.5f' %
(num_samples, duration))
print('%.5f examples/sec, %.5f sec/batch' %
(examples_per_sec, sec_per_batch))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
if tf.test.is_built_with_cuda():
device = '/device:GPU:0'
if args.order == 'NHWC':
data_format = 'channels_last'
else:
data_format = 'channels_first'
else:
device = '/cpu:0'
if args.order == 'NHWC':
data_format = 'channels_last'
else:
raise ValueError('Only support NHWC order in CPU mode')
run_benchmark(args, data_format, device)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import argparse
import time
import tensorflow as tf
import paddle.v2 as paddle
def parse_args():
parser = argparse.ArgumentParser("LSTM model benchmark.")
parser.add_argument(
'--batch_size',
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
default=5,
help='Number of lstm layers to stack. (default: %(default)d)')
parser.add_argument(
'--embedding_dim',
type=int,
default=512,
help='Dimension of embedding table. (default: %(default)d)')
parser.add_argument(
'--hidden_dim',
type=int,
default=512,
help='Hidden size of lstm unit. (default: %(default)d)')
parser.add_argument(
'--pass_num',
type=int,
default=10,
help='Epoch number to train. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
default=0.0002,
help='Learning rate used to train. (default: %(default)f)')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
args = parser.parse_args()
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def dynamic_lstm_model(dict_size,
embedding_dim,
hidden_dim,
stacked_num,
class_num=2,
is_train=True):
word_idx = tf.placeholder(tf.int64, shape=[None, None])
sequence_length = tf.placeholder(tf.int64, shape=[None, ])
embedding_weights = tf.get_variable('word_embeddings',
[dict_size, embedding_dim])
embedding = tf.nn.embedding_lookup(embedding_weights, word_idx)
lstm_cell = tf.nn.rnn_cell.LSTMCell(
num_units=hidden_dim, use_peepholes=False)
stacked_cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * stacked_num)
# final_state [LSTMTuple(c, h), LSTMTuple(c, h) ...] total stacked_num LSTMTuples
_, final_state = tf.nn.dynamic_rnn(
cell=stacked_cell,
inputs=embedding,
dtype=tf.float32,
sequence_length=sequence_length)
w = tf.Variable(
tf.truncated_normal([hidden_dim, class_num]), dtype=tf.float32)
bias = tf.Variable(
tf.constant(
value=0.0, shape=[class_num], dtype=tf.float32))
prediction = tf.matmul(final_state[-1][1], w) + bias
if not is_train:
return (word_idx, sequence_length), tf.nn.softmax(prediction)
label = tf.placeholder(tf.int64, shape=[None, ])
loss = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.one_hot(label, 2), logits=prediction)
avg_loss = tf.reduce_mean(loss)
correct_count = tf.equal(tf.argmax(prediction, 1), label)
acc = tf.reduce_mean(tf.cast(correct_count, tf.float32))
with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
g_acc = tf.metrics.accuracy(label, tf.argmax(prediction, axis=1))
vars = tf.contrib.framework.get_variables(
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
reset_op = tf.variables_initializer(vars)
return (word_idx, sequence_length, label), avg_loss, acc, g_acc, reset_op
def padding_data(data, padding_size, value):
data = data + [value] * padding_size
return data[:padding_size]
def train(args):
word_dict = paddle.dataset.imdb.word_dict()
dict_size = len(word_dict)
feeding_list, avg_loss, acc, g_acc, reset_op = dynamic_lstm_model(
dict_size, args.embedding_dim, args.hidden_dim, args.stacked_num)
adam_optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
train_op = adam_optimizer.minimize(avg_loss)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=25000),
batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.test(word_dict), buf_size=25000),
batch_size=args.batch_size)
def do_validation(sess):
sess.run(reset_op)
for batch_id, data in enumerate(test_reader()):
word_idx = map(lambda x: x[0], data)
sequence_length = np.array(
[len(seq) for seq in word_idx]).astype('int64')
maxlen = np.max(sequence_length)
word_idx = [padding_data(seq, maxlen, 0) for seq in word_idx]
word_idx = np.array(word_idx).astype('int64')
label = np.array(map(lambda x: x[1], data)).astype('int64')
_, loss, fetch_acc, fetch_g_acc = sess.run(
[train_op, avg_loss, acc, g_acc],
feed_dict={
feeding_list[0]: word_idx,
feeding_list[1]: sequence_length,
feeding_list[2]: label
})
return fetch_g_acc[1]
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_l)
sess.run(init_g)
for pass_id in xrange(args.pass_num):
# clear accuracy local variable
sess.run(reset_op)
pass_start_time = time.time()
words_seen = 0
for batch_id, data in enumerate(train_reader()):
word_idx = map(lambda x: x[0], data)
sequence_length = np.array(
[len(seq) for seq in word_idx]).astype('int64')
words_seen += np.sum(sequence_length)
maxlen = np.max(sequence_length)
word_idx = [padding_data(seq, maxlen, 0) for seq in word_idx]
word_idx = np.array(word_idx).astype('int64')
label = np.array(map(lambda x: x[1], data)).astype('int64')
_, loss, fetch_acc, fetch_g_acc = sess.run(
[train_op, avg_loss, acc, g_acc],
feed_dict={
feeding_list[0]: word_idx,
feeding_list[1]: sequence_length,
feeding_list[2]: label
})
print("pass_id=%d, batch_id=%d, loss: %f, acc: %f, avg_acc: %f"
% (pass_id, batch_id, loss, fetch_acc, fetch_g_acc[1]))
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
test_acc = do_validation(sess)
print("pass_id=%d, test_acc: %f, words/s: %f, sec/pass: %f" %
(pass_id, test_acc, words_per_sec, time_consumed))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
if args.infer_only:
pass
else:
train(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VGG16 benchmark in TensorFlow"""
import tensorflow as tf
import paddle.v2 as paddle
import numpy as np
import argparse
import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--batch_size', type=int, default=128, help="Batch size for training.")
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test')
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--learning_rate',
type=float,
default=1e-3,
help="Learning rate for training.")
parser.add_argument('--num_passes', type=int, default=50, help="No. of passes.")
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--data_format',
type=str,
default='NHWC',
choices=['NCHW', 'NHWC'],
help='The data order, NCHW=[batch, channels, height, width].'
'Only support NHWC right now.')
parser.add_argument(
'--data_set',
type=str,
default='cifar10',
choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.')
args = parser.parse_args()
class VGG16Model(object):
def __init__(self):
self.parameters = []
def batch_norm_relu(self, inputs, is_training):
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant speed boost. See
# https://www.tensorflow.org/speed/speed_guide#common_fused_ops
inputs = tf.layers.batch_normalization(
inputs=inputs,
axis=1 if args.data_format == 'NCHW' else -1,
momentum=0.9,
epsilon=1e-05,
center=True,
scale=True,
training=is_training,
fused=True)
inputs = tf.nn.relu(inputs)
return inputs
def conv_bn_layer(self,
name,
images,
kernel_shape,
is_training,
drop_rate=0.0):
with tf.name_scope(name) as scope:
kernel = tf.Variable(
tf.truncated_normal(
kernel_shape, dtype=tf.float32, stddev=1e-1),
name='weights')
conv = tf.nn.conv2d(
images,
kernel, [1, 1, 1, 1],
data_format=args.data_format,
padding='SAME')
biases = tf.Variable(
tf.constant(
0.0, shape=[kernel_shape[-1]], dtype=tf.float32),
trainable=True,
name='biases')
out = tf.nn.bias_add(conv, biases)
out = self.batch_norm_relu(out, is_training)
out = tf.layers.dropout(out, rate=drop_rate, training=is_training)
return out
def fc_layer(self, name, inputs, shape):
with tf.name_scope(name) as scope:
fc_w = tf.Variable(
tf.truncated_normal(
shape, dtype=tf.float32, stddev=1e-1),
name='weights')
fc_b = tf.Variable(
tf.constant(
0.0, shape=[shape[-1]], dtype=tf.float32),
trainable=True,
name='biases')
out = tf.nn.bias_add(tf.matmul(inputs, fc_w), fc_b)
return out
def network(self, images, class_dim, is_training):
""" VGG16 model structure.
TODO(kuke): enable this network to support the 'NCHW' data format
"""
# conv1
conv1_1 = self.conv_bn_layer(
'conv1_1', images, [3, 3, 3, 64], is_training, drop_rate=0.3)
conv1_2 = self.conv_bn_layer(
'conv1_2', conv1_1, [3, 3, 64, 64], is_training, drop_rate=0.0)
# pool1
pool1 = tf.nn.max_pool(
conv1_2,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool1')
# conv2
conv2_1 = self.conv_bn_layer(
'conv2_1', pool1, [3, 3, 64, 128], is_training, drop_rate=0.4)
conv2_2 = self.conv_bn_layer(
'conv2_2', conv2_1, [3, 3, 128, 128], is_training, drop_rate=0.0)
# pool2
pool2 = tf.nn.max_pool(
conv2_2,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool2')
# conv3
conv3_1 = self.conv_bn_layer(
'conv3_1', pool2, [3, 3, 128, 256], is_training, drop_rate=0.4)
conv3_2 = self.conv_bn_layer(
'conv3_2', conv3_1, [3, 3, 256, 256], is_training, drop_rate=0.4)
conv3_3 = self.conv_bn_layer(
'conv3_3', conv3_2, [3, 3, 256, 256], is_training, drop_rate=0.0)
# pool3
pool3 = tf.nn.max_pool(
conv3_3,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool3')
# conv4
conv4_1 = self.conv_bn_layer(
'conv4_1', pool3, [3, 3, 256, 512], is_training, drop_rate=0.4)
conv4_2 = self.conv_bn_layer(
'conv4_2', conv4_1, [3, 3, 512, 512], is_training, drop_rate=0.4)
conv4_3 = self.conv_bn_layer(
'conv4_3', conv4_2, [3, 3, 512, 512], is_training, drop_rate=0.0)
# pool4
pool4 = tf.nn.max_pool(
conv4_3,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool4')
# conv5
conv5_1 = self.conv_bn_layer(
'conv5_1', pool4, [3, 3, 512, 512], is_training, drop_rate=0.4)
conv5_2 = self.conv_bn_layer(
'conv5_2', conv5_1, [3, 3, 512, 512], is_training, drop_rate=0.4)
conv5_3 = self.conv_bn_layer(
'conv5_3', conv5_2, [3, 3, 512, 512], is_training, drop_rate=0.0)
# pool5
pool5 = tf.nn.max_pool(
conv5_3,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool4')
# flatten
shape = int(np.prod(pool5.get_shape()[1:]))
pool5_flat = tf.reshape(pool5, [-1, shape])
# fc1
drop = tf.layers.dropout(pool5_flat, rate=0.5, training=is_training)
fc1 = self.fc_layer('fc1', drop, [shape, 512])
# fc2
bn = self.batch_norm_relu(fc1, is_training)
drop = tf.layers.dropout(bn, rate=0.5, training=is_training)
fc2 = self.fc_layer('fc2', drop, [512, 512])
fc3 = self.fc_layer('fc3', fc2, [512, class_dim])
return fc3
def run_benchmark():
"""Run benchmark on cifar10 or flowers."""
if args.data_set == "cifar10":
class_dim = 10
raw_shape = (3, 32, 32)
dat_shape = (None, 32, 32, 3) if args.data_format == 'NHWC' else (
None, 3, 32, 32)
else:
class_dim = 102
raw_shape = (3, 224, 224)
dat_shape = (None, 224, 224, 3) if args.data_format == 'NHWC' else (
None, 3, 224, 224)
device = '/cpu:0' if args.device == 'CPU' else '/device:GPU:0'
with tf.device(device):
images = tf.placeholder(tf.float32, shape=dat_shape)
labels = tf.placeholder(tf.int64, shape=(None, ))
is_training = tf.placeholder('bool')
onehot_labels = tf.one_hot(labels, depth=class_dim)
vgg16 = VGG16Model()
logits = vgg16.network(images, class_dim, is_training)
loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels, logits=logits)
avg_loss = tf.reduce_mean(loss)
correct = tf.equal(tf.argmax(logits, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(avg_loss)
# data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
buf_size=5120),
batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.test10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
buf_size=5120),
batch_size=args.batch_size)
# test
def test():
test_accs = []
for batch_id, data in enumerate(test_reader()):
test_images = np.array(
map(lambda x: np.transpose(x[0].reshape(raw_shape),
axes=[1, 2, 0]) if args.data_format == 'NHWC' else x[0], data)).astype("float32")
test_labels = np.array(map(lambda x: x[1], data)).astype('int64')
test_accs.append(
accuracy.eval(feed_dict={
images: test_images,
labels: test_labels,
is_training: False
}))
return np.mean(test_accs)
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.num_passes):
# train
num_samples = 0
start_time = time.time()
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
train_images = np.array(
map(lambda x: np.transpose(x[0].reshape(raw_shape),
axes=[1, 2, 0]) if args.data_format == 'NHWC' else x[0], data)).astype("float32")
train_labels = np.array(map(lambda x: x[1], data)).astype(
'int64')
_, loss, acc = sess.run([train_op, avg_loss, accuracy],
feed_dict={
images: train_images,
labels: train_labels,
is_training: True
})
iters += 1
num_samples += len(data)
print("Pass = %d, Iters = %d, Loss = %f, Accuracy = %f" %
(pass_id, iters, loss, acc))
train_elapsed = time.time() - start_time
# test
pass_test_acc = test()
print("Pass = %d, Train speed = %f imgs/s, Test accuracy = %f\n" %
(pass_id, num_samples / train_elapsed, pass_test_acc))
def print_arguments():
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
print_arguments()
run_benchmark()
......@@ -36,7 +36,8 @@ MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR})
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include mkldnn.h
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(NOT WITH_GPU)
return()
endif()
include(ExternalProject)
set(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
include_directories(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies
set(NCCL_BUILD_COMMAND "")
set(NCCL_INSTALL_COMMAND "")
set(NCCL_INSTALL_DIR "")
else()
# otherwise, we build nccl and link it.
set(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
# Note: cuda 8.0 is needed to make nccl
# When cuda is not installed on the system directory, need to set CUDA_HOME to your cuda root
set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install PREFIX=${NCCL_INSTALL_DIR}")
endif()
ExternalProject_Add(
extern_nccl
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
GIT_TAG "v1.3.4-1"
PREFIX "${NCCL_SOURCE_DIR}"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
INSTALL_DIR "${NCCL_INSTALL_DIR}"
TEST_COMMAND ""
)
if(WITH_DSO)
if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_nccl_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_nccl = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile})
else()
add_library(nccl INTERFACE)
endif()
else()
add_library(nccl STATIC IMPORTED GLOBAL)
set_property(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl_static.a)
endif()
add_dependencies(nccl extern_nccl)
......@@ -244,11 +244,11 @@ function(cc_test TARGET_NAME)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS})
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main memory gtest gflags glog)
if("${cc_test_DEPS}" MATCHES "ARCHIVE_START")
list(REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END)
endif()
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main memory gtest gflags glog)
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......@@ -311,8 +311,8 @@ function(nv_test TARGET_NAME)
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog)
add_test(${TARGET_NAME} ${TARGET_NAME})
endif()
endfunction(nv_test)
......@@ -387,8 +387,8 @@ function(hip_test TARGET_NAME)
endif()
add_executable(${TARGET_NAME} ${_cmake_options} ${_generated_files} ${_sources})
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP)
target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags)
add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags)
add_test(${TARGET_NAME} ${TARGET_NAME})
endif()
endfunction(hip_test)
......
......@@ -16,3 +16,4 @@
block.md
scope.md
executor.md
parallel_executor.md
......@@ -16,3 +16,4 @@ Core Concepts
block.md
scope.md
executor.md
parallel_executor.md
# Problem
# Kernel Hint Design
## Problem
In PaddlePaddle's [Design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md), one Operator may have multiple kernels. Users may have some personal preference to choose a certain type of kernel for an operator, such as `force_cpu` to choose a CPU kernel, `use_cudnn` to choose a CUDNN kernel, we need to provide a way for users to do this.
In the current design, we use KernelType to describe one kernel.
......
# Background
# Kernel Selection
## Background
Every operator has many kernels because there are multiple data types, places, data layout, library type that Fluid supports. We use the `OpKernelType ` to describe kernel types that operators can hold.
The `OpKernelType ` is as follows:
......
Install and Build
=================
install and Compile
==========
.. _install_steps:
Install Steps
++++++++
PaddlePaddle provides various methods of installation for many different users
You can choose either pip or Docker to complete your install:
Focus on Deep Learning Model Development
-----------------
PaddlePaddle provides lots of packages of python wheel , that pip can install:
.. toctree::
:maxdepth: 1
:maxdepth: 1
pip_install_en.rst
docker_install_en.rst
pip_install_en.rst
Build from Source
-----------------
This is the most convenient way of installation. Please choose the right installation package with machine configure and system.
Follow the Bottom Frame
----------
PaddlePaddle also supports installation using Docker. Please refer to the tutorial below:
.. toctree::
:maxdepth: 1
docker_install_en.rst
.. warning::
We recommend running PaddlePaddle in Docker. This method has the following advantages:
We recommend to directly install via above installation steps, you'll only need to build PaddlePaddle from source when you need a modifed binary.
- Does not require installation of third-party dependencies.
- Easy to share runtime environment.
.. toctree::
Lastly, users can also compile and install PaddlePaddle from source code. The instructions are below:
.. toctree::
:maxdepth: 1
build_from_source_en.md
build_from_source_en.rst
.. warning::
One caveat with this approach is that developers will have to download, compile and install all third-party dependencies. Thus this process of installation is more time consuming.
FAQ
++++++++++
-----------
For any problems during installation, please refer to the page below for answers:
:ref:`常见问题解答 <install_faq>`
If the problem still persists, you are welcome to seek assistance from the PaddlePaddle community:
`FAQ <http://www.paddlepaddle.org/docs/develop/documentation/zh/faq/build_and_install/index_en.html>`_
`创建issue <https://github.com/PaddlePaddle/Paddle/issues/new>`_
......@@ -65,39 +65,55 @@ PaddlePaddle.org工具可以配合Docker使用,需要在系统里先安装好D
不使用PaddlePaddle.org工具
--------------------------
使用Docker构建PaddlePaddle的文档,需要在系统里先安装好Docker工具包。Docker安装请参考 `Docker的官网 <https://docs.docker.com/>`_ 。安装好Docker之后可以使用源码目录下的脚本构建文档,即
使用Docker构建PaddlePaddle的文档,需要在系统里先安装好Docker工具包。Docker安装请参考 `Docker的官网 <https://docs.docker.com/>`_ 。该方法与 `从源码编译PaddlePaddle <http://paddlepaddle.org/docs/develop/documentation/zh/build_and_install/build_from_source_cn.html>`_ 相似,通过从源码中构建可用于编译PaddlePaddle文档的Docker镜像并运行,在进入Docker容器后使用源码中的脚本构建PaddlePaddle文档,具体步骤如下:
[TBD]
.. code-block:: bash
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
# 从源码中构建可用于编译PaddlePaddle文档的Docker镜像
docker build -t paddle:dev .
docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" -e "WITH_DOC=ON" paddle:dev /bin/bash
# 进入Docker容器后使用build.sh脚本构建PaddlePaddle文档
bash -x /paddle/paddle/scripts/docker/build.sh
注:上述命令把当前目录(源码根目录)映射为 container 里的 :code:`/paddle` 目录。
编译完成后,会产生 ``doc/v2`` 和 ``doc/fluid`` 两个目录,在这两个目录下分别都生成 ``cn/html/`` 、 ``en/html`` 、 ``api/en/html`` 共三个子目录,分别进入这些目录下,执行以下命令:
.. code-block:: bash
python -m SimpleHTTPServer 8088
在浏览器中输入 http://localhost:8088 就可以看到编译生成的 ``v2`` 和 ``fluid`` 两种版本的中/英文的文档页面和英文的API页面。
如果不想使用Docker,也可以使用以下命令直接构建PaddlePaddle文档,即
.. code-block:: bash
mkdir paddle
cd paddle
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON
# 如果只需要构建使用文档,则执行以下命令
make -j $processors gen_proto_py
make -j $processors paddle_docs paddle_docs_cn
make -j $processors paddle_docs
# 如果只需要构建API,则执行以下命令
make -j $processors gen_proto_py framework_py_proto
make -j $processors copy_paddle_pybind
make -j $processors paddle_api_docs
make -j $processors paddle_apis
其中$processors代表启动和CPU核一样多的进程来并行编译,可以根据本机的CPU核数设置相应的值。
编译完成后,进入 ``doc/v2`` 目录,如果选择构建文档则会在该目录下生成 ``cn/html/`` 、 ``en/html`` 两个子目录,选择构建API则会生成 ``api/en/html`` 目录,分别进入这些目录下,执行以下命令:
编译完成后,同样会产生 ``doc/v2`` 和 ``doc/fluid`` 两个目录,如果选择构建文档则会在这两个目录下分别都生成 ``cn/html/`` 、 ``en/html`` 两个子目录,选择构建API则会在这两个目录下分别生成 ``api/en/html`` 目录,分别进入这些子目录下,执行以下命令:
.. code-block:: bash
python -m SimpleHTTPServer 8088
在浏览器中输入 http://localhost:8088 就可以看到编译生成的中/英文的文档页面和英文的API页面,下图为生成的英文文档首页示例。注意,示例中由于使用了sphinx的原始主题,所以页面的风格与官网并不一致,但这并不影响开发者进行调试。
在浏览器中输入 http://localhost:8088 就可以看到编译生成的 ``v2`` 和 ``fluid`` 两种版本的中/英文的文档页面和英文的API页面。下图为生成的 ``v2`` 英文文档首页示例。注意,示例中由于使用了sphinx的原始主题,所以页面的风格与官网并不一致,但这并不影响开发者进行调试。
.. image:: src/doc_en.png
:align: center
......
......@@ -68,39 +68,56 @@ Please `click here <https://github.com/PaddlePaddle/PaddlePaddle.org/blob/develo
Manually Building the Documentation
-------------------------------------
Build PaddlePaddle's documentation with Docker,you need to install Docker first. Please refer to `Docker's official website <https://docs.docker.com/>`_ on how to install Docker. After Docker is installed, you could use the scripts in the source directory to build the documentation.
Build PaddlePaddle's documentation with Docker,you need to install Docker first. Please refer to `Docker's official website <https://docs.docker.com/>`_ on how to install Docker. This method is quite similar to ` Build From Sources <http://paddlepaddle.org/docs/develop/documentation/en/build_and_install/build_from_source_en.html>`_ , by constructing, from source code, a docker image that can be used to build PaddlePaddle documentation. Enter the Docker container and use the script ``build.sh`` in the source directory to build the PaddlePaddle documentation. The specific steps are as follows:
[TBD]
.. code-block:: bash
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
# Construct a docker image from source code
docker build -t paddle:dev .
docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" -e "WITH_DOC=ON" paddle:dev /bin/bash
# Use build.sh to build PaddlePaddle documentation
bash -x /paddle/paddle/scripts/docker/build.sh
Note: The above commands maps the current directory (source root directory) to the :code:`/paddle` directory in the container.
After compiling, there should be two generated directories: ``doc/v2`` and ``doc/fluid``, where three subdirectories ``cn/html/``, ``en/html`` and ``api/en/html`` are generated. Please enter these directories respectively and execute the following commands:
.. code-block:: bash
python -m SimpleHTTPServer 8088
Use a web browser and navigate to http://localhost:8000, you could see the compiled ``v2`` 's and ``fluid`` 's Chinese/English documents page and English APIs page.
If you do not wish to use Docker, you can also use the following commands to directly build the PaddlePaddle documentation.
.. code-block:: bash
mkdir paddle
cd paddle
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON
# If you only need to build documents, use the following commands
make -j $processors gen_proto_py
make -j $processors paddle_docs paddle_docs_cn
make -j $processors paddle_docs
# If you only need to build APIs, use the following commands
make -j $processors gen_proto_py framework_py_proto
make -j $processors copy_paddle_pybind
make -j $processors paddle_api_docs
make -j $processors paddle_apis
$processors indicates that as many processes as the CPU cores are started to compile in parallel. It should be set according to the number of CPU cores of your machine.
After the compilation is complete, enter the ``doc/v2`` directory. If you chose to build documents, it will generate ``cn/html/`` and ``en/html`` subdirectories under this directory. If you chose to build APIs,it will generate``api/en/html`` subdirectory. Please enter these directories respectively and execute the following commands:
After compiling, there also should be two generated directories: ``doc/v2`` and ``doc/fluid`` . If you chose to build documents, two subdirectories ``cn/html/`` and ``en/html`` will be generated in both two directories. If you chose to build APIs,a subdirectory ``api/en/html`` will be generated. Please enter these directories respectively and execute the following commands:
.. code-block:: bash
python -m SimpleHTTPServer 8088
Use a web browser and navigate to http://localhost:8000, you could see the compiled Chinese/English documents page and the English APIs page. The following figure is an example of the built English documents home page. Note that due to the sphinx's original theme used in the example, the style of the page is not consistent with the official website, but this does not affect the developer's debugging.
Use a web browser and navigate to http://localhost:8000, you could see the compiled ``v2`` 's and ``fluid`` 's Chinese/English documents page and English APIs page. The following figure is an example of the built ``v2`` 's English documents home page. Note that due to the sphinx's original theme used in the example, the style of the page is not consistent with the official website, but this does not affect the developer's debugging.
.. image:: src/doc_en.png
:align: center
......
## Install and Build
TBD
### Download & Install
Download the latest C-API development package from CI system and install. You can find the required version in the table below:
<table>
<thead>
<tr>
<th>Version Tips</th>
<th>C-API</th>
</tr>
</thead>
<tbody>
<tr>
<td>cpu_avx_mkl</td>
<td><a href="https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddle.tgz" rel="nofollow">paddle.tgz</a></td>
</tr>
<tr>
<td>cpu_avx_openblas</td>
<td>-</td>
</tr>
<tr>
<td>cpu_noavx_openblas</td>
<td><a href="https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddle.tgz" rel="nofollow">paddle.tgz</a></td>
</tr>
<tr>
<td>cuda7.5_cudnn5_avx_mkl</td>
<td><a href="https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddle.tgz" rel="nofollow">paddle.tgz</a></td>
</tr>
<tr>
<td>cuda8.0_cudnn5_avx_mkl</td>
<td><a href="https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddle.tgz" rel="nofollow">paddle.tgz</a></td>
</tr>
<tr>
<td>cuda8.0_cudnn7_avx_mkl</td>
<td><a href="https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddle.tgz" rel="nofollow">paddle.tgz</a></td>
</tr></tbody></table>
### From source
Users can also compile the C-API library from PaddlePaddle source code by compiling with the following compilation options:
<table>
<thead>
<tr>
<th>Options</th>
<th>Value</th>
</tr>
</thead>
<tbody>
<tr>
<td>WITH_C_API</td>
<td>ON</td>
</tr>
<tr>
<td>WITH_PYTHON</td>
<td>OFF(recommended)</td>
</tr>
<tr>
<td>WITH_SWIG_PY</td>
<td>OFF(recommended)</td>
</tr>
<tr>
<td>WITH_GOLANG</td>
<td>OFF(recommended)</td>
</tr>
<tr>
<td>WITH_GPU</td>
<td>ON/OFF</td>
</tr>
<tr>
<td>WITH_MKL</td>
<td>ON/OFF</td>
</tr></tbody></table>
It is best to set up with recommended values to avoid linking with unnecessary libraries. Set other compilation options as you need.
Pull the latest following code snippet from github, and configure compilation options(replace PADDLE_ROOT with the installation path of the PaddlePaddle C-API inference library):
```shell
PADDLE_ROOT=/path/of/capi
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir build
cd build
cmake -DCMAKE_INSTALL_PREFIX=$PADDLE_ROOT \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_GOLANG=OFF \
-DWITH_PYTHON=OFF \
-DWITH_MKL=OFF \
-DWITH_GPU=OFF \
..
```
After running the above code to generate Makefile , run: `make && make install`. After successful compilation, the dependencies required by C-API(includes: (1)PaddlePaddle inference library and header files; (2) Third-party libraries and header files) will be stored in the `PADDLE_ROOT` directory.
If the compilation is successful, see the following directory structure under `PADDLE_ROOT`(includes PaddlePaddle header files and libraries, and third-party libraries and header files(determined by the link methods if necessary)):
```text
├── include
│   └── paddle
│   ├── arguments.h
│   ├── capi.h
│   ├── capi_private.h
│   ├── config.h
│   ├── error.h
│   ├── gradient_machine.h
│   ├── main.h
│   ├── matrix.h
│   ├── paddle_capi.map
│   └── vector.h
├── lib
│   ├── libpaddle_capi_engine.a
│   ├── libpaddle_capi_layers.a
│   ├── libpaddle_capi_shared.so
│   └── libpaddle_capi_whole.a
└── third_party
├── gflags
│   ├── include
│   │   └── gflags
│   │   ├── gflags_completions.h
│   │   ├── gflags_declare.h
│   │   ...
│   └── lib
│   └── libgflags.a
├── glog
│   ├── include
│   │   └── glog
│   │   ├── config.h
│   │   ...
│   └── lib
│   └── libglog.a
├── openblas
│   ├── include
│   │   ├── cblas.h
│   │   ...
│   └── lib
│   ...
├── protobuf
│   ├── include
│   │   └── google
│   │   └── protobuf
│   │   ...
│   └── lib
│   └── libprotobuf-lite.a
└── zlib
├── include
│   ...
└── lib
...
```
### Linking Description:
There are three kinds of linking methods:
1. Linking with dynamic library `libpaddle_capi_shared.so`(This way is much more convenient and easier, **Without special requirements, it is recommended**), refer to the following:
1. Compiling with CPU version and using `OpenBLAS`; only need to link one library named `libpaddle_capi_shared.so` to develop prediction program through C-API.
1. Compiling with CPU version and using `MKL` lib, you need to link MKL library directly to develop prediction program through PaddlePaddle C-API, due to `MKL` has its own dynamic library.
1. Compiling with GPU version, CUDA library will be loaded dynamically on prediction program run-time, and also set CUDA library to  `LD_LIBRARY_PATH` environment variable.
2. Linking with static library `libpaddle_capi_whole.a`,refer to the following:
1. Specify `-Wl,--whole-archive` linking options.
1. Explicitly link third-party libraries such as `gflags``glog``libz``protobuf` .etc, you can find them under `PADDLE_ROOT/third_party` directory.
1. Use OpenBLAS library if compiling C-API,must explicitly link `libopenblas.a`.
1. Use MKL when compiling C-API, must explicitly link MKL dynamic library.
3. Linking with static library `libpaddle_capi_layers.a` and `libpaddle_capi_engine.a`,refer to the following:
1. This linking methods is mainly used for mobile prediction.
1. Split `libpaddle_capi_whole.a` into two static linking library at least to reduce the size of linking libraries.
1. Specify `-Wl,--whole-archive -lpaddle_capi_layers`  and `-Wl,--no-whole-archive -lpaddle_capi_engine` for linking.
1. The third-party dependencies need explicitly link same as method 2 above.
# Kubernetes Distributed
# Distributed Training on Kubernetes
TBD
We introduced how to create a PaddlePaddle Job with a single node on Kuberentes in the
previous document.
In this article, we will introduce how to create a PaddlePaddle job with multiple nodes
on Kubernetes cluster.
## Overall Architecture
Before creating a training job, the users need to slice the training data and deploy
the Python scripts along with it into the distributed file system
(We can use the different type of Kuberentes Volumes to mount different distributed
file systems). Before training starts, The program will copy the training data into the
Container and also save the models at the same path during training. The global architecture
is as follows:
![PaddlePaddle on Kubernetes Architecture](src/k8s-paddle-arch.png)
The above figure describes a distributed training architecture which contains 3 nodes, each
Pod mounts a folder of the distributed file system to save training data and models
by Kubernetes Volume. Kubernetes created 3 Pods for this training phase and scheduled these on
3 nodes, each Pod has a PaddlePaddle container. After the containers car created,
PaddlePaddle starts up the communication between PServer and Trainer and read training
data for this training job.
As the description above, we can start up a PaddlePaddle distributed training job on a
Kubernetes ready cluster with the following steps:
1. [Build PaddlePaddle Docker Image](#Build a Docker Image)
1. [Split training data and upload to the distributed file system](#Upload Training Data)
1. [Edit a YAML file and create a Kubernetes Job](#Create a Job)
1. [Check the output](#Check The Output)
We will introduce these steps as follows:
### Build a Docker Image
Training docker image needs to package the paddle pserver and paddle trainer runtimes, as well as two more processes before we can kick off the training:
- Copying the training data into container.
- Generating the initialization arguments for `Paddle PServer` and `Paddle Training` processes.
Since the paddlepaddle official docker image already has the runtimes we need, we'll take it as the base image and pack some additional scripts for the processes mentioned above to build our training image. for more detail, please find from the following link:
- https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/usage/cluster/src/k8s_train/Dockerfile
```bash
$ cd doc/howto/usage/k8s/src/k8s_train
$ docker build -t [YOUR_REPO]/paddle:mypaddle .
```
And then upload the new Docker Image to a Docker hub:
```bash
docker push [YOUR_REPO]/paddle:mypaddle
```
**[NOTE]**, in the above command arguments, `[YOUR_REPO]` represents your Docker repository,
you need to use your repository instead of it. We will replace it with your respository name to
represent the Docker Image which built in this step.
### Prepare Training Data
We can download and split the training job by creating a Kubernetes Job, or custom your image
by editing [k8s_train](./src/k8s_train/).
Before creating a Job, we need to bind a [persistenVolumeClaim](https://kubernetes.io/docs/user-guide/persistent-volumes) by the different type of
the different file system, the generated dataset would be saved on this volume.
```yaml
apiVersion: batch/v1
kind: Job
metadata:
name: paddle-data
spec:
template:
metadata:
name: pi
spec:
hostNetwork: true
containers:
- name: paddle-data
image: paddlepaddle/paddle-tutorial:k8s_data
imagePullPolicy: Always
volumeMounts:
- mountPath: "/mnt"
name: nfs
env:
- name: OUT_DIR
value: /home/work/mfs/paddle-cluster-job
- name: SPLIT_COUNT
value: "3"
volumes:
- name: nfs
persistentVolumeClaim:
claimName: mfs
restartPolicy: Never
```
Create the Job with the following command:
```bash
> kubectl create -f xxx.yaml
```
If created successfully, you can see some information like this:
```base
[root@paddle-kubernetes-node0 nfsdir]$ tree -d
.
`-- paddle-cluster-job
|-- 0
| `-- data
|-- 1
| `-- data
|-- 2
| `-- data
|-- output
|-- quick_start
```
The `paddle-cluster-job` above is the job name for this training job; we need 3
PaddlePaddle training nodes and save the split training data in `paddle-cluster-job` path,
the folder `0`, `1` and `2` represents the `training_id` on each node, `quick_start` folder is used to store training data, `output` folder is used to store the models and logs.
### Create a Job
Kubernetes allow users to create objects with YAML files, and we can use a command-line tool
to create it.
The Job YAML file describes that which Docker Image would be used in this training job, how much nodes would be created, what's the startup arguments of `Paddle PServer/Trainer` process and what's the type of Volumes. You can find the details of the YAML filed in
[Kubernetes Job API](http://kubernetes.io/docs/api-reference/batch/v1/definitions/#_v1_job).
The following is an example for this training job:
```yaml
apiVersion: batch/v1
kind: Job
metadata:
name: paddle-cluster-job
spec:
parallelism: 3
completions: 3
template:
metadata:
name: paddle-cluster-job
spec:
volumes:
- name: jobpath
hostPath:
path: /home/work/mfs
containers:
- name: trainer
image: [YOUR_REPO]/paddle:mypaddle
command: ["bin/bash", "-c", "/root/start.sh"]
env:
- name: JOB_NAME
value: paddle-cluster-job
- name: JOB_PATH
value: /home/jobpath
- name: JOB_NAMESPACE
value: default
- name: TRAIN_CONFIG_DIR
value: recommendation
- name: CONF_PADDLE_NIC
value: eth0
- name: CONF_PADDLE_PORT
value: "7164"
- name: CONF_PADDLE_PORTS_NUM
value: "2"
- name: CONF_PADDLE_PORTS_NUM_SPARSE
value: "2"
- name: CONF_PADDLE_GRADIENT_NUM
value: "3"
volumeMounts:
- name: jobpath
mountPath: /home/jobpath
restartPolicy: Never
```
In the above YAML file:
- `metadata.name`, The job name.
- `parallelism`, Whether the Kubernetes Job would create `parallelism` Pods at the same time.
- `completions`, The Job would become the success status only when the number of successful Pod(the exit code is 0)
is equal to `completions`.
- `volumeMounts`, the name field `jobpath` is a key, the `mountPath` field represents
the path in the container, and we can define the `jobpath` in `volumes` filed, use `hostPath`
to configure the host path we want to mount.
- `env`, the environment variables in the Container, we pass some startup arguments by
this approach, some details are as following:
- JOB_PATH:the mount path in the container
- JOB_NAME:the job name
- TRAIN_CONFIG_DIR:the job path in the container, we can find the training data path by
combine with JOB_NAME.
- CONF_PADDLE_NIC: the argument `--nics` of `Paddle PServer` process, the network
device name.
- CONF_PADDLE_PORT: the argument `--port` of `Paddle PServer` process.
- CONF_PADDLE_PORTS_NUM: the argument `--ports_num` of `Paddle PServer`, the port number
for dense prameter update.
- CONF_PADDLE_PORTS_NUM_SPARSE:the argument `--ports_num_for_sparse` of `Paddle PServer`,
the port number for sparse parameter update.
- CONF_PADDLE_GRADIENT_NUM:the number of training node, the argument
`--num_gradient_servers` of `Paddle PServer` and `Paddle Trainer`.
You can find some details information at [here]
(http://www.paddlepaddle.org/docs/develop/documentation/zh/howto/usage/cmd_parameter/detail_introduction_cn.html)。
We can use the command-line tool of Kubernetes to create a Job when we finish the YAML file:
```bash
kubectl create -f job.yaml
```
Upon successful creation, Kubernetes would create 3 Pods as PaddlePaddle training node,
pull the Docker image and begin to train.
### Checkout the Output
At the process of training, we can check the logs and the output models which is stored in
the `output` folder.
**NOTE**, `node_0`, `node_1` and `node_2` represent the
`trainer_id` of the PaddlePaddle training job rather than the node id of Kubernetes.
```bash
[root@paddle-kubernetes-node0 output]# tree -d
.
├── node_0
│   ├── server.log
│   └── train.log
├── node_1
│   ├── server.log
│   └── train.log
├── node_2
......
├── pass-00002
│   ├── done
│   ├── ___embedding_0__.w0
│   ├── ___embedding_1__.w0
......
```
We can checkout the status of each training Pod by viewing the logs:
```bash
[root@paddle-kubernetes-node0 node_0]# cat train.log
I1116 09:10:17.123121 50 Util.cpp:155] commandline:
/usr/local/bin/../opt/paddle/bin/paddle_trainer
--nics=eth0 --port=7164
--ports_num=2 --comment=paddle_process_by_paddle
--pservers=192.168.129.66,192.168.223.143,192.168.129.71
--ports_num_for_sparse=2 --config=./trainer_config.py
--trainer_count=4 --num_passes=10 --use_gpu=0
--log_period=50 --dot_period=10 --saving_period=1
--local=0 --trainer_id=0
--save_dir=/home/jobpath/paddle-cluster-job/output
I1116 09:10:17.123440 50 Util.cpp:130] Calling runInitFunctions
I1116 09:10:17.123764 50 Util.cpp:143] Call runInitFunctions done.
[WARNING 2016-11-16 09:10:17,227 default_decorators.py:40] please use keyword arguments in paddle config.
[INFO 2016-11-16 09:10:17,239 networks.py:1282] The input order is [movie_id, title, genres, user_id, gender, age, occupation, rating]
[INFO 2016-11-16 09:10:17,239 networks.py:1289] The output order is [__square_error_cost_0__]
I1116 09:10:17.392917 50 Trainer.cpp:170] trainer mode: Normal
I1116 09:10:17.613910 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process
I1116 09:10:17.680917 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process
I1116 09:10:17.681543 50 GradientMachine.cpp:134] Initing parameters..
I1116 09:10:18.012390 50 GradientMachine.cpp:141] Init parameters done.
I1116 09:10:18.018641 50 ParameterClient2.cpp:122] pserver 0 192.168.129.66:7164
I1116 09:10:18.018950 50 ParameterClient2.cpp:122] pserver 1 192.168.129.66:7165
I1116 09:10:18.019069 50 ParameterClient2.cpp:122] pserver 2 192.168.223.143:7164
I1116 09:10:18.019492 50 ParameterClient2.cpp:122] pserver 3 192.168.223.143:7165
I1116 09:10:18.019716 50 ParameterClient2.cpp:122] pserver 4 192.168.129.71:7164
I1116 09:10:18.019836 50 ParameterClient2.cpp:122] pserver 5 192.168.129.71:7165
```
## Some Additional Details
### Using Environment Variables
Usually we use the environment varialbes to configurate the PaddlePaddle Job which runs in
Kubernetes, `start_paddle.py` provides a start up script to convert the environment variable
to the start up arguments of PaddlePaddle process:
```bash
API = "/api/v1/namespaces/"
JOBSELECTOR = "labelSelector=job-name="
JOB_PATH = os.getenv("JOB_PATH") + "/" + os.getenv("JOB_NAME")
JOB_PATH_OUTPUT = JOB_PATH + "/output"
JOBNAME = os.getenv("JOB_NAME")
NAMESPACE = os.getenv("JOB_NAMESPACE")
PADDLE_NIC = os.getenv("CONF_PADDLE_NIC")
PADDLE_PORT = os.getenv("CONF_PADDLE_PORT")
PADDLE_PORTS_NUM = os.getenv("CONF_PADDLE_PORTS_NUM")
PADDLE_PORTS_NUM_SPARSE = os.getenv("CONF_PADDLE_PORTS_NUM_SPARSE")
PADDLE_SERVER_NUM = os.getenv("CONF_PADDLE_GRADIENT_NUM")
```
### Communication between Pods
At the begin of `start_paddle.py`, it would initializes and parses the arguments.
```python
parser = argparse.ArgumentParser(prog="start_paddle.py",
description='simple tool for k8s')
args, train_args_list = parser.parse_known_args()
train_args = refine_unknown_args(train_args_list)
train_args_dict = dict(zip(train_args[:-1:2], train_args[1::2]))
podlist = getPodList()
```
And then query the status of all the other Pods of this Job by the function `getPodList()`, and fetch `triner_id` by the function `getIdMap(podlist)` if all the Pods status is `RUNNING`.
```python
podlist = getPodList()
# need to wait until all pods are running
while not isPodAllRunning(podlist):
time.sleep(10)
podlist = getPodList()
idMap = getIdMap(podlist)
```
**NOTE**: `getPodList()` would prefetch all the Pods in the current namespace, if some
Pods are alreay running, it may cause some error. We will use [statfulesets](https://kubernetes.io/docs/concepts/abstractions/controllers/statefulsets) instead of
Kubernetes Pod or Replicaset in the future.
The function `getIdMap(podlist)` fetches IPs addresses of `podlist` and then sort them
to generate `trainer_id`.
```python
def getIdMap(podlist):
'''
generate tainer_id by ip
'''
ips = []
for pod in podlist["items"]:
ips.append(pod["status"]["podIP"])
ips.sort()
idMap = {}
for i in range(len(ips)):
idMap[ips[i]] = i
return idMap
```
After getting the `idMap`, we can generate the arguments of `Paddle PServer` and `Paddle Trainer`
so that we can start up them by `startPaddle(idMap, train_args_dict)`.
### Create Job
The main goal of `startPaddle` is generating the arguments of `Paddle PServer` and
`Paddle Trainer` processes. Take `Paddle Trainer` as an example, we parse the
environment variable and then get `PADDLE_NIC`, `PADDLE_PORT`, `PADDLE_PORTS_NUM` and etc...,
finally find `trainerId` from `idMap` according to its IP address.
```python
program = 'paddle train'
args = " --nics=" + PADDLE_NIC
args += " --port=" + str(PADDLE_PORT)
args += " --ports_num=" + str(PADDLE_PORTS_NUM)
args += " --comment=" + "paddle_process_by_paddle"
ip_string = ""
for ip in idMap.keys():
ip_string += (ip + ",")
ip_string = ip_string.rstrip(",")
args += " --pservers=" + ip_string
args_ext = ""
for key, value in train_args_dict.items():
args_ext += (' --' + key + '=' + value)
localIP = socket.gethostbyname(socket.gethostname())
trainerId = idMap[localIP]
args += " " + args_ext + " --trainer_id=" + \
str(trainerId) + " --save_dir=" + JOB_PATH_OUTPUT
```
.timestamp
*.o
*.a
.svn
......
......@@ -7,9 +7,9 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
if(WITH_GPU)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto)
else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto)
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto)
endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
......@@ -21,9 +21,9 @@ endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init)
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place memory device_context init)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
......
......@@ -13,11 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include <queue>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include <queue>
namespace paddle {
namespace framework {
......@@ -147,52 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op,
std::vector<std::string> &v) {
auto in_names = (*op)->InputArgumentNames();
v.insert(v.end(), in_names.begin(), in_names.end());
auto out_names = (*op)->OutputArgumentNames();
v.insert(v.end(), out_names.begin(), out_names.end());
std::sort(v.begin(), v.end());
auto last = std::unique(v.begin(), v.end());
v.erase(last, v.end());
};
need_update_ = true;
for (size_t i = s; i < e; i++) {
// since remove op one by one, every time remove the first op.
auto op = ops_.begin() + s;
// collect input and output variables from current delete op
std::vector<std::string> cur_vars;
get_vars(op, cur_vars);
// remove current op
ops_.erase(ops_.begin() + s);
// collect input and output variables from other ops
std::vector<std::string> other_vars;
for (auto it = ops_.begin(); it != ops_.end(); it++) {
get_vars(it, other_vars);
}
// variables should be deleted
std::vector<std::string> delete_vars;
// delete_vars = cur_vars - cur_vars ^ other_input_vars
std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
other_vars.end(),
std::inserter(delete_vars, delete_vars.end()));
// remove variables
for (size_t i = 0; i < delete_vars.size(); i++) {
auto name = delete_vars[i];
auto it = vars_.find(name);
PADDLE_ENFORCE(it != vars_.end(),
"%s is not in variable list, it should not be deleted",
name);
vars_.erase(it);
VLOG(3) << "deleting variable " << name;
}
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}
std::vector<OpDesc *> BlockDesc::AllOps() const {
......
......@@ -105,7 +105,7 @@ static void BuildVar(const std::string& param_name,
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
InitDevices();
InitDevices(true);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
......
......@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......@@ -15,7 +16,7 @@ else()
set(multi_devices_graph_builder_deps)
endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
......@@ -54,12 +55,37 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
result.vars_.resize(places_.size());
// We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
......@@ -72,27 +98,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// append send op if program is distributed trainer main program.
// always use the first device
if (!is_forwarding && op->Type() == "send") {
auto &p = places_[0];
auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0);
continue;
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
CreateOpHandleIOs(&result, op, p, i);
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
auto var_names = op->OutputArgumentNames();
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
......@@ -147,15 +174,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad.get());
auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
vars.emplace_back(new VarHandle);
auto &var = vars.back();
var->place_ = p;
var->name_ = og;
var->version_ = vars.size() - 1;
op_handle->AddOutput(&var);
op_handle->AddOutput(var.get());
}
#else
PADDLE_ENFORCE("Not implemented");
......
......@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
......@@ -41,6 +44,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
op_->Run(*local_scope_, place_);
}
std::string SendOpHandle::Name() const { return "send"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -16,6 +16,8 @@
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
......@@ -24,7 +26,9 @@ namespace framework {
namespace details {
struct SSAGraph {
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
......
......@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *write_op = (*it_new)->generated_op_;
auto &read_ops = (*it_old)->pending_ops_;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
......@@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle);
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
init_var->place_ = place;
init_var->name_ = each_var_name;
init_var->generated_op_ = nullptr;
init_var->version_ = 0;
var = init_var.get();
} else {
var = &var_holder.rbegin()->second;
var = var_holder.rbegin()->get();
}
return var;
}
......@@ -72,11 +73,12 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
vars.emplace_back(new VarHandle());
auto &var = vars.back();
var->version_ = version;
var->name_ = each_var_name;
var->place_ = place;
op_handle->AddOutput(var.get());
}
template <typename Callback>
......@@ -84,7 +86,7 @@ void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(pair2.second);
callback(*pair2);
}
}
}
......
......@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second);
InsertPendingVar(*version_pair);
}
}
}
......@@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
......
......@@ -93,6 +93,43 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name);
}
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
auto& global_block = pdesc.Block(block_id);
const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
......@@ -188,8 +225,8 @@ static bool has_fetch_operators(
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name, bool create_vars) {
bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
......@@ -282,38 +319,13 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);
Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
}
CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
}
for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
......
......@@ -54,9 +54,9 @@ class Executor {
void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
const std::string& fetch_holder_name = "fetch");
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
......@@ -64,6 +64,8 @@ class Executor {
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
bool create_vars = true);
......
......@@ -64,7 +64,7 @@ void InitP2P(int count) {
#endif
}
void InitDevices() {
void InitDevices(bool init_p2p) {
/*Init all avaiable devices by default */
std::vector<platform::Place> places;
......@@ -85,7 +85,9 @@ void InitDevices() {
for (int i = 0; i < count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
InitP2P(count);
if (init_p2p) {
InitP2P(count);
}
platform::DeviceContextPool::Init(places);
}
......
......@@ -24,7 +24,7 @@ void InitGflags(std::vector<std::string> &argv);
void InitGLOG(const std::string &prog_name);
void InitDevices();
void InitDevices(bool init_p2p);
} // namespace framework
} // namespace paddle
......@@ -21,7 +21,7 @@ TEST(InitDevices, CPU) {
using paddle::platform::DeviceContextPool;
#ifndef PADDLE_WITH_CUDA
InitDevices();
InitDevices(true);
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U);
#endif
......@@ -33,7 +33,7 @@ TEST(InitDevices, CUDA) {
#ifdef PADDLE_WITH_CUDA
int count = paddle::platform::GetCUDADeviceCount();
InitDevices();
InitDevices(true);
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U + static_cast<unsigned>(count));
#endif
......
......@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace paddle {
namespace framework {
......@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
}
void WriteToRecordIO(recordio::Writer &writer,
void WriteToRecordIO(recordio::Writer *writer,
const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) {
std::stringstream buffer;
......@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx);
}
writer.Write(buffer.str());
writer->Write(buffer.str());
}
std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) {
std::vector<LoDTensor> result;
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
if (scanner->HasNext()) {
std::istringstream sin(scanner->Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
}
return result;
}
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
......@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer,
extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx);
recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework
} // namespace paddle
......@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace framework {
......@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{
recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(&writer, {tensor, tensor}, ctx);
WriteToRecordIO(&writer, {tensor, tensor}, ctx);
writer.Flush();
}
......@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx);
auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx);
tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
......
......@@ -30,7 +30,7 @@ __global__ void test(size_t* a, int size) {
}
TEST(LoD, data) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
......@@ -46,7 +46,7 @@ TEST(LoD, data) {
}
TEST(LoDTensor, LoDInGPU) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::LoDTensor lod_tensor;
paddle::platform::CUDAPlace place(0);
......
......@@ -72,7 +72,7 @@ REGISTER_OP_WITHOUT_GRADIENT(test_operator,
paddle::framework::OpWithoutKernelCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
......@@ -198,7 +198,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
......@@ -228,7 +228,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
......@@ -269,7 +269,7 @@ class OperatorClone : public paddle::framework::OperatorBase {
};
TEST(Operator, Clone) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include <vector>
......@@ -24,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -43,30 +43,40 @@ class ParallelExecutorPrivate {
#endif
};
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_;
}
ParallelExecutor::ParallelExecutor(
size_t num_threads, bool use_event,
const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope, bool allow_op_delay)
const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
// Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]);
exe.Run(startup_program, scope, 0);
// Step 1. Bcast the params to devs.
// Create local scopes
for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.push_back(&scope->NewScope());
if (local_scopes.empty()) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.push_back(&scope->NewScope());
}
} else {
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.push_back(local_scopes[i]);
}
}
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
#endif
if (platform::is_gpu_place(places[0]) &&
member_->local_scopes_.size() != 1) { // Is CUDA
BCastParamsToGPUs(startup_program);
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA
BCastParamsToGPUs(bcast_vars);
}
// Startup Program has been run. All local scopes has correct parameters.
......@@ -99,48 +109,45 @@ ParallelExecutor::ParallelExecutor(
}
void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const {
const std::unordered_set<std::string> &vars) const {
#ifdef PADDLE_WITH_CUDA
auto *main_scope = member_->local_scopes_[0];
for (auto *var_desc : startup_program.Block(0).AllVars()) {
size_t idx = var_desc->Name().find("@GRAD");
if (idx != std::string::npos) continue;
if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
auto &main_tensor =
main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i];
void *buffer;
if (i == 0) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[i];
auto *t =
local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims);
buffer = t->mutable_data(place, main_tensor.type());
}
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
} else {
platform::CPUPlace cpu;
for (size_t i = 1; i < member_->places_.size(); ++i) {
for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var);
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i];
void *buffer;
if (i == 0) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
t->Resize(dims);
t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t);
buffer = t->mutable_data(place, main_tensor.type());
}
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
} else {
platform::CPUPlace cpu;
for (size_t i = 1; i < member_->places_.size(); ++i) {
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
t->Resize(dims);
t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t);
}
}
member_->nccl_ctxs_->WaitAll();
......@@ -165,12 +172,17 @@ void ParallelExecutor::SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
for (auto it : feed_tensors) {
auto lod_tensors = it.second.SplitLoDTensor(member_->places_);
PADDLE_ENFORCE_EQ(
member_->places_.size(), lod_tensors.size(),
"The number of samples of current batch is less than the count of "
"devices, currently, it is not allowed. (%d vs %d)",
member_->places_.size(), lod_tensors.size());
for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var?
member_->local_scopes_[j]
->Var(it.first)
->GetMutable<LoDTensor>()
->ShareDataWith(lod_tensors[j]);
auto t =
member_->local_scopes_[j]->Var(it.first)->GetMutable<LoDTensor>();
t->ShareDataWith(lod_tensors[j]);
t->set_lod(lod_tensors[j].lod());
}
}
}
......
......@@ -36,22 +36,25 @@ class ParallelExecutor {
explicit ParallelExecutor(size_t num_threads, bool use_event,
const std::vector<platform::Place>& places,
const std::unordered_set<std::string>& params,
const ProgramDesc& startup_program,
const std::unordered_set<std::string>& bcast_vars,
const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes,
bool allow_op_delay);
std::vector<Scope*>& GetLocalScopes();
void Run(const std::vector<std::string>& fetch_tensors,
const std::string& fetched_var_name,
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
private:
void SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const ProgramDesc& startup_program) const;
};
} // namespace framework
......
......@@ -85,9 +85,9 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
}
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
BlockDesc *global_block = blocks_[0].get();
auto &global_block = Block(0);
std::vector<std::string> feed_target_names;
for (auto *op : global_block->AllOps()) {
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]);
}
......@@ -96,9 +96,9 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
}
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
BlockDesc *global_block = blocks_[0].get();
auto &global_block = Block(0);
std::vector<std::string> fetch_target_names;
for (auto *op : global_block->AllOps()) {
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
fetch_target_names.push_back(op->Input("X")[0]);
}
......@@ -106,5 +106,43 @@ const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
return fetch_target_names;
}
void ProgramDesc::SetFeedHolderName(const std::string &feed_holder_name) {
auto *global_block = MutableBlock(0);
int index = 0;
for (auto *op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
// Unify the input's name of all feed_ops to feed_holder_name
global_block->RemoveVar(op->Input("X")[0]);
op->SetInput("X", {feed_holder_name});
op->SetAttr("col", {index});
op->CheckAttrs();
index++;
}
}
auto *feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
feed_holder->SetPersistable(true);
}
void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
auto *global_block = MutableBlock(0);
int index = 0;
for (auto *op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
// Unify the output's name of all fetch_ops to fetch_holder_name
global_block->RemoveVar(op->Output("Out")[0]);
op->SetOutput("Out", {fetch_holder_name});
op->SetAttr("col", {index});
op->CheckAttrs();
index++;
}
}
auto *fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
}
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h"
......@@ -52,9 +53,26 @@ class ProgramDesc {
proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target.
// This function is used to collect the output variable's name of all
// feed_ops.
const std::vector<std::string> GetFeedTargetNames();
// The input variable of fetch_op is referenced as fetch_target.
// This function is used to collect the input variable's name of all
// fetch_ops.
const std::vector<std::string> GetFetchTargetNames();
// The input variable of feed_op that holds input Tensor provided by users is
// referenced as feed_holder.
// This function is used to change or unify the feed_holder variables' name.
void SetFeedHolderName(const std::string &feed_holder_name);
// The output variable of fetch_op that holds output Tensor needed by users is
// referenced as fetch_holder.
// This function is used to change or unify the fetch_holder variables' name.
void SetFetchHolderName(const std::string &fetch_holder_name);
private:
proto::ProgramDesc desc_;
......
......@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
if (out->empty()) {
return;
}
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];
......
......@@ -14,14 +14,13 @@
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle {
namespace framework {
......@@ -31,8 +30,6 @@ class ReaderBase {
virtual void ReInit() = 0;
virtual bool HasNext() const = 0;
virtual ~ReaderBase();
};
......@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected:
ReaderBase* reader_;
};
......@@ -80,8 +75,6 @@ class ReaderHolder {
reader_->ReInit();
}
bool HasNext() const { return reader_->HasNext(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
......@@ -39,6 +38,7 @@ Scope::~Scope() {
}
Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this));
return *kids_.back();
}
......@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}
void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
......@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
}
}
void Scope::EraseVars(std::vector<std::string>& var_names) {
void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <list>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
......@@ -51,13 +52,13 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);
void EraseVars(std::vector<std::string>& var_names);
void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;
const Scope& parent() const { return *parent_; }
const Scope* parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
......@@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
};
} // namespace framework
} // namespace paddle
set(FLUID_CORE_MODULES proto_desc paddle_memory lod_tensor executor prune init)
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor prune init)
cc_library(paddle_fluid_api
SRCS io.cc
......
......@@ -24,7 +24,8 @@ function(inference_test TARGET_NAME)
endforeach()
endfunction(inference_test)
inference_test(fit_a_line)
# This unittest is buggy!
#inference_test(fit_a_line)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp conv)
......
......@@ -12,6 +12,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/inference/tests/test_multi_thread_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
......@@ -26,32 +27,63 @@ TEST(inference, fit_a_line) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor input;
// The second dim of the input tensor should be 13
// The input data should be >= 0
int64_t batch_size = 10;
SetupTensor<float>(&input, {batch_size, 13}, static_cast<float>(0),
static_cast<float>(10));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
for (int num_threads : {1, 2}) {
std::vector<std::vector<paddle::framework::LoDTensor*>> cpu_feeds;
cpu_feeds.resize(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto* input = new paddle::framework::LoDTensor();
// The second dim of the input tensor should be 13
// The input data should be >= 0
int64_t batch_size = 10;
SetupTensor<float>(input, {batch_size, 13}, static_cast<float>(0),
static_cast<float>(10));
cpu_feeds[i].push_back(input);
}
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
std::vector<std::vector<paddle::framework::LoDTensor*>> cpu_fetchs1;
cpu_fetchs1.resize(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto* output = new paddle::framework::LoDTensor();
cpu_fetchs1[i].push_back(output);
}
// Run inference on CPU
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
// Run inference on CPU
LOG(INFO) << "--- CPU Runs (num_threads: " << num_threads << "): ---";
if (num_threads == 1) {
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds[0],
cpu_fetchs1[0]);
} else {
TestMultiThreadInference<paddle::platform::CPUPlace>(
dirname, cpu_feeds, cpu_fetchs1, num_threads);
}
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);
std::vector<std::vector<paddle::framework::LoDTensor*>> cpu_fetchs2;
cpu_fetchs2.resize(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto* output = new paddle::framework::LoDTensor();
cpu_fetchs2[i].push_back(output);
}
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs (num_threads: " << num_threads << "): ---";
if (num_threads == 1) {
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds[0],
cpu_fetchs2[0]);
} else {
TestMultiThreadInference<paddle::platform::CUDAPlace>(
dirname, cpu_feeds, cpu_fetchs2, num_threads);
}
CheckError<float>(output1, output2);
for (int i = 0; i < num_threads; ++i) {
CheckError<float>(*cpu_fetchs1[i][0], *cpu_fetchs2[i][0]);
delete cpu_fetchs2[i][0];
}
#endif
for (int i = 0; i < num_threads; ++i) {
delete cpu_feeds[i][0];
delete cpu_fetchs1[i][0];
}
} // num_threads-loop
}
......@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace, true>(dirname, cpu_feeds,
cpu_fetchs1, FLAGS_repeat);
TestInference<paddle::platform::CPUPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
......@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace, true>(dirname, cpu_feeds,
cpu_fetchs2, FLAGS_repeat);
TestInference<paddle::platform::CUDAPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
......
......@@ -25,7 +25,8 @@ limitations under the License. */
template <typename T>
void SetupTensor(paddle::framework::LoDTensor* input,
paddle::framework::DDim dims, T lower, T upper) {
std::mt19937 rng(100); // An arbitrarily chosen but fixed seed.
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
T* input_ptr = input->mutable_data<T>(dims, paddle::platform::CPUPlace());
......@@ -88,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
template <typename Place, bool PrepareContext = false>
template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
......@@ -166,6 +167,13 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
if (!CreateVars) {
// If users don't want to create and destroy variables every time they
// run, they need to set `create_vars` to false and manually call
// `CreateVariables` before running.
executor.CreateVariables(*inference_program, scope, 0);
}
// Ignore the profiling results of the first run
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
if (PrepareContext) {
......@@ -173,7 +181,8 @@ void TestInference(const std::string& dirname,
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
// Enable the profiler
......@@ -191,7 +200,8 @@ void TestInference(const std::string& dirname,
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/io.h"
void ThreadedRunInference(
const std::unique_ptr<paddle::framework::ProgramDesc>& inference_program,
paddle::framework::Executor* executor, paddle::framework::Scope* scope,
const int thread_id,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
auto copy_program = std::unique_ptr<paddle::framework::ProgramDesc>(
new paddle::framework::ProgramDesc(*inference_program));
std::string feed_holder_name = "feed_" + paddle::string::to_string(thread_id);
std::string fetch_holder_name =
"fetch_" + paddle::string::to_string(thread_id);
copy_program->SetFeedHolderName(feed_holder_name);
copy_program->SetFetchHolderName(fetch_holder_name);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
copy_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
copy_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor->Run(*copy_program, scope, feed_targets, fetch_targets, true,
feed_holder_name, fetch_holder_name);
}
template <typename Place>
void TestMultiThreadInference(
const std::string& dirname,
const std::vector<std::vector<paddle::framework::LoDTensor*>>& cpu_feeds,
const std::vector<std::vector<paddle::framework::LoDTensor*>>& cpu_fetchs,
const int num_threads) {
// 1. Define place, executor, scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program =
paddle::inference::Load(executor, *scope, dirname);
std::vector<std::thread*> threads;
for (int i = 0; i < num_threads; ++i) {
threads.push_back(new std::thread(
ThreadedRunInference, std::ref(inference_program), &executor, scope, i,
std::ref(cpu_feeds[i]), std::ref(cpu_fetchs[i])));
}
for (int i = 0; i < num_threads; ++i) {
threads[i]->join();
delete threads[i];
}
delete scope;
}
add_subdirectory(detail)
cc_library(memory SRCS memory.cc DEPS place enforce)
cc_library(malloc SRCS malloc.cc DEPS buddy_allocator place enforce)
cc_library(memcpy SRCS memcpy.cc DEPS place)
cc_library(paddle_memory
cc_library(memory
DEPS
memory
memcpy
meta_data
meta_cache
memory_block
buddy_allocator
system_allocator)
malloc
memcpy)
cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory)
cc_test(malloc_test SRCS malloc_test.cc DEPS malloc)
#if (WITH_GPU)
# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place paddle_memory)
# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory)
#endif()
cc_library(memory_block SRCS memory_block.cc memory_block_desc.cc meta_cache.cc)
if(${WITH_GPU})
nv_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info)
else(${WITH_GPU})
......@@ -6,10 +8,4 @@ endif(${WITH_GPU})
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator)
cc_library(meta_data SRCS meta_data.cc)
cc_library(meta_cache SRCS meta_cache.cc)
cc_library(memory_block SRCS memory_block.cc)
cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS glog)
cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS memory_block system_allocator glog)
......@@ -46,7 +46,8 @@ inline size_t align(size_t size, size_t alignment) {
void* BuddyAllocator::Alloc(size_t unaligned_size) {
// adjust allocation alignment
size_t size = align(unaligned_size + sizeof(Metadata), min_chunk_size_);
size_t size =
align(unaligned_size + sizeof(MemoryBlock::Desc), min_chunk_size_);
// acquire the allocator lock
std::lock_guard<std::mutex> lock(mutex_);
......@@ -103,7 +104,7 @@ void BuddyAllocator::Free(void* p) {
return;
}
block->mark_as_free(cache_);
block->mark_as_free(&cache_);
total_used_ -= block->total_size(cache_);
total_free_ += block->total_size(cache_);
......@@ -122,7 +123,7 @@ void BuddyAllocator::Free(void* p) {
right_buddy));
// merge its right buddy to the block
block->merge(cache_, right_buddy);
block->merge(&cache_, right_buddy);
}
}
......@@ -139,7 +140,7 @@ void BuddyAllocator::Free(void* p) {
left_buddy->total_size(cache_), left_buddy));
// merge the block to its left buddy
left_buddy->merge(cache_, block);
left_buddy->merge(&cache_, block);
block = left_buddy;
}
}
......@@ -163,13 +164,13 @@ size_t BuddyAllocator::Used() { return total_used_; }
void* BuddyAllocator::SystemAlloc(size_t size) {
size_t index = 0;
void* p = system_allocator_->Alloc(index, size);
void* p = system_allocator_->Alloc(&index, size);
VLOG(10) << "Allocated " << p << " from system allocator.";
if (p == nullptr) return nullptr;
static_cast<MemoryBlock*>(p)->init(cache_, MemoryBlock::HUGE_CHUNK, index,
static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::HUGE_CHUNK, index,
size, nullptr, nullptr);
return static_cast<MemoryBlock*>(p)->data();
......@@ -187,14 +188,14 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
// Allocate a new maximum sized block
size_t index = 0;
void* p = system_allocator_->Alloc(index, max_chunk_size_);
void* p = system_allocator_->Alloc(&index, max_chunk_size_);
if (p == nullptr) return pool_.end();
VLOG(10) << "Creating and inserting new block " << p
<< " from system allocator";
static_cast<MemoryBlock*>(p)->init(cache_, MemoryBlock::FREE_CHUNK, index,
static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::FREE_CHUNK, index,
max_chunk_size_, nullptr, nullptr);
// gpu fallback allocation
......@@ -238,11 +239,11 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
VLOG(10) << "Split block (" << block << ", " << block->total_size(cache_)
<< ") into";
block->split(cache_, size);
block->split(&cache_, size);
VLOG(10) << "Left block (" << block << ", " << block->total_size(cache_)
<< ")";
block->set_type(cache_, MemoryBlock::ARENA_CHUNK);
block->set_type(&cache_, MemoryBlock::ARENA_CHUNK);
// the rest of memory if exist
if (block->has_right_buddy(cache_)) {
......
......@@ -14,18 +14,18 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/memory/detail/meta_cache.h"
#include "paddle/fluid/memory/detail/meta_data.h"
#include <mutex> // NOLINT
#include <set>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#include <mutex>
#include <set>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace memory {
namespace detail {
......
......@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/detail/meta_data.h"
#include <functional>
#include "paddle/fluid/memory/detail/memory_block.h"
namespace paddle {
namespace memory {
namespace detail {
Metadata::Metadata(MemoryBlock::Type t, size_t i, size_t s, size_t ts,
MemoryBlock* l, MemoryBlock* r)
MemoryBlock::Desc::Desc(MemoryBlock::Type t, size_t i, size_t s, size_t ts,
MemoryBlock* l, MemoryBlock* r)
: type(t),
index(i),
size(s),
......@@ -29,7 +29,7 @@ Metadata::Metadata(MemoryBlock::Type t, size_t i, size_t s, size_t ts,
left_buddy(l),
right_buddy(r) {}
Metadata::Metadata()
MemoryBlock::Desc::Desc()
: type(MemoryBlock::INVALID_CHUNK),
index(0),
size(0),
......@@ -37,32 +37,36 @@ Metadata::Metadata()
left_buddy(nullptr),
right_buddy(nullptr) {}
namespace {
template <class T>
inline void hash_combine(std::size_t& seed, const T& v) {
inline void hash_combine(std::size_t* seed, const T& v) {
std::hash<T> hasher;
seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
(*seed) ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2);
}
inline size_t hash(const Metadata* metadata, size_t initial_seed) {
inline size_t hash(const MemoryBlock::Desc& metadata, size_t initial_seed) {
size_t seed = initial_seed;
hash_combine(seed, (size_t)metadata->type);
hash_combine(seed, metadata->index);
hash_combine(seed, metadata->size);
hash_combine(seed, metadata->total_size);
hash_combine(seed, metadata->left_buddy);
hash_combine(seed, metadata->right_buddy);
hash_combine(&seed, static_cast<size_t>(metadata.type));
hash_combine(&seed, metadata.index);
hash_combine(&seed, metadata.size);
hash_combine(&seed, metadata.total_size);
hash_combine(&seed, metadata.left_buddy);
hash_combine(&seed, metadata.right_buddy);
return seed;
}
void Metadata::update_guards() {
guard_begin = hash(this, 1);
guard_end = hash(this, 2);
} // namespace
void MemoryBlock::Desc::update_guards() {
guard_begin = hash(*this, 1);
guard_end = hash(*this, 2);
}
bool Metadata::check_guards() const {
return guard_begin == hash(this, 1) && guard_end == hash(this, 2);
bool MemoryBlock::Desc::check_guards() const {
return guard_begin == hash(*this, 1) && guard_end == hash(*this, 2);
}
} // namespace detail
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/detail/meta_cache.h"
#include "glog/logging.h"
#include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/platform/assert.h"
......@@ -23,29 +22,28 @@ namespace detail {
MetadataCache::MetadataCache(bool uses_gpu) : uses_gpu_(uses_gpu) {}
Metadata MetadataCache::load(const MemoryBlock* block) {
MemoryBlock::Desc MetadataCache::load(const MemoryBlock* block) const {
if (uses_gpu_) {
auto existing_metadata = cache_.find(block);
PADDLE_ASSERT(existing_metadata->second.check_guards());
return existing_metadata->second;
auto existing_desc = cache_.find(block);
PADDLE_ASSERT(existing_desc->second.check_guards());
return existing_desc->second;
} else {
auto* meta = reinterpret_cast<const Metadata*>(block);
VLOG(10) << "Load MetaData type=" << meta->type;
PADDLE_ASSERT(meta->check_guards());
return *reinterpret_cast<const Metadata*>(block);
auto* desc = reinterpret_cast<const MemoryBlock::Desc*>(block);
VLOG(10) << "Load MemoryBlock::Desc type=" << desc->type;
PADDLE_ASSERT(desc->check_guards());
return *reinterpret_cast<const MemoryBlock::Desc*>(block);
}
}
void MetadataCache::store(MemoryBlock* block,
const Metadata& original_metadata) {
auto metadata = original_metadata;
metadata.update_guards();
void MetadataCache::save(MemoryBlock* block,
const MemoryBlock::Desc& original_desc) {
auto desc = original_desc;
desc.update_guards();
if (uses_gpu_) {
cache_[block] = metadata;
cache_[block] = desc;
} else {
*reinterpret_cast<Metadata*>(block) = metadata;
*reinterpret_cast<MemoryBlock::Desc*>(block) = desc;
}
}
......
此差异已折叠。
此差异已折叠。
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/malloc.h"
#include "glog/logging.h"
......
此差异已折叠。
此差异已折叠。
......@@ -15,7 +15,6 @@ limitations under the License. */
#include <unordered_map>
#include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/memory/detail/meta_data.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......
......@@ -263,7 +263,7 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册