Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6e2424e4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e2424e4
编写于
4月 08, 2018
作者:
A
Abhinav Arora
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into cpplint_ops_a
上级
11487de9
31464f34
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
571 addition
and
262 deletion
+571
-262
benchmark/fluid/machine_translation.py
benchmark/fluid/machine_translation.py
+48
-18
benchmark/fluid/mnist.py
benchmark/fluid/mnist.py
+43
-21
benchmark/fluid/resnet.py
benchmark/fluid/resnet.py
+14
-24
benchmark/fluid/run.sh
benchmark/fluid/run.sh
+63
-7
benchmark/fluid/stacked_dynamic_lstm.py
benchmark/fluid/stacked_dynamic_lstm.py
+58
-31
benchmark/fluid/vgg.py
benchmark/fluid/vgg.py
+10
-6
paddle/.gitignore
paddle/.gitignore
+1
-0
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+42
-30
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+4
-2
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+56
-47
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+5
-2
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+1
-1
paddle/fluid/inference/tests/book/test_inference_image_classification.cc
...ference/tests/book/test_inference_image_classification.cc
+4
-4
paddle/fluid/inference/tests/test_helper.h
paddle/fluid/inference/tests/test_helper.h
+12
-3
paddle/fluid/operators/go_op.cc
paddle/fluid/operators/go_op.cc
+2
-2
paddle/fluid/operators/lod_reset_op.h
paddle/fluid/operators/lod_reset_op.h
+3
-1
paddle/fluid/platform/float16_test.cc
paddle/fluid/platform/float16_test.cc
+22
-18
paddle/fluid/platform/float16_test.cu
paddle/fluid/platform/float16_test.cu
+21
-21
paddle/fluid/pybind/.gitignore
paddle/fluid/pybind/.gitignore
+1
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+12
-5
python/.gitignore
python/.gitignore
+1
-0
python/paddle/.gitignore
python/paddle/.gitignore
+1
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+11
-4
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+57
-6
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+36
-8
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+43
-1
未找到文件。
benchmark/fluid/machine_translation.py
浏览文件 @
6e2424e4
...
...
@@ -48,6 +48,13 @@ parser.add_argument(
type
=
int
,
default
=
16
,
help
=
"The sequence number of a mini-batch data. (default: %(default)d)"
)
parser
.
add_argument
(
'--skip_batch_num'
,
type
=
int
,
default
=
5
,
help
=
'The first num of minibatch num to skip, for better performance test'
)
parser
.
add_argument
(
'--iterations'
,
type
=
int
,
default
=
80
,
help
=
'The number of minibatches.'
)
parser
.
add_argument
(
"--dict_size"
,
type
=
int
,
...
...
@@ -72,16 +79,21 @@ parser.add_argument(
default
=
3
,
help
=
"The width for beam searching. (default: %(default)d)"
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
distutils
.
util
.
strtobool
,
default
=
True
,
help
=
"Whether to use gpu. (default: %(default)d)"
)
'--device'
,
type
=
str
,
default
=
'GPU'
,
choices
=
[
'CPU'
,
'GPU'
],
help
=
"The device type."
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
250
,
help
=
"The maximum length of sequence when doing generation. "
"(default: %(default)d)"
)
parser
.
add_argument
(
'--with_test'
,
action
=
'store_true'
,
help
=
'If set, test the testset during training.'
)
def
lstm_step
(
x_t
,
hidden_t_prev
,
cell_t_prev
,
size
):
...
...
@@ -281,7 +293,7 @@ def train():
paddle
.
dataset
.
wmt14
.
test
(
args
.
dict_size
),
buf_size
=
1000
),
batch_size
=
args
.
batch_size
)
place
=
core
.
C
UDAPlace
(
0
)
if
args
.
use_gpu
else
core
.
CPUPlace
(
)
place
=
core
.
C
PUPlace
()
if
args
.
device
==
'CPU'
else
core
.
CUDAPlace
(
0
)
exe
=
Executor
(
place
)
exe
.
run
(
framework
.
default_startup_program
())
...
...
@@ -307,14 +319,20 @@ def train():
return
total_loss
/
count
iters
,
num_samples
,
start_time
=
0
,
0
,
time
.
time
()
for
pass_id
in
xrange
(
args
.
pass_num
):
pass_start_time
=
time
.
time
()
words_seen
=
0
train_accs
=
[]
train_losses
=
[]
for
batch_id
,
data
in
enumerate
(
train_batch_generator
()):
if
iters
==
args
.
skip_batch_num
:
start_time
=
time
.
time
()
num_samples
=
0
if
iters
==
args
.
iterations
:
break
src_seq
,
word_num
=
to_lodtensor
(
map
(
lambda
x
:
x
[
0
],
data
),
place
)
words_seen
+=
word_num
num_samples
+=
word_num
trg_seq
,
word_num
=
to_lodtensor
(
map
(
lambda
x
:
x
[
1
],
data
),
place
)
words_seen
+=
word_num
num_samples
+=
word_num
lbl_seq
,
_
=
to_lodtensor
(
map
(
lambda
x
:
x
[
2
],
data
),
place
)
fetch_outs
=
exe
.
run
(
framework
.
default_main_program
(),
...
...
@@ -325,24 +343,36 @@ def train():
},
fetch_list
=
[
avg_cost
])
avg_cost_val
=
np
.
array
(
fetch_outs
[
0
])
print
(
'pass_id=%d, batch_id=%d, train_loss: %f'
%
(
pass_id
,
batch_id
,
avg_cost_val
))
iters
+=
1
loss
=
np
.
array
(
fetch_outs
[
0
])
print
(
"Pass = %d, Iter = %d, Loss = %f"
%
(
pass_id
,
iters
,
loss
)
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_end_time
=
time
.
time
()
test_loss
=
do_validation
()
time_consumed
=
pass_end_time
-
pass_start_time
words_per_sec
=
words_seen
/
time_consumed
print
(
"pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f"
%
(
pass_id
,
test_loss
,
words_per_sec
,
time_consumed
))
train_elapsed
=
time
.
time
()
-
start_time
examples_per_sec
=
num_samples
/
train_elapsed
print
(
'
\n
Total examples: %d, total time: %.5f, %.5f examples/sed
\n
'
%
(
num_samples
,
train_elapsed
,
examples_per_sec
))
# evaluation
if
args
.
with_test
:
test_loss
=
do_validation
()
exit
(
0
)
def
infer
():
pass
def
print_arguments
(
args
):
print
(
'----------- seq2seq Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
'__main__'
:
args
=
parser
.
parse_args
()
print_arguments
(
args
)
if
args
.
infer_only
:
infer
()
else
:
...
...
benchmark/fluid/mnist.py
浏览文件 @
6e2424e4
...
...
@@ -35,6 +35,12 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
(
"mnist model benchmark."
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
128
,
help
=
'The minibatch size.'
)
parser
.
add_argument
(
'--skip_batch_num'
,
type
=
int
,
default
=
5
,
help
=
'The first num of minibatch num to skip, for better performance test'
)
parser
.
add_argument
(
'--iterations'
,
type
=
int
,
default
=
35
,
help
=
'The number of minibatches.'
)
parser
.
add_argument
(
...
...
@@ -53,19 +59,14 @@ def parse_args():
'--use_nvprof'
,
action
=
'store_true'
,
help
=
'If set, use nvprof for CUDA.'
)
parser
.
add_argument
(
'--with_test'
,
action
=
'store_true'
,
help
=
'If set, test the testset during training.'
)
args
=
parser
.
parse_args
()
return
args
def
print_arguments
(
args
):
vars
(
args
)[
'use_nvprof'
]
=
(
vars
(
args
)[
'use_nvprof'
]
and
vars
(
args
)[
'device'
]
==
'GPU'
)
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
def
cnn_model
(
data
):
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
data
,
...
...
@@ -161,16 +162,22 @@ def run_benchmark(model, args):
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
args
.
batch_size
)
accuracy
=
fluid
.
average
.
WeightedAverage
()
iters
,
num_samples
,
start_time
=
0
,
0
,
time
.
time
()
for
pass_id
in
range
(
args
.
pass_num
):
accuracy
.
reset
()
pass_start
=
time
.
time
()
train_accs
=
[]
train_losses
=
[]
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
iters
==
args
.
skip_batch_num
:
start_time
=
time
.
time
()
num_samples
=
0
if
iters
==
args
.
iterations
:
break
img_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
].
reshape
([
1
,
28
,
28
]),
data
)).
astype
(
DTYPE
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int64"
)
y_data
=
y_data
.
reshape
([
len
(
y_data
),
1
])
start
=
time
.
time
()
outs
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"pixel"
:
img_data
,
...
...
@@ -178,21 +185,36 @@ def run_benchmark(model, args):
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size_tensor
]
)
# The accuracy is the accumulation of batches, but not the current batch.
accuracy
.
add
(
value
=
outs
[
1
],
weight
=
outs
[
2
])
end
=
time
.
time
()
iters
+=
1
num_samples
+=
len
(
y_data
)
loss
=
np
.
array
(
outs
[
0
])
acc
=
np
.
array
(
outs
[
1
])
print
(
"pass=%d, batch=%d, loss=%f, error=%f, elapse=%f"
%
(
pass_id
,
batch_id
,
loss
,
1
-
acc
,
(
end
-
start
)
/
1000
))
train_losses
.
append
(
loss
)
train_accs
.
append
(
acc
)
print
(
"Pass: %d, Iter: %d, Loss: %f, Accuracy: %f"
%
(
pass_id
,
iters
,
loss
,
acc
))
print
(
"Pass: %d, Loss: %f, Train Accuray: %f
\n
"
%
(
pass_id
,
np
.
mean
(
train_losses
),
np
.
mean
(
train_accs
)))
train_elapsed
=
time
.
time
()
-
start_time
examples_per_sec
=
num_samples
/
train_elapsed
pass_end
=
time
.
time
()
print
(
'
\n
Total examples: %d, total time: %.5f, %.5f examples/sed
\n
'
%
(
num_samples
,
train_elapsed
,
examples_per_sec
))
# evaluation
if
args
.
with_test
:
test_avg_acc
=
eval_test
(
exe
,
batch_acc
,
batch_size_tensor
,
inference_program
)
exit
(
0
)
train_avg_acc
=
accuracy
.
eval
()
test_avg_acc
=
eval_test
(
exe
,
batch_acc
,
batch_size_tensor
,
inference_program
)
print
(
"pass=%d, train_avg_acc=%f, test_avg_acc=%f, elapse=%f"
%
(
pass_id
,
train_avg_acc
,
test_avg_acc
,
(
pass_end
-
pass_start
)
/
1000
))
def
print_arguments
(
args
):
vars
(
args
)[
'use_nvprof'
]
=
(
vars
(
args
)[
'use_nvprof'
]
and
vars
(
args
)[
'device'
]
==
'GPU'
)
print
(
'----------- mnist Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
'__main__'
:
...
...
benchmark/fluid/resnet.py
浏览文件 @
6e2424e4
...
...
@@ -87,15 +87,6 @@ def parse_args():
return
args
def
print_arguments
(
args
):
vars
(
args
)[
'use_nvprof'
]
=
(
vars
(
args
)[
'use_nvprof'
]
and
vars
(
args
)[
'device'
]
==
'GPU'
)
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'relu'
):
conv1
=
fluid
.
layers
.
conv2d
(
input
=
input
,
...
...
@@ -279,32 +270,31 @@ def run_benchmark(model, args):
'label'
:
label
},
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size_tensor
])
iters
+=
1
num_samples
+=
l
abel
[
0
]
num_samples
+=
l
en
(
label
)
accuracy
.
add
(
value
=
acc
,
weight
=
weight
)
train_losses
.
append
(
loss
)
train_accs
.
append
(
acc
)
print
(
"Pass: %d, Iter: %d, Loss: %f, Accuracy: %f"
%
(
pass_id
,
iters
,
loss
,
acc
))
pass_train_acc
=
accuracy
.
eval
()
# evaluation
if
args
.
with_test
:
pass_test_acc
=
test
(
exe
)
train_elapsed
=
time
.
time
()
-
start_time
print
(
"Pass: %d, Loss: %f, Train Accuray: %f
\n
"
%
(
pass_id
,
np
.
mean
(
train_losses
),
np
.
mean
(
train_accs
)))
train_elapsed
=
time
.
time
()
-
start_time
examples_per_sec
=
num_samples
/
train_elapsed
print
(
'
\n
Total examples: %d, total time: %.5f, %.5f examples/sed
\n
'
%
(
num_samples
,
train_elapsed
,
examples_per_sec
))
# evaluation
if
args
.
with_test
:
pass_test_acc
=
test
(
exe
)
exit
(
0
)
if
args
.
use_cprof
:
pr
.
disable
()
s
=
StringIO
.
StringIO
()
sortby
=
'cumulative'
ps
=
pstats
.
Stats
(
pr
,
stream
=
s
).
sort_stats
(
sortby
)
ps
.
print_stats
()
print
(
s
.
getvalue
())
def
print_arguments
(
args
):
vars
(
args
)[
'use_nvprof'
]
=
(
vars
(
args
)[
'use_nvprof'
]
and
vars
(
args
)[
'device'
]
==
'GPU'
)
print
(
'----------- resnet Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
'__main__'
:
...
...
benchmark/fluid/run.sh
浏览文件 @
6e2424e4
#!/bin/bash
# This script benchmarking the PaddlePaddle Fluid on
# single thread single GPU.
export
CUDNN_PATH
=
/paddle/cudnn_v5/cuda/lib
#export FLAGS_fraction_of_gpu_memory_to_use=0.0
export
CUDNN_PATH
=
/paddle/cudnn_v5
# disable openmp and mkl parallel
#https://github.com/PaddlePaddle/Paddle/issues/7199
...
...
@@ -25,25 +27,79 @@ export CUDA_VISIBLE_DEVICES=0
export
LD_LIBRARY_PATH
=
/usr/local/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
$CUDNN_PATH
:
$LD_LIBRARY_PATH
# only query the gpu used
nohup stdbuf
-oL
nvidia-smi
\
--id
=
${
CUDA_VISIBLE_DEVICES
}
\
--query-gpu
=
timestamp
\
--query-compute-apps
=
pid,process_name,used_memory
\
--format
=
csv
\
--filename
=
mem.log
\
-l
1 &
# mnist
# mnist gpu mnist 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/mnist.py
\
--device
=
GPU
\
--batch_size
=
128
\
--skip_batch_num
=
5
\
--iterations
=
500
\
2>&1 |
tee
-a
mnist_gpu_128.log
# vgg16
#
cifar10
gpu cifar10 128
FLAGS_benchmark
=
true
python fluid/vgg
.py
\
# gpu cifar10 128
FLAGS_benchmark
=
true
stdbuf
-oL
python fluid/vgg16
.py
\
--device
=
GPU
\
--batch_size
=
128
\
--skip_batch_num
=
5
\
--iterations
=
30
\
2>&1
>
vgg16_gpu_128.log
--iterations
=
30
\
2>&1 |
tee
-a
vgg16_gpu_128.log
# flowers gpu 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/vgg16.py
\
--device
=
GPU
\
--batch_size
=
32
\
--data_set
=
flowers
\
--skip_batch_num
=
5
\
--iterations
=
30
\
2>&1 |
tee
-a
vgg16_gpu_flowers_32.log
# resnet50
# resnet50 gpu cifar10 128
FLAGS_benchmark
=
true
python fluid/resnet
.py
\
FLAGS_benchmark
=
true
stdbuf
-oL
python fluid/resnet50
.py
\
--device
=
GPU
\
--batch_size
=
128
\
--data_set
=
cifar10
\
--model
=
resnet_cifar10
\
--skip_batch_num
=
5
\
--iterations
=
30
\
2>&1
>
resnet50_gpu_128.log
2>&1 |
tee
-a
resnet50_gpu_128.log
# resnet50 gpu flowers 64
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/resnet50.py
\
--device
=
GPU
\
--batch_size
=
64
\
--data_set
=
flowers
\
--model
=
resnet_imagenet
\
--skip_batch_num
=
5
\
--iterations
=
30
\
2>&1 |
tee
-a
resnet50_gpu_flowers_64.log
# lstm
# lstm gpu imdb 32 # tensorflow only support batch=32
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/stacked_dynamic_lstm.py
\
--device
=
GPU
\
--batch_size
=
32
\
--skip_batch_num
=
5
\
--iterations
=
30
\
--hidden_dim
=
512
\
--emb_dim
=
512
\
--crop_size
=
1500
\
2>&1 |
tee
-a
lstm_gpu_32.log
# seq2seq
# seq2seq gpu wmb 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/machine_translation.py
\
--device
=
GPU
\
--batch_size
=
128
\
--skip_batch_num
=
5
\
--iterations
=
30
\
2>&1 |
tee
-a
lstm_gpu_128.log
benchmark/fluid/stacked_dynamic_lstm.py
浏览文件 @
6e2424e4
...
...
@@ -37,6 +37,14 @@ def parse_args():
type
=
int
,
default
=
32
,
help
=
'The sequence number of a batch data. (default: %(default)d)'
)
parser
.
add_argument
(
'--skip_batch_num'
,
type
=
int
,
default
=
5
,
help
=
'The first num of minibatch num to skip, for better performance test'
)
parser
.
add_argument
(
'--iterations'
,
type
=
int
,
default
=
80
,
help
=
'The number of minibatches.'
)
parser
.
add_argument
(
'--emb_dim'
,
type
=
int
,
...
...
@@ -64,6 +72,10 @@ def parse_args():
default
=
int
(
os
.
environ
.
get
(
'CROP_SIZE'
,
'1500'
)),
help
=
'The max sentence length of input. Since this model use plain RNN,'
' Gradient could be explored if sentence is too long'
)
parser
.
add_argument
(
'--with_test'
,
action
=
'store_true'
,
help
=
'If set, test the testset during training.'
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -157,37 +169,43 @@ def main():
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
def
train_loop
(
pass_num
,
crop_size
):
with
profiler
.
profiler
(
args
.
device
,
'total'
)
as
prof
:
for
pass_id
in
range
(
pass_num
):
train_reader
=
batch
(
paddle
.
reader
.
shuffle
(
crop_sentence
(
imdb
.
train
(
word_dict
),
crop_size
),
buf_size
=
25000
),
batch_size
=
args
.
batch_size
)
word_nums
=
0
pass_start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_reader
()):
tensor_words
=
to_lodtensor
([
x
[
0
]
for
x
in
data
],
place
)
for
x
in
data
:
word_nums
+=
len
(
x
[
0
])
label
=
numpy
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
"int64"
)
label
=
label
.
reshape
((
-
1
,
1
))
loss_np
,
acc
,
weight
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"words"
:
tensor_words
,
"label"
:
label
},
fetch_list
=
[
loss
,
batch_acc
,
batch_size_tensor
])
print
(
"pass_id=%d, batch_id=%d, loss=%f, acc=%f"
%
(
pass_id
,
batch_id
,
loss_np
,
acc
))
pass_end_time
=
time
.
time
()
time_consumed
=
pass_end_time
-
pass_start_time
words_per_sec
=
word_nums
/
time_consumed
print
(
"pass_id=%d, sec/pass: %f, words/s: %f"
%
(
pass_id
,
time_consumed
,
words_per_sec
))
train_loop
(
args
.
pass_num
,
args
.
crop_size
)
train_reader
=
batch
(
paddle
.
reader
.
shuffle
(
crop_sentence
(
imdb
.
train
(
word_dict
),
args
.
crop_size
),
buf_size
=
25000
),
batch_size
=
args
.
batch_size
)
iters
,
num_samples
,
start_time
=
0
,
0
,
time
.
time
()
for
pass_id
in
range
(
args
.
pass_num
):
train_accs
=
[]
train_losses
=
[]
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
iters
==
args
.
skip_batch_num
:
start_time
=
time
.
time
()
num_samples
=
0
if
iters
==
args
.
iterations
:
break
tensor_words
=
to_lodtensor
([
x
[
0
]
for
x
in
data
],
place
)
label
=
numpy
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
"int64"
)
label
=
label
.
reshape
((
-
1
,
1
))
loss_np
,
acc
,
weight
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"words"
:
tensor_words
,
"label"
:
label
},
fetch_list
=
[
loss
,
batch_acc
,
batch_size_tensor
])
iters
+=
1
for
x
in
data
:
num_samples
+=
len
(
x
[
0
])
print
(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f"
%
(
pass_id
,
iters
,
loss_np
,
acc
)
)
# The accuracy is the accumulation of batches, but not the current batch.
train_elapsed
=
time
.
time
()
-
start_time
examples_per_sec
=
num_samples
/
train_elapsed
print
(
'
\n
Total examples: %d, total time: %.5f, %.5f examples/sed
\n
'
%
(
num_samples
,
train_elapsed
,
examples_per_sec
))
exit
(
0
)
def
to_lodtensor
(
data
,
place
):
...
...
@@ -205,5 +223,14 @@ def to_lodtensor(data, place):
return
res
def
print_arguments
(
args
):
print
(
'----------- lstm Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
print_arguments
(
args
)
main
()
benchmark/fluid/vgg.py
浏览文件 @
6e2424e4
...
...
@@ -191,25 +191,29 @@ def main():
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size_tensor
])
accuracy
.
add
(
value
=
acc
,
weight
=
weight
)
iters
+=
1
num_samples
+=
len
(
data
)
num_samples
+=
len
(
y_
data
)
print
(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f"
%
(
pass_id
,
iters
,
loss
,
acc
)
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_train_acc
=
accuracy
.
eval
()
#
pass_train_acc = accuracy.eval()
train_losses
.
append
(
loss
)
train_accs
.
append
(
acc
)
print
(
"Pass: %d, Loss: %f, Train Accuray: %f
\n
"
%
(
pass_id
,
np
.
mean
(
train_losses
),
np
.
mean
(
train_accs
)))
train_elapsed
=
time
.
time
()
-
start_time
examples_per_sec
=
num_samples
/
train_elapsed
print
(
'
\n
Total examples: %d, total time: %.5f, %.5f examples/sed
\n
'
%
(
num_samples
,
train_elapsed
,
examples_per_sec
))
# evaluation
if
args
.
with_test
:
pass_test_acc
=
test
(
exe
)
train_elapsed
=
time
.
time
()
-
start_time
print
(
"Pass: %d, Loss: %f, Train Accuray: %f
\n
"
%
(
pass_id
,
np
.
mean
(
train_losses
),
np
.
mean
(
train_accs
)))
exit
(
0
)
def
print_arguments
():
print
(
'----------- Configuration Arguments -----------'
)
print
(
'-----------
vgg
Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
...
...
paddle/.gitignore
浏览文件 @
6e2424e4
.timestamp
*.o
*.a
.svn
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
6e2424e4
...
...
@@ -93,6 +93,43 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN"
,
name
);
}
void
Executor
::
CreateVariables
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
)
{
auto
&
global_block
=
pdesc
.
Block
(
block_id
);
const
Scope
*
ancestor_scope
=
scope
;
while
(
ancestor_scope
->
parent
())
{
ancestor_scope
=
ancestor_scope
->
parent
();
}
if
(
ancestor_scope
!=
scope
)
{
for
(
auto
&
var
:
global_block
.
AllVars
())
{
if
(
var
->
Name
()
==
framework
::
kEmptyVarName
)
{
continue
;
}
if
(
var
->
Persistable
())
{
auto
*
ptr
=
const_cast
<
Scope
*>
(
ancestor_scope
)
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" global, which pointer is "
<<
ptr
;
}
else
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" locally, which pointer is "
<<
ptr
;
}
}
}
else
{
for
(
auto
&
var
:
global_block
.
AllVars
())
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
}
}
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
platform
::
RecordBlock
b
(
block_id
);
...
...
@@ -184,8 +221,8 @@ static bool has_fetch_operators(
void
Executor
::
Run
(
const
ProgramDesc
&
program
,
Scope
*
scope
,
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
feed_targets
,
std
::
map
<
std
::
string
,
LoDTensor
*>&
fetch_targets
,
const
std
::
string
&
feed_holder_name
,
const
std
::
string
&
fetch_holder_name
,
bool
create_vars
)
{
bool
create_vars
,
const
std
::
string
&
feed_holder_name
,
const
std
::
string
&
fetch_holder_name
)
{
platform
::
RecordBlock
b
(
kProgramId
);
bool
has_feed_ops
=
has_feed_operators
(
program
.
Block
(
0
),
feed_targets
,
feed_holder_name
);
...
...
@@ -296,38 +333,13 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void
Executor
::
RunPreparedContext
(
ExecutorPrepareContext
*
ctx
,
Scope
*
scope
,
bool
create_local_scope
,
bool
create_vars
)
{
auto
&
block
=
ctx
->
prog_
.
Block
(
ctx
->
block_id_
);
Scope
*
local_scope
=
scope
;
if
(
create_vars
)
{
if
(
create_local_scope
)
{
local_scope
=
&
scope
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Name
()
==
framework
::
kEmptyVarName
)
{
continue
;
}
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" global, which pointer is "
<<
ptr
;
}
else
{
auto
*
ptr
=
local_scope
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create Variable "
<<
var
->
Name
()
<<
" locally, which pointer is "
<<
ptr
;
}
}
}
else
{
for
(
auto
&
var
:
block
.
AllVars
())
{
auto
*
ptr
=
local_scope
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
}
// if (create_local_scope)
}
// if (create_vars)
}
CreateVariables
(
ctx
->
prog_
,
local_scope
,
ctx
->
block_id_
);
}
for
(
auto
&
op
:
ctx
->
ops_
)
{
VLOG
(
3
)
<<
place_
<<
" "
<<
op
->
DebugStringEx
(
local_scope
);
...
...
paddle/fluid/framework/executor.h
浏览文件 @
6e2424e4
...
...
@@ -54,9 +54,9 @@ class Executor {
void
Run
(
const
ProgramDesc
&
program
,
Scope
*
scope
,
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
feed_targets
,
std
::
map
<
std
::
string
,
LoDTensor
*>&
fetch_targets
,
bool
create_vars
=
true
,
const
std
::
string
&
feed_holder_name
=
"feed"
,
const
std
::
string
&
fetch_holder_name
=
"fetch"
,
bool
create_vars
=
true
);
const
std
::
string
&
fetch_holder_name
=
"fetch"
);
static
std
::
unique_ptr
<
ExecutorPrepareContext
>
Prepare
(
const
ProgramDesc
&
program
,
int
block_id
);
...
...
@@ -64,6 +64,8 @@ class Executor {
static
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Prepare
(
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
);
void
CreateVariables
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
);
void
RunPreparedContext
(
ExecutorPrepareContext
*
ctx
,
Scope
*
scope
,
bool
create_local_scope
=
true
,
bool
create_vars
=
true
);
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
6e2424e4
...
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include <vector>
...
...
@@ -24,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -43,30 +43,40 @@ class ParallelExecutorPrivate {
#endif
};
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
return
member_
->
local_scopes_
;
}
ParallelExecutor
::
ParallelExecutor
(
size_t
num_threads
,
bool
use_event
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
ProgramDesc
&
startup_program
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
bool
allow_op_delay
)
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
allow_op_delay
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
member_
->
global_scope_
=
scope
;
// Step 1. RunStartupProgram and Bcast the params to devs.
Executor
exe
(
places
[
0
]);
exe
.
Run
(
startup_program
,
scope
,
0
);
// Step 1. Bcast the params to devs.
// Create local scopes
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
member_
->
local_scopes_
.
push_back
(
&
scope
->
NewScope
());
if
(
local_scopes
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
member_
->
local_scopes_
.
push_back
(
&
scope
->
NewScope
());
}
}
else
{
PADDLE_ENFORCE_EQ
(
member_
->
places_
.
size
(),
local_scopes
.
size
());
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
member_
->
local_scopes_
.
push_back
(
local_scopes
[
i
]);
}
}
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
));
#endif
if
(
platform
::
is_gpu_place
(
places
[
0
])
&&
member_
->
local_scopes_
.
size
()
!=
1
)
{
// Is CUDA
BCastParamsToGPUs
(
startup_program
);
if
(
platform
::
is_gpu_place
(
places
[
0
])
&&
member_
->
local_scopes_
.
size
()
!=
1
&&
local_scopes
.
empty
()
)
{
// Is CUDA
BCastParamsToGPUs
(
bcast_vars
);
}
// Startup Program has been run. All local scopes has correct parameters.
...
...
@@ -99,48 +109,47 @@ ParallelExecutor::ParallelExecutor(
}
void
ParallelExecutor
::
BCastParamsToGPUs
(
const
ProgramDesc
&
startup_program
)
const
{
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
{
#ifdef PADDLE_WITH_CUDA
auto
*
main_scope
=
member_
->
local_scopes_
[
0
];
for
(
auto
*
var_desc
:
startup_program
.
Block
(
0
).
AllVars
())
{
size_t
idx
=
var_desc
->
Name
().
find
(
"@GRAD"
);
if
(
idx
!=
std
::
string
::
npos
)
continue
;
if
(
var_desc
->
GetType
()
==
proto
::
VarType
::
LOD_TENSOR
)
{
auto
&
main_tensor
=
main_scope
->
FindVar
(
var_desc
->
Name
())
->
Get
<
LoDTensor
>
();
auto
&
dims
=
main_tensor
.
dims
();
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
size_t
numel
=
main_tensor
.
numel
();
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
platform
::
NCCLGroupGuard
guard
;
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
place
=
member_
->
places_
[
i
];
void
*
buffer
;
if
(
i
==
0
)
{
buffer
=
const_cast
<
void
*>
(
main_tensor
.
data
<
void
>
());
}
else
{
auto
local_scope
=
member_
->
local_scopes_
[
i
];
auto
*
t
=
local_scope
->
Var
(
var_desc
->
Name
())
->
GetMutable
<
LoDTensor
>
();
t
->
Resize
(
dims
);
buffer
=
t
->
mutable_data
(
place
,
main_tensor
.
type
());
}
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
place
);
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
0
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
}
else
{
platform
::
CPUPlace
cpu
;
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
auto
&
var
:
vars
)
{
auto
*
main_var
=
main_scope
->
FindVar
(
var
);
if
(
!
main_var
->
IsType
<
LoDTensor
>
())
{
continue
;
}
auto
&
main_tensor
=
main_var
->
Get
<
LoDTensor
>
();
auto
&
dims
=
main_tensor
.
dims
();
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
size_t
numel
=
main_tensor
.
numel
();
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
platform
::
NCCLGroupGuard
guard
;
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
place
=
member_
->
places_
[
i
];
void
*
buffer
;
if
(
i
==
0
)
{
buffer
=
const_cast
<
void
*>
(
main_tensor
.
data
<
void
>
());
}
else
{
auto
local_scope
=
member_
->
local_scopes_
[
i
];
auto
*
t
=
local_scope
->
Var
(
var
_desc
->
Name
()
)
->
GetMutable
<
LoDTensor
>
();
auto
*
t
=
local_scope
->
Var
(
var
)
->
GetMutable
<
LoDTensor
>
();
t
->
Resize
(
dims
);
t
->
mutable_data
(
cpu
,
main_tensor
.
type
());
paddle
::
framework
::
TensorCopy
(
main_tensor
,
cpu
,
t
);
buffer
=
t
->
mutable_data
(
place
,
main_tensor
.
type
());
}
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
place
);
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
0
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
}
else
{
platform
::
CPUPlace
cpu
;
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
local_scope
=
member_
->
local_scopes_
[
i
];
auto
*
t
=
local_scope
->
Var
(
var
)
->
GetMutable
<
LoDTensor
>
();
t
->
Resize
(
dims
);
t
->
mutable_data
(
cpu
,
main_tensor
.
type
());
paddle
::
framework
::
TensorCopy
(
main_tensor
,
cpu
,
t
);
}
}
member_
->
nccl_ctxs_
->
WaitAll
();
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
6e2424e4
...
...
@@ -36,11 +36,14 @@ class ParallelExecutor {
explicit
ParallelExecutor
(
size_t
num_threads
,
bool
use_event
,
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
unordered_set
<
std
::
string
>&
params
,
const
ProgramDesc
&
startup_program
,
const
std
::
unordered_set
<
std
::
string
>&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
bool
allow_op_delay
);
std
::
vector
<
Scope
*>&
GetLocalScopes
();
void
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
,
const
std
::
unordered_map
<
std
::
string
,
LoDTensor
>&
feed_tensors
);
...
...
@@ -51,7 +54,7 @@ class ParallelExecutor {
ParallelExecutorPrivate
*
member_
;
void
BCastParamsToGPUs
(
const
ProgramDesc
&
startup_program
)
const
;
void
BCastParamsToGPUs
(
const
std
::
unordered_set
<
std
::
string
>&
vars
)
const
;
};
}
// namespace framework
...
...
paddle/fluid/framework/scope.h
浏览文件 @
6e2424e4
...
...
@@ -58,7 +58,7 @@ class Scope {
/// nullptr if cannot find.
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
const
Scope
&
parent
()
const
{
return
*
parent_
;
}
const
Scope
*
parent
()
const
{
return
parent_
;
}
/// Find the scope or an ancestor scope that contains the given variable.
const
Scope
*
FindScope
(
const
Variable
*
var
)
const
;
...
...
paddle/fluid/inference/tests/book/test_inference_image_classification.cc
浏览文件 @
6e2424e4
...
...
@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU
LOG
(
INFO
)
<<
"--- CPU Runs: ---"
;
TestInference
<
paddle
::
platform
::
CPUPlace
>
(
dirname
,
cpu_feeds
,
cpu_fetchs1
,
FLAGS_repeat
);
TestInference
<
paddle
::
platform
::
CPUPlace
,
false
>
(
dirname
,
cpu_feeds
,
cpu_fetchs1
,
FLAGS_repeat
);
LOG
(
INFO
)
<<
output1
.
dims
();
#ifdef PADDLE_WITH_CUDA
...
...
@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU
LOG
(
INFO
)
<<
"--- GPU Runs: ---"
;
TestInference
<
paddle
::
platform
::
CUDAPlace
>
(
dirname
,
cpu_feeds
,
cpu_fetchs2
,
FLAGS_repeat
);
TestInference
<
paddle
::
platform
::
CUDAPlace
,
false
>
(
dirname
,
cpu_feeds
,
cpu_fetchs2
,
FLAGS_repeat
);
LOG
(
INFO
)
<<
output2
.
dims
();
CheckError
<
float
>
(
output1
,
output2
);
...
...
paddle/fluid/inference/tests/test_helper.h
浏览文件 @
6e2424e4
...
...
@@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ
(
count
,
0U
)
<<
"There are "
<<
count
<<
" different elements."
;
}
template
<
typename
Place
>
template
<
typename
Place
,
bool
CreateVars
=
true
>
void
TestInference
(
const
std
::
string
&
dirname
,
const
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_feeds
,
const
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_fetchs
,
...
...
@@ -166,8 +166,16 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
if
(
!
CreateVars
)
{
// If users don't want to create and destroy variables every time they
// run, they need to set `create_vars` to false and manually call
// `CreateVariables` before running.
executor
.
CreateVariables
(
*
inference_program
,
scope
,
0
);
}
// Ignore the profiling results of the first run
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
,
CreateVars
);
// Enable the profiler
paddle
::
platform
::
EnableProfiler
(
state
);
...
...
@@ -178,7 +186,8 @@ void TestInference(const std::string& dirname,
"run_inference"
,
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
,
CreateVars
);
}
// Disable the profiler and print the timing information
...
...
paddle/fluid/operators/go_op.cc
浏览文件 @
6e2424e4
...
...
@@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase {
// TODO(varunarora): Consider moving this root scope lookup to scope.h.
const
framework
::
Scope
*
root_scope
=
&
scope
;
const
framework
::
Scope
*
parent_scope
=
&
(
root_scope
->
parent
()
);
const
framework
::
Scope
*
parent_scope
=
root_scope
->
parent
(
);
while
(
parent_scope
!=
nullptr
)
{
root_scope
=
parent_scope
;
parent_scope
=
&
(
parent_scope
->
parent
()
);
parent_scope
=
parent_scope
->
parent
(
);
}
framework
::
BlockDesc
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kBlock
);
...
...
paddle/fluid/operators/lod_reset_op.h
浏览文件 @
6e2424e4
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -35,7 +37,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
if
(
lod_t
->
lod
().
size
()
>
0
)
{
auto
y_lod
=
lod_t
->
lod
();
auto
last_level
=
y_lod
[
y_lod
.
size
()
-
1
];
PADDLE_ENFORCE_EQ
(
last_level
.
back
(
),
in
->
dims
()[
0
],
PADDLE_ENFORCE_EQ
(
(
int64_t
)(
last_level
.
back
()
),
in
->
dims
()[
0
],
"Last value of `Y`'s last level LoD should be equal "
"to the first dimension of `X`"
);
out
->
set_lod
(
y_lod
);
...
...
paddle/fluid/platform/float16_test.cc
浏览文件 @
6e2424e4
...
...
@@ -8,13 +8,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
platform
{
...
...
@@ -74,24 +75,27 @@ TEST(float16, conversion_cpu) {
// Conversion operator
EXPECT_EQ
(
Eigen
::
half
(
float16
(
1.0
f
)).
x
,
0x3c00
);
EXPECT_EQ
(
float
(
float16
(
0.5
f
)),
0.5
f
);
EXPECT_NEAR
(
double
(
float16
(
0.33333
)),
0.33333
,
0.0001
);
EXPECT_EQ
(
int
(
float16
(
-
1
)),
-
1
);
EXPECT_EQ
(
bool
(
float16
(
true
)),
true
);
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
0.5
f
)),
0.5
f
);
EXPECT_NEAR
(
static_cast
<
double
>
(
float16
(
0.33333
)),
0.33333
,
0.0001
);
EXPECT_EQ
(
static_cast
<
int
>
(
float16
(
-
1
)),
-
1
);
EXPECT_EQ
(
static_cast
<
bool
>
(
float16
(
true
)),
true
);
}
TEST
(
float16
,
arithmetic_cpu
)
{
EXPECT_EQ
(
float
(
float16
(
1
)
+
float16
(
1
)),
2
);
EXPECT_EQ
(
float
(
float16
(
5
)
+
float16
(
-
5
)),
0
);
EXPECT_NEAR
(
float
(
float16
(
0.33333
f
)
+
float16
(
0.66667
f
)),
1.0
f
,
0.001
);
EXPECT_EQ
(
float
(
float16
(
3
)
-
float16
(
5
)),
-
2
);
EXPECT_NEAR
(
float
(
float16
(
0.66667
f
)
-
float16
(
0.33333
f
)),
0.33334
f
,
0.001
);
EXPECT_NEAR
(
float
(
float16
(
3.3
f
)
*
float16
(
2.0
f
)),
6.6
f
,
0.01
);
EXPECT_NEAR
(
float
(
float16
(
-
2.1
f
)
*
float16
(
-
3.0
f
)),
6.3
f
,
0.01
);
EXPECT_NEAR
(
float
(
float16
(
2.0
f
)
/
float16
(
3.0
f
)),
0.66667
f
,
0.001
);
EXPECT_EQ
(
float
(
float16
(
1.0
f
)
/
float16
(
2.0
f
)),
0.5
f
);
EXPECT_EQ
(
float
(
-
float16
(
512.0
f
)),
-
512.0
f
);
EXPECT_EQ
(
float
(
-
float16
(
-
512.0
f
)),
512.0
f
);
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
1
)
+
float16
(
1
)),
2
);
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
5
)
+
float16
(
-
5
)),
0
);
EXPECT_NEAR
(
static_cast
<
float
>
(
float16
(
0.33333
f
)
+
float16
(
0.66667
f
)),
1.0
f
,
0.001
);
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
3
)
-
float16
(
5
)),
-
2
);
EXPECT_NEAR
(
static_cast
<
float
>
(
float16
(
0.66667
f
)
-
float16
(
0.33333
f
)),
0.33334
f
,
0.001
);
EXPECT_NEAR
(
static_cast
<
float
>
(
float16
(
3.3
f
)
*
float16
(
2.0
f
)),
6.6
f
,
0.01
);
EXPECT_NEAR
(
static_cast
<
float
>
(
float16
(
-
2.1
f
)
*
float16
(
-
3.0
f
)),
6.3
f
,
0.01
);
EXPECT_NEAR
(
static_cast
<
float
>
(
float16
(
2.0
f
)
/
float16
(
3.0
f
)),
0.66667
f
,
0.001
);
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
1.0
f
)
/
float16
(
2.0
f
)),
0.5
f
);
EXPECT_EQ
(
static_cast
<
float
>
(
-
float16
(
512.0
f
)),
-
512.0
f
);
EXPECT_EQ
(
static_cast
<
float
>
(
-
float16
(
-
512.0
f
)),
512.0
f
);
}
TEST
(
float16
,
comparison_cpu
)
{
...
...
paddle/fluid/platform/float16_test.cu
浏览文件 @
6e2424e4
...
...
@@ -36,19 +36,19 @@ limitations under the License. */
half *in1, *in2, *out; \
half *d_in1, *d_in2, *d_out; \
int size = sizeof(half); \
cudaMalloc(
(void**)&d_in1, size);
\
cudaMalloc(
(void**)&d_in2, size);
\
cudaMalloc(
(void**)&d_out, size);
\
in1 =
(half*)malloc(size);
\
in2 =
(half*)malloc(size);
\
out =
(half*)malloc(size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_in1), size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_in2), size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_out), size);
\
in1 =
reinterpret_cast<half*>(malloc(size));
\
in2 =
reinterpret_cast<half*>(malloc(size));
\
out =
reinterpret_cast<half*>(malloc(size));
\
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2, d_out); \
cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost); \
EXPECT_EQ(
float(float16(out[0])), v_out);
\
EXPECT_EQ(
static_cast<float>(float16(out[0])), v_out);
\
free(in1); \
free(in2); \
free(out); \
...
...
@@ -63,17 +63,17 @@ limitations under the License. */
half *in1, *in2; \
half *d_in1, *d_in2; \
int size = sizeof(half); \
cudaMalloc(
(void**)&d_in1, size);
\
cudaMalloc(
(void**)&d_in2, size);
\
in1 =
(half*)malloc(size);
\
in2 =
(half*)malloc(size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_in1), size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_in2), size);
\
in1 =
reinterpret_cast<half*>(malloc(size));
\
in2 =
reinterpret_cast<half*>(malloc(size));
\
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2); \
cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost); \
EXPECT_EQ(
float(float16(in1[0])), v_out);
\
EXPECT_EQ(
static_cast<float>(float16(in1[0])), v_out);
\
free(in1); \
free(in2); \
cudaFree(d_in1); \
...
...
@@ -87,12 +87,12 @@ limitations under the License. */
half *d_in1, *d_in2; \
bool *out, *d_out; \
int size = sizeof(half); \
cudaMalloc(
(void**)&d_in1, size);
\
cudaMalloc(
(void**)&d_in2, size);
\
cudaMalloc(
(void**)&d_out, 1);
\
in1 =
(half*)malloc(size);
\
in2 =
(half*)malloc(size);
\
out =
(bool*)malloc(1);
\
cudaMalloc(
reinterpret_cast<void**>(&d_in1), size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_in2), size);
\
cudaMalloc(
reinterpret_cast<void**>(&d_out), 1);
\
in1 =
reinterpret_cast<half*>(malloc(size));
\
in2 =
reinterpret_cast<half*>(malloc(size));
\
out =
reinterpret_cast<bool*>(malloc(1));
\
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
...
...
@@ -130,13 +130,13 @@ void TestNeg(float v_in, float v_out) {
LOG
(
INFO
)
<<
"Test Neg on GPU!"
;
half
*
in
,
*
d_in
;
int
size
=
sizeof
(
half
);
cudaMalloc
(
(
void
**
)
&
d_in
,
size
);
in
=
(
half
*
)
malloc
(
size
);
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in
)
,
size
);
in
=
reinterpret_cast
<
half
*>
(
malloc
(
size
)
);
in
[
0
]
=
half
(
float16
(
v_in
));
cudaMemcpy
(
d_in
,
in
,
size
,
cudaMemcpyHostToDevice
);
Neg
<<<
1
,
1
>>>
(
d_in
);
cudaMemcpy
(
in
,
d_in
,
size
,
cudaMemcpyDeviceToHost
);
EXPECT_EQ
(
float
(
float16
(
in
[
0
])),
v_out
);
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
in
[
0
])),
v_out
);
free
(
in
);
cudaFree
(
d_in
);
}
...
...
paddle/fluid/pybind/.gitignore
0 → 100644
浏览文件 @
6e2424e4
pybind.h
paddle/fluid/pybind/pybind.cc
浏览文件 @
6e2424e4
...
...
@@ -544,13 +544,20 @@ All parameter, weight, gradient are variables in Paddle.
[](
ParallelExecutor
&
self
,
size_t
num_threads
,
bool
use_event
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
ProgramDesc
&
startup_program
,
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
bool
allow_op_delay
)
{
new
(
&
self
)
ParallelExecutor
(
num_threads
,
use_event
,
places
,
params
,
startup_program
,
main_program
,
loss_var_name
,
scope
,
allow_op_delay
);
Scope
*
scope
,
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
allow_op_delay
)
{
new
(
&
self
)
ParallelExecutor
(
num_threads
,
use_event
,
places
,
params
,
bcast_vars
,
main_program
,
loss_var_name
,
scope
,
local_scopes
,
allow_op_delay
);
})
.
def
(
"local_scopes"
,
[](
ParallelExecutor
&
self
)
->
std
::
vector
<
Scope
*>
*
{
return
&
self
.
GetLocalScopes
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"run"
,
&
ParallelExecutor
::
Run
);
BindRecordIOWriter
(
&
m
);
...
...
python/.gitignore
浏览文件 @
6e2424e4
*pyc
build
dist
paddlepaddle.egg-info
paddle.egg-info
paddlepaddle_gpu.egg-info
.idea
...
...
python/paddle/.gitignore
0 → 100644
浏览文件 @
6e2424e4
version.py
python/paddle/fluid/framework.py
浏览文件 @
6e2424e4
...
...
@@ -659,7 +659,7 @@ class Block(object):
def
__init__
(
self
,
program
,
idx
):
self
.
desc
=
program
.
desc
.
block
(
idx
)
self
.
vars
=
dict
()
# var_name --> var
self
.
ops
=
collections
.
deque
()
# operator list
self
.
ops
=
list
()
# operator list
self
.
program
=
program
self
.
removed_vars
=
dict
()
...
...
@@ -831,6 +831,13 @@ class Block(object):
self
.
ops
.
append
(
op
)
return
op
def
insert_op
(
self
,
index
,
*
args
,
**
kwargs
):
self
.
sync_with_cpp
()
op_desc
=
self
.
desc
.
insert_op
(
index
)
op
=
Operator
(
block
=
self
,
desc
=
op_desc
,
*
args
,
**
kwargs
)
self
.
ops
.
insert
(
index
,
op
)
return
op
def
delete_ops
(
self
,
ops
):
# remove from cpp
# FIXME(typhoonzero): remove only the first occurrence.
...
...
@@ -842,12 +849,12 @@ class Block(object):
self
.
desc
.
remove_op
(
start
,
end
+
1
)
def
slice_ops
(
self
,
start
,
end
):
return
list
(
self
.
ops
)
[
start
:
end
]
return
self
.
ops
[
start
:
end
]
def
prepend_op
(
self
,
*
args
,
**
kwargs
):
op_desc
=
self
.
desc
.
prepend_op
()
op
=
Operator
(
self
,
op_desc
,
*
args
,
**
kwargs
)
self
.
ops
.
appendleft
(
op
)
self
.
ops
.
insert
(
0
,
op
)
return
op
def
sync_with_cpp
(
self
):
...
...
@@ -892,7 +899,7 @@ class Block(object):
for
index
in
range
((
start_index
-
1
-
1
),
-
1
,
-
1
):
op_desc
=
ops_in_cpp
[
index
]
op
=
Operator
(
self
,
op_desc
)
self
.
ops
.
appendleft
(
op
)
self
.
ops
.
insert
(
0
,
op
)
# sync ops append to the end of cpp_ops
for
index
in
range
((
end_index
+
1
),
len
(
ops_in_cpp
)):
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
6e2424e4
...
...
@@ -22,10 +22,49 @@ __all__ = ['ParallelExecutor']
class
ParallelExecutor
(
object
):
def
__init__
(
self
,
loss_name
,
use_cuda
,
loss_name
=
None
,
main_program
=
None
,
num_threads
=
None
,
allow_op_delay
=
False
):
allow_op_delay
=
False
,
share_vars_from
=
None
):
"""
ParallelExecutor can run program in parallel.
Args:
use_cuda(bool): Whether to use CUDA or not.
loss_name(str, default None): The loss name must set in training.
main_program(Program, default None): The program that need to run,
if not provided, then default_main_program will be used.
num_threads(int, default None): How many threads are used for
training.
allow_op_delay(bool, default False): Whether to delay and buffer
some operators together for scheduling or not, which may
improve performance in some cases, defalut False.
share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor.
Returns:
A ParallelExecutor object.
Raises:
TypeError: If share_vars_from is provided, but not ParallelExecutor
object.
Examples:
.. code-block:: python
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=test_program,
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
"""
self
.
_places
=
[]
self
.
_act_places
=
[]
if
use_cuda
:
...
...
@@ -50,10 +89,21 @@ class ParallelExecutor(object):
else
:
min
(
len
(
self
.
_places
)
*
2
,
multiprocessing
.
cpu_count
())
startup
=
framework
.
default_startup_program
()
main
=
framework
.
default_main_program
()
main
=
main_program
main
=
main
if
main
else
framework
.
default_main_program
()
scope
=
executor
.
global_scope
()
if
share_vars_from
and
not
isinstance
(
share_vars_from
,
ParallelExecutor
):
raise
TypeError
(
"share_vars_from must be ParallelExecutor."
)
local_scopes
=
share_vars_from
.
executor
.
local_scopes
(
)
if
share_vars_from
else
[]
persistable_vars
=
[
v
.
name
for
v
in
filter
(
lambda
var
:
var
.
persistable
,
main
.
list_vars
())
]
self
.
executor
=
core
.
ParallelExecutor
(
num_threads
,
True
if
use_cuda
else
False
,
# use_event
...
...
@@ -62,10 +112,11 @@ class ParallelExecutor(object):
p
.
name
for
p
in
main
.
global_block
().
iter_parameters
()
if
not
p
.
stop_gradient
]),
s
tartup
.
desc
,
s
et
(
persistable_vars
)
,
main
.
desc
,
loss_name
,
loss_name
if
loss_name
else
''
,
scope
,
local_scopes
,
allow_op_delay
)
self
.
scope
=
scope
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
6e2424e4
...
...
@@ -535,9 +535,37 @@ class TestSwish(OpTest):
#--------------------test MKLDNN--------------------
class
TestMKLDNNRelu
(
TestRelu
):
class
TestMKLDNNRelu
Dim2
(
TestRelu
):
def
setUp
(
self
):
super
(
TestMKLDNNRelu
,
self
).
setUp
()
super
(
TestMKLDNNReluDim2
,
self
).
setUp
()
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNTanhDim2
(
TestTanh
):
def
setUp
(
self
):
super
(
TestMKLDNNTanhDim2
,
self
).
setUp
()
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNSqrtDim2
(
TestSqrt
):
def
setUp
(
self
):
super
(
TestMKLDNNSqrtDim2
,
self
).
setUp
()
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNAbsDim2
(
TestAbs
):
def
setUp
(
self
):
super
(
TestMKLDNNAbsDim2
,
self
).
setUp
()
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNReluDim4
(
TestRelu
):
def
setUp
(
self
):
super
(
TestMKLDNNReluDim4
,
self
).
setUp
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
2
,
4
,
3
,
5
]).
astype
(
"float32"
)
# The same reason with TestAbs
...
...
@@ -549,9 +577,9 @@ class TestMKLDNNRelu(TestRelu):
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNTanh
(
TestTanh
):
class
TestMKLDNNTanh
Dim4
(
TestTanh
):
def
setUp
(
self
):
super
(
TestMKLDNNTanh
,
self
).
setUp
()
super
(
TestMKLDNNTanh
Dim4
,
self
).
setUp
()
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
4
,
3
,
5
]).
astype
(
"float32"
)
...
...
@@ -560,9 +588,9 @@ class TestMKLDNNTanh(TestTanh):
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNSqrt
(
TestSqrt
):
class
TestMKLDNNSqrt
Dim4
(
TestSqrt
):
def
setUp
(
self
):
super
(
TestMKLDNNSqrt
,
self
).
setUp
()
super
(
TestMKLDNNSqrt
Dim4
,
self
).
setUp
()
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
4
,
3
,
5
]).
astype
(
"float32"
)
...
...
@@ -571,9 +599,9 @@ class TestMKLDNNSqrt(TestSqrt):
self
.
attrs
=
{
"use_mkldnn"
:
True
}
class
TestMKLDNNAbs
(
TestAbs
):
class
TestMKLDNNAbs
Dim4
(
TestAbs
):
def
setUp
(
self
):
super
(
TestMKLDNNAbs
,
self
).
setUp
()
super
(
TestMKLDNNAbs
Dim4
,
self
).
setUp
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
2
,
4
,
3
,
5
]).
astype
(
"float32"
)
# The same reason with TestAbs
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
6e2424e4
...
...
@@ -207,7 +207,11 @@ class TestParallelExecutorBase(unittest.TestCase):
if
memory_opt
:
fluid
.
memory_optimize
(
main
)
exe
=
fluid
.
ParallelExecutor
(
loss_name
=
loss
.
name
,
use_cuda
=
True
)
place
=
fluid
.
CUDAPlace
(
0
)
startup_exe
=
fluid
.
Executor
(
place
)
startup_exe
.
run
(
startup
)
exe
=
fluid
.
ParallelExecutor
(
True
,
loss_name
=
loss
.
name
)
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
()
begin
=
time
.
time
()
...
...
@@ -453,3 +457,41 @@ class TestTransformer(TestParallelExecutorBase):
@
unittest
.
skip
(
"transformer is buggy in multi gpu"
)
def
test_main
(
self
):
self
.
check_network_convergence
(
transformer
)
class
ParallelExecutorTestingDuringTraining
(
unittest
.
TestCase
):
def
test_parallel_testing
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
simple_fc_net
(
True
)
test_program
=
main
.
clone
(
for_test
=
True
)
opt
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.0001
)
opt
.
minimize
(
loss
)
batch_size
=
32
image
=
numpy
.
random
.
normal
(
size
=
(
batch_size
,
784
)).
astype
(
'float32'
)
label
=
numpy
.
random
.
randint
(
0
,
10
,
(
batch_size
,
1
),
dtype
=
"int64"
)
place
=
fluid
.
CUDAPlace
(
0
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
feed_dict
=
{
'image'
:
image
,
'label'
:
label
}
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
loss
.
name
,
main_program
=
main
)
test_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
main_program
=
test_program
,
share_vars_from
=
train_exe
)
for
i
in
xrange
(
5
):
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed_dict
=
feed_dict
)
test_loss
=
numpy
.
array
(
test_loss
)
train_loss
,
=
train_exe
.
run
([
loss
.
name
],
feed_dict
=
feed_dict
)
train_loss
=
numpy
.
array
(
train_loss
)
self
.
assertTrue
(
numpy
.
allclose
(
train_loss
,
test_loss
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录