Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ea73fb84
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea73fb84
编写于
6月 06, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into dev_reverse_op
上级
42d71747
ea408d55
变更
57
显示空白变更内容
内联
并排
Showing
57 changed file
with
798 addition
and
302 deletion
+798
-302
Dockerfile
Dockerfile
+2
-3
benchmark/fluid/Dockerfile
benchmark/fluid/Dockerfile
+2
-2
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+10
-5
benchmark/fluid/models/mnist.py
benchmark/fluid/models/mnist.py
+24
-9
benchmark/fluid/models/resnet.py
benchmark/fluid/models/resnet.py
+22
-7
benchmark/fluid/models/stacked_dynamic_lstm.py
benchmark/fluid/models/stacked_dynamic_lstm.py
+1
-2
cmake/external/grpc.cmake
cmake/external/grpc.cmake
+1
-0
doc/fluid/howto/optimization/host_memory_profiling_cn.md
doc/fluid/howto/optimization/host_memory_profiling_cn.md
+89
-0
paddle/contrib/inference/demo/simple_on_word2vec.cc
paddle/contrib/inference/demo/simple_on_word2vec.cc
+3
-0
paddle/contrib/inference/paddle_inference_api.h
paddle/contrib/inference/paddle_inference_api.h
+2
-1
paddle/contrib/inference/paddle_inference_api_impl.cc
paddle/contrib/inference/paddle_inference_api_impl.cc
+27
-17
paddle/contrib/inference/paddle_inference_api_impl.h
paddle/contrib/inference/paddle_inference_api_impl.h
+6
-3
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-0
paddle/fluid/framework/details/execution_strategy.h
paddle/fluid/framework/details/execution_strategy.h
+1
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+76
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+53
-0
paddle/fluid/framework/details/ssa_graph_executor.cc
paddle/fluid/framework/details/ssa_graph_executor.cc
+0
-4
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+1
-5
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+4
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+1
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+16
-42
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+6
-0
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h
+68
-0
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc
...fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc
+46
-0
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+3
-3
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+41
-55
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+18
-5
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+1
-6
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+4
-7
paddle/fluid/operators/detection/box_coder_op.cc
paddle/fluid/operators/detection/box_coder_op.cc
+11
-8
paddle/fluid/operators/detection/box_coder_op.cu
paddle/fluid/operators/detection/box_coder_op.cu
+38
-21
paddle/fluid/operators/detection/box_coder_op.h
paddle/fluid/operators/detection/box_coder_op.h
+52
-33
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+2
-2
paddle/fluid/operators/pool_cudnn_op.cu.cc
paddle/fluid/operators/pool_cudnn_op.cu.cc
+5
-1
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+1
-1
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+1
-1
paddle/fluid/operators/reduce_op.h
paddle/fluid/operators/reduce_op.h
+10
-9
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-2
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+4
-4
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+6
-7
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+17
-3
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+0
-1
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+0
-3
paddle/fluid/platform/dynload/cublas.h
paddle/fluid/platform/dynload/cublas.h
+1
-1
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+1
-1
paddle/fluid/platform/dynload/cupti.h
paddle/fluid/platform/dynload/cupti.h
+1
-1
paddle/fluid/platform/dynload/curand.h
paddle/fluid/platform/dynload/curand.h
+1
-1
paddle/fluid/platform/dynload/nccl.h
paddle/fluid/platform/dynload/nccl.h
+1
-1
paddle/fluid/platform/dynload/tensorrt.h
paddle/fluid/platform/dynload/tensorrt.h
+1
-1
paddle/fluid/platform/dynload/warpctc.h
paddle/fluid/platform/dynload/warpctc.h
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+8
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+75
-15
python/paddle/fluid/tests/no_test_concurrency.py
python/paddle/fluid/tests/no_test_concurrency.py
+0
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-4
python/paddle/fluid/tests/unittests/test_box_coder_op.py
python/paddle/fluid/tests/unittests/test_box_coder_op.py
+26
-0
未找到文件。
Dockerfile
浏览文件 @
ea73fb84
...
@@ -24,7 +24,7 @@ COPY ./paddle/scripts/docker/root/ /root/
...
@@ -24,7 +24,7 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN
apt-get update
&&
\
RUN
apt-get update
&&
\
apt-get
install
-y
--allow-downgrades
\
apt-get
install
-y
--allow-downgrades
\
git python-pip python-dev openssh-server bison
\
git python-pip python-dev
python-opencv
openssh-server bison
\
libnccl2
=
2.1.2-1+cuda8.0 libnccl-dev
=
2.1.2-1+cuda8.0
\
libnccl2
=
2.1.2-1+cuda8.0 libnccl-dev
=
2.1.2-1+cuda8.0
\
wget unzip unrar
tar
xz-utils bzip2
gzip
coreutils ntp
\
wget unzip unrar
tar
xz-utils bzip2
gzip
coreutils ntp
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
...
@@ -76,8 +76,7 @@ RUN easy_install -U pip && \
...
@@ -76,8 +76,7 @@ RUN easy_install -U pip && \
pip
install
sphinx-rtd-theme
==
0.1.9 recommonmark
pip
install
sphinx-rtd-theme
==
0.1.9 recommonmark
RUN
pip
install
pre-commit
'ipython==5.3.0'
&&
\
RUN
pip
install
pre-commit
'ipython==5.3.0'
&&
\
pip
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
pip
install
opencv-python
#For docstring checker
#For docstring checker
RUN
pip
install
pylint pytest astroid isort
RUN
pip
install
pylint pytest astroid isort
...
...
benchmark/fluid/Dockerfile
浏览文件 @
ea73fb84
FROM
nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
FROM
nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN
apt-get update
&&
apt-get
install
-y
python python-pip iputils-ping libgtk2.0-dev wget vim net-tools iftop
RUN
apt-get update
&&
apt-get
install
-y
python python-pip iputils-ping libgtk2.0-dev wget vim net-tools iftop
python-opencv
RUN
ln
-s
/usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so
&&
ln
-s
/usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/lib/libnccl.so
RUN
ln
-s
/usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so
&&
ln
-s
/usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/lib/libnccl.so
RUN
pip
install
-U
pip
RUN
pip
install
-U
pip
RUN
pip
install
-U
kubernetes
opencv-python
paddlepaddle
RUN
pip
install
-U
kubernetes paddlepaddle
# IMPORTANT:
# IMPORTANT:
# Add "ENV http_proxy=http://ip:port" if your download is slow, and don't forget to unset it at runtime.
# Add "ENV http_proxy=http://ip:port" if your download is slow, and don't forget to unset it at runtime.
...
...
benchmark/fluid/fluid_benchmark.py
浏览文件 @
ea73fb84
...
@@ -69,6 +69,11 @@ def parse_args():
...
@@ -69,6 +69,11 @@ def parse_args():
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
'If gpus > 1, will use ParallelExecutor to run, else use Executor.'
)
help
=
'If gpus > 1, will use ParallelExecutor to run, else use Executor.'
)
parser
.
add_argument
(
'--cpus'
,
type
=
int
,
default
=
1
,
help
=
'If cpus > 1, will use ParallelDo to run, else use Executor.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--data_set'
,
'--data_set'
,
type
=
str
,
type
=
str
,
...
@@ -85,8 +90,8 @@ def parse_args():
...
@@ -85,8 +90,8 @@ def parse_args():
help
=
'If set, use nvprof for CUDA.'
)
help
=
'If set, use nvprof for CUDA.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--no_test'
,
'--no_test'
,
action
=
'store_
fals
e'
,
action
=
'store_
tru
e'
,
help
=
'If set, test the testset during training.'
)
help
=
'If set,
do not
test the testset during training.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--memory_optimize'
,
'--memory_optimize'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -229,9 +234,9 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
...
@@ -229,9 +234,9 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print
(
"Pass: %d, Iter: %d, Loss: %f
\n
"
%
print
(
"Pass: %d, Iter: %d, Loss: %f
\n
"
%
(
pass_id
,
iters
,
np
.
mean
(
train_losses
)))
(
pass_id
,
iters
,
np
.
mean
(
train_losses
)))
print_train_time
(
start_time
,
time
.
time
(),
num_samples
)
print_train_time
(
start_time
,
time
.
time
(),
num_samples
)
print
(
"Pass: %d, Loss: %f"
%
(
pass_id
,
np
.
mean
(
train_losses
)))
print
(
"Pass: %d, Loss: %f"
%
(
pass_id
,
np
.
mean
(
train_losses
)))
,
# evaluation
# evaluation
if
not
args
.
no_test
and
batch_acc
!=
None
:
if
not
args
.
no_test
and
batch_acc
:
pass_test_acc
=
test
(
exe
,
infer_prog
,
test_reader
,
feeder
,
pass_test_acc
=
test
(
exe
,
infer_prog
,
test_reader
,
feeder
,
batch_acc
)
batch_acc
)
print
(
", Test Accuracy: %f"
%
pass_test_acc
)
print
(
", Test Accuracy: %f"
%
pass_test_acc
)
...
@@ -310,7 +315,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
...
@@ -310,7 +315,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
print
(
"Pass %d, batch %d, loss %s"
%
print
(
"Pass %d, batch %d, loss %s"
%
(
pass_id
,
batch_id
,
np
.
array
(
loss
)))
(
pass_id
,
batch_id
,
np
.
array
(
loss
)))
print_train_time
(
start_time
,
time
.
time
(),
num_samples
)
print_train_time
(
start_time
,
time
.
time
(),
num_samples
)
if
not
args
.
no_test
and
batch_acc
!=
None
:
if
not
args
.
no_test
and
batch_acc
:
test_acc
=
test
(
startup_exe
,
infer_prog
,
test_reader
,
feeder
,
test_acc
=
test
(
startup_exe
,
infer_prog
,
test_reader
,
feeder
,
batch_acc
)
batch_acc
)
print
(
"Pass: %d, Test Accuracy: %f
\n
"
%
(
pass_id
,
test_acc
))
print
(
"Pass: %d, Test Accuracy: %f
\n
"
%
(
pass_id
,
test_acc
))
...
...
benchmark/fluid/models/mnist.py
浏览文件 @
ea73fb84
...
@@ -69,15 +69,30 @@ def get_model(args):
...
@@ -69,15 +69,30 @@ def get_model(args):
images
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
DTYPE
)
images
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
DTYPE
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
if
args
.
device
==
'CPU'
and
args
.
cpus
>
1
:
places
=
fluid
.
layers
.
get_places
(
args
.
cpus
)
pd
=
fluid
.
layers
.
ParallelDo
(
places
)
with
pd
.
do
():
predict
=
cnn_model
(
pd
.
read_input
(
images
))
label
=
pd
.
read_input
(
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
pd
.
write_output
(
avg_cost
)
pd
.
write_output
(
batch_acc
)
avg_cost
,
batch_acc
=
pd
()
avg_cost
=
fluid
.
layers
.
mean
(
avg_cost
)
batch_acc
=
fluid
.
layers
.
mean
(
batch_acc
)
else
:
# Train program
# Train program
predict
=
cnn_model
(
images
)
predict
=
cnn_model
(
images
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
# Evaluator
# Evaluator
batch_size_tensor
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
,
total
=
batch_size_tensor
)
# inference program
# inference program
inference_program
=
fluid
.
default_main_program
().
clone
()
inference_program
=
fluid
.
default_main_program
().
clone
()
...
...
benchmark/fluid/models/resnet.py
浏览文件 @
ea73fb84
...
@@ -132,18 +132,33 @@ def get_model(args):
...
@@ -132,18 +132,33 @@ def get_model(args):
input
=
fluid
.
layers
.
data
(
name
=
'data'
,
shape
=
dshape
,
dtype
=
'float32'
)
input
=
fluid
.
layers
.
data
(
name
=
'data'
,
shape
=
dshape
,
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
predict
=
model
(
input
,
class_dim
)
if
args
.
device
==
'CPU'
and
args
.
cpus
>
1
:
places
=
fluid
.
layers
.
get_places
(
args
.
cpus
)
pd
=
fluid
.
layers
.
ParallelDo
(
places
)
with
pd
.
do
():
predict
=
model
(
pd
.
read_input
(
input
),
class_dim
)
label
=
pd
.
read_input
(
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
batch_size_tensor
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
pd
.
write_output
(
avg_cost
)
batch_acc
=
fluid
.
layers
.
accuracy
(
pd
.
write_output
(
batch_acc
)
input
=
predict
,
label
=
label
,
total
=
batch_size_tensor
)
avg_cost
,
batch_acc
=
pd
()
avg_cost
=
fluid
.
layers
.
mean
(
avg_cost
)
batch_acc
=
fluid
.
layers
.
mean
(
batch_acc
)
else
:
predict
=
model
(
input
,
class_dim
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
inference_program
=
fluid
.
default_main_program
().
clone
()
inference_program
=
fluid
.
default_main_program
().
clone
()
with
fluid
.
program_guard
(
inference_program
):
with
fluid
.
program_guard
(
inference_program
):
inference_program
=
fluid
.
io
.
get_inference_program
(
inference_program
=
fluid
.
io
.
get_inference_program
(
target_vars
=
[
batch_acc
,
batch_size_tensor
])
target_vars
=
[
batch_acc
])
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
)
...
...
benchmark/fluid/models/stacked_dynamic_lstm.py
浏览文件 @
ea73fb84
...
@@ -101,9 +101,8 @@ def get_model(args):
...
@@ -101,9 +101,8 @@ def get_model(args):
loss
=
fluid
.
layers
.
mean
(
x
=
loss
)
loss
=
fluid
.
layers
.
mean
(
x
=
loss
)
# add acc
# add acc
batch_size_tensor
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
logit
,
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
\
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
logit
,
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
\
shape
=
[
1
],
dtype
=
'int64'
)
,
total
=
batch_size_tensor
)
shape
=
[
1
],
dtype
=
'int64'
))
inference_program
=
fluid
.
default_main_program
().
clone
()
inference_program
=
fluid
.
default_main_program
().
clone
()
with
fluid
.
program_guard
(
inference_program
):
with
fluid
.
program_guard
(
inference_program
):
...
...
cmake/external/grpc.cmake
浏览文件 @
ea73fb84
...
@@ -45,6 +45,7 @@ ExternalProject_Add(
...
@@ -45,6 +45,7 @@ ExternalProject_Add(
# checkout and clean other dirs under third_party
# checkout and clean other dirs under third_party
# 4. remove .git, and package the directory.
# 4. remove .git, and package the directory.
URL
"http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
URL
"http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
URL_MD5
"c9c58ee7d0e8929a63155af6a2ecdbd0"
PREFIX
${
GRPC_SOURCES_DIR
}
PREFIX
${
GRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
doc/fluid/howto/optimization/host_memory_profiling_cn.md
0 → 100644
浏览文件 @
ea73fb84
## 堆内存分析和优化
计算机程序都可能有内存泄漏的风险。
**内存泄漏**
一般是由于程序在堆(heap)上分配了内存而没有释放,随着程序的运行占用的内存越来越大,一方面会影响程序的稳定性,可能让运行速度越来越慢,或者造成oom,甚至会影响运行程序的机器的稳定性,造成宕机。
目前有很多内存泄漏分析工具,比较经典的有
[
valgrind
](
http://valgrind.org/docs/manual/quick-start.html#quick-start.intro
)
,
[
gperftools
](
https://gperftools.github.io/gperftools/
)
。
因为Fluid是用Python驱动C++ core来运行,valgrind直接分析非常困难,需要自己编译debug版本的、带valgrind支持的专用Python版本,而且输出的信息中大部分是Python自己的符号和调用信息,分析起来很困难,另外使用valgrind会让程序运行速度变得非常慢,所以不建议使用。
本教程主要介绍
[
gperftools
](
https://gperftools.github.io/gperftools/
)
的使用。
gperftool主要支持以下四个功能:
-
thread-caching malloc
-
heap-checking using tcmalloc
-
heap-profiling using tcmalloc
-
CPU profiler
Paddle也提供了基于gperftool的
[
CPU性能分析教程
](
https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/howto/optimization/cpu_profiling_cn.md
)
。
对于堆内存的分析,主要用到thread-caching malloc和heap-profiling using tcmalloc。
## 使用流程
#### 环境
本教程基于paddle提供的Docker开发环境paddlepaddle/paddle:latest-dev,基于Ubuntu 16.04.4 LTS环境。
#### 使用流程
-
安装google-perftools
```
apt-get install libunwind-dev
apt-get install google-perftools
```
-
安装pprof
```
go get -u github.com/google/pprof
```
-
设置运行环境
```
export PPROF_PATH=/root/gopath/bin/pprof
export PPROF_BINARY_PATH=/root/gopath/bin/pprof
export LD_PRELOAD=/usr/lib/libtcmalloc.so.4
```
-
使用heap profile来运行python程序。本质上是周期性的对堆的分配情况做一次快照。
```
# HEAPPROFILE 设置生成的堆分析文件的目录和文件前缀
# HEAP_PROFILE_ALLOCATION_INTERVAL 设置每分配多少存储dump一次dump,默认1GB
env HEAPPROFILE="./perf_log/test.log" HEAP_PROFILE_ALLOCATION_INTERVAL=209715200 python trainer.py
```
随着程序的运行,会在perf_log这个文件夹下生成很多文件,如下:
```
-rw-r--r-- 1 root root 1.0M Jun 1 15:00 test.log.0001.heap
-rw-r--r-- 1 root root 1.0M Jun 1 15:00 test.log.0002.heap
-rw-r--r-- 1 root root 1.0M Jun 1 15:00 test.log.0003.heap
-rw-r--r-- 1 root root 1.0M Jun 1 15:00 test.log.0004.heap
-rw-r--r-- 1 root root 1.0M Jun 1 15:00 test.log.0005.heap
-rw-r--r-- 1 root root 1.0M Jun 1 15:00 test.log.0006.heap
```
-
使用pprof对heap文件进行分析。分析有两种模式:
-
完整模式。会对当前heap做一个分析,显示目前分配内存一些调用路径。
```
pprof --pdf python test.log.0012.heap
```
上述命令会生成一个profile00x.pdf的文件,可以直接打开,例如:[memory_cpu_allocator](https://github.com/jacquesqiao/Paddle/blob/bd2ea0e1f84bb6522a66d44a072598153634cade/doc/fluid/howto/optimization/memory_cpu_allocator.pdf)。从下图可以看出,在CPU版本fluid的运行过程中,分配存储最多的模块式CPUAllocator. 而别的模块相对而言分配内存较少,所以被忽略了,这对于分配内存泄漏是很不方便的,因为泄漏是一个缓慢的过程,在这种图中是无法看到的。

- Diff模式。可以对两个时刻的heap做diff,把一些内存分配没有发生变化的模块去掉,而把增量部分显示出来。
```
pprof --pdf --base test.log.0010.heap python test.log.1045.heap
```
生成的结果为:[`memory_leak_protobuf`](https://github.com/jacquesqiao/Paddle/blob/bd2ea0e1f84bb6522a66d44a072598153634cade/doc/fluid/howto/optimization/memory_leak_protobuf.pdf)
从图中可以看出:ProgramDesc这个结构,在两个版本之间增长了200MB+,所以这里有很大的内存泄漏的可能性,最终结果也确实证明是这里造成了泄漏。


paddle/contrib/inference/demo/simple_on_word2vec.cc
浏览文件 @
ea73fb84
...
@@ -65,7 +65,10 @@ void Main(bool use_gpu) {
...
@@ -65,7 +65,10 @@ void Main(bool use_gpu) {
}
}
TEST
(
demo
,
word2vec_cpu
)
{
Main
(
false
/*use_gpu*/
);
}
TEST
(
demo
,
word2vec_cpu
)
{
Main
(
false
/*use_gpu*/
);
}
#ifdef PADDLE_WITH_CUDA
TEST
(
demo
,
word2vec_gpu
)
{
Main
(
true
/*use_gpu*/
);
}
TEST
(
demo
,
word2vec_gpu
)
{
Main
(
true
/*use_gpu*/
);
}
#endif
}
// namespace demo
}
// namespace demo
}
// namespace paddle
}
// namespace paddle
paddle/contrib/inference/paddle_inference_api.h
浏览文件 @
ea73fb84
...
@@ -63,6 +63,7 @@ class PaddlePredictor {
...
@@ -63,6 +63,7 @@ class PaddlePredictor {
struct
Config
;
struct
Config
;
PaddlePredictor
()
=
default
;
PaddlePredictor
()
=
default
;
PaddlePredictor
(
const
PaddlePredictor
&
)
=
delete
;
PaddlePredictor
(
const
PaddlePredictor
&
)
=
delete
;
PaddlePredictor
&
operator
=
(
const
PaddlePredictor
&
)
=
delete
;
// Predict an record.
// Predict an record.
// The caller should be responsible for allocating and releasing the memory of
// The caller should be responsible for allocating and releasing the memory of
...
@@ -76,7 +77,7 @@ class PaddlePredictor {
...
@@ -76,7 +77,7 @@ class PaddlePredictor {
virtual
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
=
0
;
virtual
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
=
0
;
// Destroy the Predictor.
// Destroy the Predictor.
virtual
~
PaddlePredictor
()
{}
virtual
~
PaddlePredictor
()
=
default
;
// The common configs for all the predictors.
// The common configs for all the predictors.
struct
Config
{
struct
Config
{
...
...
paddle/contrib/inference/paddle_inference_api_impl.cc
浏览文件 @
ea73fb84
...
@@ -54,7 +54,8 @@ std::string num2str(T a) {
...
@@ -54,7 +54,8 @@ std::string num2str(T a) {
}
}
}
// namespace
}
// namespace
bool
NativePaddlePredictor
::
Init
()
{
bool
NativePaddlePredictor
::
Init
(
std
::
shared_ptr
<
framework
::
Scope
>
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
if
(
config_
.
use_gpu
)
{
...
@@ -62,9 +63,15 @@ bool NativePaddlePredictor::Init() {
...
@@ -62,9 +63,15 @@ bool NativePaddlePredictor::Init() {
}
else
{
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
place_
=
paddle
::
platform
::
CPUPlace
();
}
}
if
(
parent_scope
)
{
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
paddle
::
framework
::
InitDevices
(
false
);
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
// Initialize the inference program
// Initialize the inference program
if
(
!
config_
.
model_dir
.
empty
())
{
if
(
!
config_
.
model_dir
.
empty
())
{
...
@@ -83,13 +90,8 @@ bool NativePaddlePredictor::Init() {
...
@@ -83,13 +90,8 @@ bool NativePaddlePredictor::Init() {
return
false
;
return
false
;
}
}
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
executor_
->
CreateVariables
(
// Create temporary variables first, so that the first batch do not need to
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// create variables in the runtime. This is the logics of the old inference
// API.
// TODO(Superjomn) this should be modified when `Clone` is valid for
// multi-thread application.
executor_
->
CreateVariables
(
*
inference_program_
,
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
// Get the feed_target_names and fetch_target_names
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
...
@@ -97,6 +99,13 @@ bool NativePaddlePredictor::Init() {
...
@@ -97,6 +99,13 @@ bool NativePaddlePredictor::Init() {
return
true
;
return
true
;
}
}
NativePaddlePredictor
::~
NativePaddlePredictor
()
{
if
(
sub_scope_
)
{
PADDLE_ENFORCE_NOT_NULL
(
scope_
,
"Should have parent scope!"
);
scope_
->
DeleteScope
(
sub_scope_
);
}
};
bool
NativePaddlePredictor
::
Run
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
bool
NativePaddlePredictor
::
Run
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
output_data
)
{
std
::
vector
<
PaddleTensor
>
*
output_data
)
{
VLOG
(
3
)
<<
"Predictor::predict"
;
VLOG
(
3
)
<<
"Predictor::predict"
;
...
@@ -121,8 +130,9 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
...
@@ -121,8 +130,9 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
}
}
// Run the inference program
// Run the inference program
// if share variables, we need not create variables
// if share variables, we need not create variables
executor_
->
RunPreparedContext
(
ctx_
.
get
(),
executor_
->
RunPreparedContext
(
scope_
.
get
(),
ctx_
.
get
(),
sub_scope_
!=
nullptr
?
sub_scope_
:
scope_
.
get
(),
&
feed_targets
,
&
feed_targets
,
&
fetch_targets
,
&
fetch_targets
,
false
/* don't create variable eatch time */
);
false
/* don't create variable eatch time */
);
...
@@ -138,7 +148,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
...
@@ -138,7 +148,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
VLOG
(
3
)
<<
"Predictor::clone"
;
VLOG
(
3
)
<<
"Predictor::clone"
;
std
::
unique_ptr
<
PaddlePredictor
>
cls
(
new
NativePaddlePredictor
(
config_
));
std
::
unique_ptr
<
PaddlePredictor
>
cls
(
new
NativePaddlePredictor
(
config_
));
if
(
!
dynamic_cast
<
NativePaddlePredictor
*>
(
cls
.
get
())
->
Init
())
{
if
(
!
dynamic_cast
<
NativePaddlePredictor
*>
(
cls
.
get
())
->
Init
(
scope_
))
{
LOG
(
ERROR
)
<<
"fail to call Init"
;
LOG
(
ERROR
)
<<
"fail to call Init"
;
return
nullptr
;
return
nullptr
;
}
}
...
@@ -266,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
...
@@ -266,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
}
}
std
::
unique_ptr
<
PaddlePredictor
>
predictor
(
new
NativePaddlePredictor
(
config
));
std
::
unique_ptr
<
PaddlePredictor
>
predictor
(
new
NativePaddlePredictor
(
config
));
if
(
!
dynamic_cast
<
NativePaddlePredictor
*>
(
predictor
.
get
())
->
Init
())
{
if
(
!
dynamic_cast
<
NativePaddlePredictor
*>
(
predictor
.
get
())
->
Init
(
nullptr
))
{
return
nullptr
;
return
nullptr
;
}
}
return
std
::
move
(
predictor
);
return
std
::
move
(
predictor
);
...
...
paddle/contrib/inference/paddle_inference_api_impl.h
浏览文件 @
ea73fb84
...
@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor {
...
@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor {
explicit
NativePaddlePredictor
(
const
NativeConfig
&
config
)
explicit
NativePaddlePredictor
(
const
NativeConfig
&
config
)
:
config_
(
config
)
{}
:
config_
(
config
)
{}
bool
Init
();
// will only create sub scope if have global scope
bool
Init
(
std
::
shared_ptr
<
framework
::
Scope
>
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
bool
Run
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
output_data
)
override
;
std
::
vector
<
PaddleTensor
>
*
output_data
)
override
;
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
override
;
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
override
;
~
NativePaddlePredictor
()
override
{}
;
~
NativePaddlePredictor
()
override
;
private:
private:
bool
SetFeed
(
const
std
::
vector
<
PaddleTensor
>
&
input_datas
,
bool
SetFeed
(
const
std
::
vector
<
PaddleTensor
>
&
input_datas
,
...
@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor {
...
@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor {
NativeConfig
config_
;
NativeConfig
config_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
std
::
unique_ptr
<
framework
::
Executor
>
executor_
;
std
::
unique_ptr
<
framework
::
Executor
>
executor_
;
std
::
unique
_ptr
<
framework
::
Scope
>
scope_
;
std
::
shared
_ptr
<
framework
::
Scope
>
scope_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
ctx_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
ctx_
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
inference_program_
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
inference_program_
;
std
::
vector
<
std
::
string
>
feed_target_names_
;
std
::
vector
<
std
::
string
>
feed_target_names_
;
std
::
vector
<
std
::
string
>
fetch_target_names_
;
std
::
vector
<
std
::
string
>
fetch_target_names_
;
// Do not use unique_ptr, use parent scope to delete
framework
::
Scope
*
sub_scope_
{
nullptr
};
};
};
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
ea73fb84
...
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
...
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method
)
framework_proto glog lod_rank_table feed_fetch_method
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor
scope_buffered_ssa_graph_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
ea73fb84
...
@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
...
@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
device_context broadcast_op_handle
)
device_context broadcast_op_handle
)
cc_test
(
gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
cc_test
(
gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context gather_op_handle
)
device_context gather_op_handle
)
cc_library
(
scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor
)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
# device_context reduce_op_handle )
paddle/fluid/framework/details/execution_strategy.h
浏览文件 @
ea73fb84
...
@@ -22,6 +22,7 @@ struct ExecutionStrategy {
...
@@ -22,6 +22,7 @@ struct ExecutionStrategy {
size_t
num_threads_
{
0
};
size_t
num_threads_
{
0
};
bool
use_event_
{
true
};
bool
use_event_
{
true
};
bool
allow_op_delay_
{
false
};
bool
allow_op_delay_
{
false
};
size_t
num_iteration_per_drop_scope_
{
100
};
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
0 → 100644
浏览文件 @
ea73fb84
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
ScopeBufferedSSAGraphExecutor
::
ScopeBufferedSSAGraphExecutor
(
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>
&&
underlying_executor
)
:
strategy_
(
std
::
move
(
strategy
)),
underlying_executor_
(
std
::
move
(
underlying_executor
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
var_infos_
(
std
::
move
(
var_infos
)),
places_
(
std
::
move
(
places
))
{}
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
if
(
drop_scope_counter_
==
0
)
{
// Create local scopes.
for
(
auto
it
=
local_scopes_
.
rbegin
();
it
!=
local_scopes_
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
Scope
&
local_scope
=
scope
->
NewScope
();
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
()
=
&
local_scope
;
for
(
auto
&
info
:
var_infos_
)
{
if
(
scope
->
FindVar
(
info
.
name_
)
!=
nullptr
)
{
continue
;
}
if
(
info
.
persistable_
)
{
// Persistable
InitializeVariable
(
scope
->
Var
(
info
.
name_
),
info
.
type_
);
}
else
{
InitializeVariable
(
local_scope
.
Var
(
info
.
name_
),
info
.
type_
);
}
}
}
}
auto
fetch_data
=
underlying_executor_
->
Run
(
fetch_tensors
);
drop_scope_counter_
+=
1
;
if
(
!
fetch_tensors
.
empty
()
||
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
=
0
;
// Wait All computational streams
for
(
auto
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
scope
->
DeleteScope
(
local_scope
);
}
}
return
fetch_data
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
0 → 100644
浏览文件 @
ea73fb84
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
struct
VariableInfo
{
std
::
string
name_
;
proto
::
VarType
::
Type
type_
;
bool
persistable_
;
};
class
ScopeBufferedSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
ScopeBufferedSSAGraphExecutor
(
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
size_t
drop_scope_counter_
{
0
};
ExecutionStrategy
strategy_
;
std
::
unique_ptr
<
SSAGraphExecutor
>
underlying_executor_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
platform
::
Place
>
places_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_executor.cc
浏览文件 @
ea73fb84
...
@@ -17,10 +17,6 @@
...
@@ -17,10 +17,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
SSAGraphExecutor
::
SSAGraphExecutor
(
std
::
unique_ptr
<
SSAGraph
>
&&
graph
)
:
graph_
(
std
::
move
(
graph
))
{}
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
ea73fb84
...
@@ -28,15 +28,11 @@ class SSAGraphExecutor {
...
@@ -28,15 +28,11 @@ class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN
(
SSAGraphExecutor
);
DISABLE_COPY_AND_ASSIGN
(
SSAGraphExecutor
);
public:
public:
// Steal graph inside
SSAGraphExecutor
()
{}
explicit
SSAGraphExecutor
(
std
::
unique_ptr
<
SSAGraph
>
&&
graph
);
virtual
~
SSAGraphExecutor
();
virtual
~
SSAGraphExecutor
();
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
protected:
std
::
unique_ptr
<
SSAGraph
>
graph_
;
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
ea73fb84
...
@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
...
@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
SSAGraph
>
&&
graph
)
std
::
unique_ptr
<
SSAGraph
>
&&
graph
)
:
SSAGraphExecutor
(
std
::
move
(
graph
)),
:
graph_
(
std
::
move
(
graph
)),
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
:
nullptr
),
:
nullptr
),
local_scopes_
(
local_scopes
),
local_scopes_
(
local_scopes
),
...
@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
...
@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
try
{
try
{
if
(
VLOG_IS_ON
(
10
))
{
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
}
op
->
Run
(
strategy_
.
use_event_
);
op
->
Run
(
strategy_
.
use_event_
);
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
running_ops_
--
;
running_ops_
--
;
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
ea73fb84
...
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details
::
OpHandleBase
*
op
);
details
::
OpHandleBase
*
op
);
private:
private:
std
::
unique_ptr
<
SSAGraph
>
graph_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
ea73fb84
...
@@ -23,6 +23,7 @@ limitations under the License. */
...
@@ -23,6 +23,7 @@ limitations under the License. */
#endif
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -42,8 +43,6 @@ class ParallelExecutorPrivate {
...
@@ -42,8 +43,6 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
nccl_ctxs_
;
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
nccl_ctxs_
;
#endif
#endif
std
::
vector
<
std
::
tuple
<
std
::
string
,
proto
::
VarType
::
Type
,
bool
>>
var_types_
;
bool
own_local_scope
;
bool
own_local_scope
;
};
};
...
@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor(
...
@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor(
local_scopes
.
empty
())
{
// Is CUDA
local_scopes
.
empty
())
{
// Is CUDA
BCastParamsToGPUs
(
bcast_vars
);
BCastParamsToGPUs
(
bcast_vars
);
}
}
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// Step 2. Create vars in each scope;
std
::
vector
<
details
::
VariableInfo
>
var_infos
;
for
(
auto
*
var
:
main_program
.
Block
(
0
).
AllVars
())
{
var_infos
.
emplace_back
();
var_infos
.
back
().
name_
=
var
->
Name
();
var_infos
.
back
().
type_
=
var
->
GetType
();
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
details
::
MultiDevSSAGraphBuilder
builder
(
details
::
MultiDevSSAGraphBuilder
builder
(
...
@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor(
...
@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
params
,
member_
->
local_scopes_
,
build_strategy
);
build_strategy
);
#endif
#endif
auto
graph
=
builder
.
Build
(
main_program
);
auto
graph
=
builder
.
Build
(
main_program
);
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
// Step 3. Create vars in each scope;
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
for
(
auto
*
var
:
main_program
.
Block
(
0
).
AllVars
())
{
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
member_
->
var_types_
.
emplace_back
(
var
->
Name
(),
var
->
GetType
(),
member_
->
places_
,
std
::
move
(
member_
->
executor_
)));
var
->
Persistable
());
}
}
}
void
ParallelExecutor
::
BCastParamsToGPUs
(
void
ParallelExecutor
::
BCastParamsToGPUs
(
...
@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs(
...
@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
platform
::
RecordBlock
b
(
0
);
// Create local scopes.
for
(
auto
it
=
member_
->
local_scopes_
.
rbegin
();
it
!=
member_
->
local_scopes_
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
Scope
&
local_scope
=
scope
->
NewScope
();
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
()
=
&
local_scope
;
for
(
auto
&
name_type_pair
:
member_
->
var_types_
)
{
if
(
scope
->
FindVar
(
std
::
get
<
0
>
(
name_type_pair
))
!=
nullptr
)
{
continue
;
}
if
(
std
::
get
<
2
>
(
name_type_pair
))
{
// Persistable
InitializeVariable
(
scope
->
Var
(
std
::
get
<
0
>
(
name_type_pair
)),
std
::
get
<
1
>
(
name_type_pair
));
}
else
{
InitializeVariable
(
local_scope
.
Var
(
std
::
get
<
0
>
(
name_type_pair
)),
std
::
get
<
1
>
(
name_type_pair
));
}
}
}
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
fetch_data
;
// Wait All computational streams
for
(
auto
p
:
member_
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
for
(
auto
&
scope
:
member_
->
local_scopes_
)
{
auto
&
local_scope
=
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
scope
->
DeleteScope
(
local_scope
);
}
}
}
void
ParallelExecutor
::
FeedTensorsIntoLocalScopes
(
void
ParallelExecutor
::
FeedTensorsIntoLocalScopes
(
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
ea73fb84
...
@@ -15,3 +15,9 @@ cc_test(test_subgraph_splitter
...
@@ -15,3 +15,9 @@ cc_test(test_subgraph_splitter
DEPS analysis paddle_fluid tensor
DEPS analysis paddle_fluid tensor
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
)
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
)
set_tests_properties
(
test_subgraph_splitter PROPERTIES DEPENDS test_word2vec
)
set_tests_properties
(
test_subgraph_splitter PROPERTIES DEPENDS test_word2vec
)
cc_test
(
test_dfg_graphviz_draw_pass
SRCS dfg_graphviz_draw_pass_tester.cc
DEPS analysis
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
)
set_tests_properties
(
test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec
)
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h
0 → 100644
浏览文件 @
ea73fb84
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file create an DFG_GraphvizDrawPass which helps to draw a data flow
* graph's structure using graphviz.
*/
#pragma once
#include <fstream>
#include <string>
#include "paddle/fluid/inference/analysis/pass.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
/*
* Output a dot file and write to some place.
*/
class
DFG_GraphvizDrawPass
:
public
DataFlowGraphPass
{
public:
DFG_GraphvizDrawPass
(
const
std
::
string
&
dir
,
const
std
::
string
&
id
)
:
dir_
(
dir
),
id_
(
id
)
{}
bool
Initialize
()
override
{
return
Pass
::
Initialize
();
}
void
Run
(
DataFlowGraph
*
graph
)
override
{
auto
content
=
Draw
(
graph
);
std
::
ofstream
file
(
GenDotPath
());
file
.
write
(
content
.
c_str
(),
content
.
size
());
file
.
close
();
LOG
(
INFO
)
<<
"draw dot to "
<<
GenDotPath
();
}
bool
Finalize
()
override
{
return
Pass
::
Finalize
();
}
Pass
*
CreatePrinterPass
(
std
::
ostream
&
os
,
const
std
::
string
&
banner
)
const
override
{
return
nullptr
;
}
private:
// Path of the dot file to output.
std
::
string
GenDotPath
()
const
{
return
dir_
+
"/"
+
"graph_"
+
id_
+
".dot"
;
}
std
::
string
Draw
(
DataFlowGraph
*
graph
)
{
return
graph
->
DotString
();
}
std
::
string
dir_
;
std
::
string
id_
;
};
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc
0 → 100644
浏览文件 @
ea73fb84
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include <gtest/gtest.h>
#include <fstream>
#include <string>
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
TEST_F
(
DFG_Tester
,
dfg_graphviz_draw_pass_tester
)
{
auto
dfg
=
ProgramDescToDFG
(
desc
);
DFG_GraphvizDrawPass
pass
(
"./"
,
"test"
);
pass
.
Initialize
();
pass
.
Run
(
&
dfg
);
// test content
std
::
ifstream
file
(
"./graph_test.dot"
);
ASSERT_TRUE
(
file
.
is_open
());
std
::
string
line
;
int
no
{
0
};
while
(
std
::
getline
(
file
,
line
))
{
no
++
;
}
ASSERT_EQ
(
no
,
82
);
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
ea73fb84
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
DEFINE_bool
(
cudnn_
algo_use_autotune
,
true
,
DEFINE_bool
(
cudnn_
deterministic
,
true
,
"Whether allow using an autotuning algorithm for convolution "
"Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If "
"operator. The autotuning algorithm may be non-deterministic. If "
"false, the algorithm is deterministic."
);
"false, the algorithm is deterministic."
);
...
@@ -272,7 +272,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -272,7 +272,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
if
(
input_grad
)
{
if
(
input_grad
)
{
if
(
FLAGS_cudnn_
algo_use_autotune
)
{
if
(
FLAGS_cudnn_
deterministic
)
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
handle
,
cudnn_filter_desc
,
handle
,
cudnn_filter_desc
,
...
@@ -297,7 +297,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -297,7 +297,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}
}
if
(
filter_grad
)
{
if
(
filter_grad
)
{
if
(
FLAGS_cudnn_
algo_use_autotune
)
{
if
(
FLAGS_cudnn_
deterministic
)
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
ea73fb84
...
@@ -38,6 +38,25 @@ void RPCClient::Init() {
...
@@ -38,6 +38,25 @@ void RPCClient::Init() {
if
(
rpc_client_
.
get
()
==
nullptr
)
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
rpc_client_
.
reset
(
new
RPCClient
());
}
}
rpc_client_
->
InitEventLoop
();
}
void
RPCClient
::
InitEventLoop
()
{
// start the client process thread
// TODO(wuyi): can make this in a threadpool
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
RPCClient
::
Proceed
,
this
)));
}
RPCClient
::~
RPCClient
()
{
Wait
();
cq_
.
Shutdown
();
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
for
(
auto
&
it
:
channels_
)
{
it
.
second
.
reset
();
}
}
client_thread_
->
join
();
}
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
...
@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_
++
;
req_count_
++
;
}
}
bool
RPCClient
::
Wait
()
{
void
RPCClient
::
Wait
()
{
VLOG
(
3
)
<<
"RPCClient begin Wait()"
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
<<
" req_count_:"
<<
req_count_
;
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
if
(
req_count_
<=
0
)
{
return
true
;
}
const
size_t
kReqCnt
=
req_count_
;
bool
a
[
kReqCnt
];
std
::
vector
<
std
::
future
<
void
>>
waits
(
req_count_
);
std
::
mutex
mu
;
for
(
int
i
=
0
;
i
<
req_count_
;
i
++
)
{
waits
[
i
]
=
framework
::
AsyncIO
([
i
,
&
a
,
&
mu
,
this
]
{
bool
ret
=
Proceed
();
std
::
lock_guard
<
std
::
mutex
>
l
(
mu
);
a
[
i
]
=
ret
;
});
}
for
(
int
i
=
0
;
i
<
req_count_
;
i
++
)
{
waits
[
i
].
wait
();
}
int
last_req_count
=
req_count_
;
req_count_
=
0
;
for
(
int
i
=
0
;
i
<
last_req_count
;
i
++
)
{
if
(
!
a
[
i
])
{
return
false
;
}
}
return
true
;
}
}
bool
RPCClient
::
Proceed
()
{
void
RPCClient
::
Proceed
()
{
void
*
tag
=
NULL
;
void
*
tag
=
nullptr
;
bool
ok
=
false
;
bool
ok
=
false
;
// request counts.
while
(
cq_
.
Next
(
&
tag
,
&
ok
))
{
if
(
!
cq_
.
Next
(
&
tag
,
&
ok
))
{
LOG
(
ERROR
)
<<
"Get meets CompletionQueue error"
;
return
false
;
}
GPR_ASSERT
(
ok
);
PADDLE_ENFORCE
(
tag
);
// TODO(gongwb): add more retries.
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
if
(
!
c
->
status_
.
ok
())
{
GPR_ASSERT
(
ok
);
LOG
(
ERROR
)
<<
"proc param error:"
<<
c
->
var_h_
.
String
()
PADDLE_ENFORCE
(
c
);
if
(
c
->
status_
.
ok
())
{
c
->
Process
();
}
else
{
LOG
(
ERROR
)
<<
"var: "
<<
c
->
var_h_
.
String
()
<<
" grpc error:"
<<
c
->
status_
.
error_message
();
<<
" grpc error:"
<<
c
->
status_
.
error_message
();
delete
c
;
return
false
;
}
}
c
->
Process
();
delete
c
;
delete
c
;
return
true
;
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
sync_mutex_
);
req_count_
--
;
}
sync_cond_
.
notify_all
();
}
}
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
// TODO(Yancey1989): make grpc client completely thread-safe
// TODO(Yancey1989): make grpc client completely thread-safe
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_
mutex_
);
auto
it
=
channels_
.
find
(
ep
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
ea73fb84
...
@@ -17,14 +17,17 @@ limitations under the License. */
...
@@ -17,14 +17,17 @@ limitations under the License. */
#include <time.h>
#include <time.h>
#include <chrono> // NOLINT
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <ctime>
#include <ctime>
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <map>
#include <map>
#include <mutex> // NOLINT
#include <mutex> // NOLINT
#include <string>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include "grpc++/channel.h"
#include "grpc++/generic/generic_stub.h"
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/byte_buffer.h"
...
@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor {
...
@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor {
class
RPCClient
{
class
RPCClient
{
public:
public:
RPCClient
()
{}
RPCClient
()
{}
~
RPCClient
();
static
RPCClient
*
GetInstance
();
static
RPCClient
*
GetInstance
();
...
@@ -192,19 +196,28 @@ class RPCClient {
...
@@ -192,19 +196,28 @@ class RPCClient {
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
600
*
1000
);
int64_t
time_out
=
600
*
1000
);
bool
Wait
();
void
Wait
();
// InitEventLoop should only be called by Init()
void
InitEventLoop
();
private:
private:
bool
Proceed
();
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
// Init is called by GetInstance.
static
void
Init
();
static
void
Init
();
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
unique_ptr
<
std
::
thread
>
client_thread_
;
// mutex for Wait client sync
std
::
mutex
sync_mutex_
;
std
::
condition_variable
sync_cond_
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
mutex
mutex_
;
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
once_flag
init_flag_
;
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
ea73fb84
...
@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase {
...
@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase {
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
}
virtual
~
RequestSend
()
{}
virtual
~
RequestSend
()
{}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
void
Process
()
override
{
...
@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase {
...
@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase {
framework
::
Variable
*
outvar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
...
@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase {
...
@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase {
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
&
reply_
);
}
}
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
...
@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase {
...
@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
&
reply_
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
status_
=
FINISH
;
}
}
protected:
protected:
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
ea73fb84
...
@@ -113,10 +113,6 @@ void StartServer() {
...
@@ -113,10 +113,6 @@ void StartServer() {
std
::
thread
server_thread
(
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
// FIXME(gongwb): don't use hard time.
sleep
(
10
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
server_thread
.
join
();
}
}
...
@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) {
...
@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) {
std
::
thread
server_thread
(
StartServer
);
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
client
;
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
()
;
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
...
@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) {
...
@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
client
.
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
.
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
...
@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) {
...
@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) {
}
}
}
}
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
g_rpc_service
.
reset
(
nullptr
);
...
...
paddle/fluid/operators/detection/box_coder_op.cc
浏览文件 @
ea73fb84
...
@@ -22,21 +22,21 @@ class BoxCoderOp : public framework::OperatorWithKernel {
...
@@ -22,21 +22,21 @@ class BoxCoderOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PriorBox"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PriorBox"
),
"Input(PriorBox) of BoxCoderOp should not be null."
);
"Input(PriorBox) of BoxCoderOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PriorBoxVar"
),
"Input(PriorBoxVar) of BoxCoderOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"TargetBox"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"TargetBox"
),
"Input(TargetBox) of BoxCoderOp should not be null."
);
"Input(TargetBox) of BoxCoderOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutputBox"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutputBox"
),
"Output(OutputBox) of BoxCoderOp should not be null."
);
"Output(OutputBox) of BoxCoderOp should not be null."
);
auto
prior_box_dims
=
ctx
->
GetInputDim
(
"PriorBox"
);
auto
prior_box_dims
=
ctx
->
GetInputDim
(
"PriorBox"
);
auto
prior_box_var_dims
=
ctx
->
GetInputDim
(
"PriorBoxVar"
);
auto
target_box_dims
=
ctx
->
GetInputDim
(
"TargetBox"
);
auto
target_box_dims
=
ctx
->
GetInputDim
(
"TargetBox"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
prior_box_dims
.
size
(),
2
,
"The rank of Input of PriorBoxVar must be 2"
);
"The rank of Input of PriorBoxVar must be 2"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
[
1
],
4
,
"The shape of PriorBox is [N, 4]"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
[
1
],
4
,
"The shape of PriorBox is [N, 4]"
);
if
(
ctx
->
HasInput
(
"PriorBoxVar"
))
{
auto
prior_box_var_dims
=
ctx
->
GetInputDim
(
"PriorBoxVar"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
);
}
auto
code_type
=
GetBoxCodeType
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"code_type"
));
auto
code_type
=
GetBoxCodeType
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"code_type"
));
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
...
@@ -71,9 +71,11 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -71,9 +71,11 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"of the coordinate system. [xmax, ymax] is the right bottom "
"of the coordinate system. [xmax, ymax] is the right bottom "
"coordinate of the anchor box."
);
"coordinate of the anchor box."
);
AddInput
(
"PriorBoxVar"
,
AddInput
(
"PriorBoxVar"
,
"(Tensor, default Tensor<float>) "
"(Tensor, default Tensor<float>
, optional
) "
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group "
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group "
"of variance."
);
"of variance. PriorBoxVar will set all elements to 1 by "
"default."
)
.
AsDispensable
();
AddInput
(
AddInput
(
"TargetBox"
,
"TargetBox"
,
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
...
@@ -131,5 +133,6 @@ width and height.
...
@@ -131,5 +133,6 @@ width and height.
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
box_coder
,
ops
::
BoxCoderOp
,
ops
::
BoxCoderOpMaker
,
REGISTER_OPERATOR
(
box_coder
,
ops
::
BoxCoderOp
,
ops
::
BoxCoderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
box_coder
,
ops
::
BoxCoderKernel
<
float
>
,
REGISTER_OP_CPU_KERNEL
(
ops
::
BoxCoderKernel
<
double
>
);
box_coder
,
ops
::
BoxCoderKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
BoxCoderKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/detection/box_coder_op.cu
浏览文件 @
ea73fb84
...
@@ -48,15 +48,18 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
...
@@ -48,15 +48,18 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
target_box_data
[
row_idx
*
len
+
1
]
+
target_box_data
[
row_idx
*
len
+
1
]
+
(
normalized
==
false
);
(
normalized
==
false
);
output
[
idx
*
len
]
=
(
target_box_center_x
-
prior_box_center_x
)
/
output
[
idx
*
len
]
=
prior_box_width
/
prior_box_var_data
[
col_idx
*
len
];
(
target_box_center_x
-
prior_box_center_x
)
/
prior_box_width
;
output
[
idx
*
len
+
1
]
=
(
target_box_center_y
-
prior_box_center_y
)
/
output
[
idx
*
len
+
1
]
=
prior_box_height
/
(
target_box_center_y
-
prior_box_center_y
)
/
prior_box_height
;
prior_box_var_data
[
col_idx
*
len
+
1
];
output
[
idx
*
len
+
2
]
=
log
(
fabs
(
target_box_width
/
prior_box_width
));
output
[
idx
*
len
+
2
]
=
log
(
fabs
(
target_box_width
/
prior_box_width
))
/
output
[
idx
*
len
+
3
]
=
log
(
fabs
(
target_box_height
/
prior_box_height
));
prior_box_var_data
[
col_idx
*
len
+
2
];
if
(
prior_box_var_data
)
{
output
[
idx
*
len
+
3
]
=
log
(
fabs
(
target_box_height
/
prior_box_height
))
/
output
[
idx
*
len
]
/=
prior_box_var_data
[
col_idx
*
len
];
prior_box_var_data
[
col_idx
*
len
+
3
];
output
[
idx
*
len
+
1
]
/=
prior_box_var_data
[
col_idx
*
len
+
1
];
output
[
idx
*
len
+
2
]
/=
prior_box_var_data
[
col_idx
*
len
+
2
];
output
[
idx
*
len
+
3
]
/=
prior_box_var_data
[
col_idx
*
len
+
3
];
}
}
}
}
}
...
@@ -79,20 +82,31 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
...
@@ -79,20 +82,31 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
T
prior_box_center_y
=
(
prior_box_data
[
col_idx
*
len
+
3
]
+
T
prior_box_center_y
=
(
prior_box_data
[
col_idx
*
len
+
3
]
+
prior_box_data
[
col_idx
*
len
+
1
])
/
prior_box_data
[
col_idx
*
len
+
1
])
/
2
;
2
;
T
target_box_width
,
target_box_height
;
T
target_box_width
=
exp
(
prior_box_var_data
[
col_idx
*
len
+
2
]
*
T
target_box_center_x
,
target_box_center_y
;
if
(
prior_box_var_data
)
{
target_box_width
=
exp
(
prior_box_var_data
[
col_idx
*
len
+
2
]
*
target_box_data
[
idx
*
len
+
2
])
*
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
prior_box_width
;
T
target_box_height
=
exp
(
prior_box_var_data
[
col_idx
*
len
+
3
]
*
target_box_height
=
exp
(
prior_box_var_data
[
col_idx
*
len
+
3
]
*
target_box_data
[
idx
*
len
+
3
])
*
target_box_data
[
idx
*
len
+
3
])
*
prior_box_height
;
prior_box_height
;
T
target_box_center_x
=
prior_box_var_data
[
col_idx
*
len
]
*
target_box_center_x
=
prior_box_var_data
[
col_idx
*
len
]
*
target_box_data
[
idx
*
len
]
*
prior_box_width
+
target_box_data
[
idx
*
len
]
*
prior_box_width
+
prior_box_center_x
;
prior_box_center_x
;
T
target_box_center_y
=
prior_box_var_data
[
col_idx
*
len
+
1
]
*
target_box_center_y
=
prior_box_var_data
[
col_idx
*
len
+
1
]
*
target_box_data
[
idx
*
len
+
1
]
*
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_height
+
prior_box_center_y
;
prior_box_center_y
;
}
else
{
target_box_width
=
exp
(
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_height
=
exp
(
target_box_data
[
idx
*
len
+
3
])
*
prior_box_height
;
target_box_center_x
=
target_box_data
[
idx
*
len
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_center_y
;
}
output
[
idx
*
len
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
idx
*
len
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
idx
*
len
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
output
[
idx
*
len
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
...
@@ -103,7 +117,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
...
@@ -103,7 +117,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
}
}
}
}
template
<
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
BoxCoderCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
BoxCoderCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
@@ -114,6 +128,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
...
@@ -114,6 +128,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
if
(
prior_box_var
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
if
(
target_box
->
lod
().
size
())
{
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1
,
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1
,
"Only support 1 level of LoD."
);
"Only support 1 level of LoD."
);
...
@@ -125,10 +144,6 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
...
@@ -125,10 +144,6 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
int
grid
=
(
row
*
col
+
block
-
1
)
/
block
;
int
grid
=
(
row
*
col
+
block
-
1
)
/
block
;
auto
&
device_ctx
=
context
.
cuda_device_context
();
auto
&
device_ctx
=
context
.
cuda_device_context
();
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
output_box
->
mutable_data
<
T
>
({
row
,
col
,
len
},
context
.
GetPlace
());
output_box
->
mutable_data
<
T
>
({
row
,
col
,
len
},
context
.
GetPlace
());
T
*
output
=
output_box
->
data
<
T
>
();
T
*
output
=
output_box
->
data
<
T
>
();
...
@@ -150,5 +165,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
...
@@ -150,5 +165,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
box_coder
,
ops
::
BoxCoderCUDAKernel
<
float
>
,
REGISTER_OP_CUDA_KERNEL
(
ops
::
BoxCoderCUDAKernel
<
double
>
);
box_coder
,
ops
::
BoxCoderCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
BoxCoderCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/detection/box_coder_op.h
浏览文件 @
ea73fb84
...
@@ -28,19 +28,20 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) {
...
@@ -28,19 +28,20 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) {
PADDLE_THROW
(
"Not support type %s."
,
type
);
PADDLE_THROW
(
"Not support type %s."
,
type
);
}
}
template
<
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
BoxCoderKernel
:
public
framework
::
OpKernel
<
T
>
{
class
BoxCoderKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
EncodeCenterSize
(
const
framework
::
Tensor
&
target_box
,
void
EncodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
&
prior_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
&
prior_box_var
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
T
*
output
)
const
{
const
bool
normalized
,
T
*
output
)
const
{
int64_t
row
=
target_box
.
dims
()[
0
];
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
.
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
.
dims
()[
1
];
int64_t
len
=
prior_box
->
dims
()[
1
];
auto
*
target_box_data
=
target_box
.
data
<
T
>
();
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
.
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
auto
*
prior_box_var_data
=
prior_box_var
.
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
if
(
prior_box_var
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
...
@@ -65,30 +66,35 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -65,30 +66,35 @@ class BoxCoderKernel : public framework::OpKernel<T> {
(
normalized
==
false
);
(
normalized
==
false
);
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
output
[
offset
]
=
(
target_box_center_x
-
prior_box_center_x
)
/
output
[
offset
]
=
prior_box_width
/
prior_box_var_data
[
j
*
len
]
;
(
target_box_center_x
-
prior_box_center_x
)
/
prior_box_width
;
output
[
offset
+
1
]
=
(
target_box_center_y
-
prior_box_center_y
)
/
output
[
offset
+
1
]
=
prior_box_height
/
prior_box_var_data
[
j
*
len
+
1
]
;
(
target_box_center_y
-
prior_box_center_y
)
/
prior_box_height
;
output
[
offset
+
2
]
=
output
[
offset
+
2
]
=
std
::
log
(
std
::
fabs
(
target_box_width
/
prior_box_width
))
/
std
::
log
(
std
::
fabs
(
target_box_width
/
prior_box_width
));
prior_box_var_data
[
j
*
len
+
2
];
output
[
offset
+
3
]
=
output
[
offset
+
3
]
=
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
))
/
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
));
prior_box_var_data
[
j
*
len
+
3
];
if
(
prior_box_var
)
{
output
[
offset
]
/=
prior_box_var_data
[
j
*
len
];
output
[
offset
+
1
]
/=
prior_box_var_data
[
j
*
len
+
1
];
output
[
offset
+
2
]
/=
prior_box_var_data
[
j
*
len
+
2
];
output
[
offset
+
3
]
/=
prior_box_var_data
[
j
*
len
+
3
];
}
}
}
}
}
}
void
DecodeCenterSize
(
const
framework
::
Tensor
&
target_box
,
}
const
framework
::
Tensor
&
prior_box
,
void
DecodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
&
prior_box_var
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
T
*
output
)
const
{
const
bool
normalized
,
T
*
output
)
const
{
int64_t
row
=
target_box
.
dims
()[
0
];
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
.
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
.
dims
()[
1
];
int64_t
len
=
prior_box
->
dims
()[
1
];
auto
*
target_box_data
=
target_box
.
data
<
T
>
();
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
.
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
auto
*
prior_box_var_data
=
prior_box_var
.
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
if
(
prior_box_var
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
...
@@ -103,19 +109,32 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -103,19 +109,32 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T
prior_box_center_y
=
T
prior_box_center_y
=
(
prior_box_data
[
j
*
len
+
3
]
+
prior_box_data
[
j
*
len
+
1
])
/
2
;
(
prior_box_data
[
j
*
len
+
3
]
+
prior_box_data
[
j
*
len
+
1
])
/
2
;
T
target_box_center_x
=
prior_box_var_data
[
j
*
len
]
*
T
target_box_center_x
=
0
,
target_box_center_y
=
0
;
T
target_box_width
=
0
,
target_box_height
=
0
;
if
(
prior_box_var
)
{
target_box_center_x
=
prior_box_var_data
[
j
*
len
]
*
target_box_data
[
offset
]
*
prior_box_width
+
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
prior_box_center_x
;
T
target_box_center_y
=
prior_box_var_data
[
j
*
len
+
1
]
*
target_box_center_y
=
prior_box_var_data
[
j
*
len
+
1
]
*
target_box_data
[
offset
+
1
]
*
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_height
+
prior_box_center_y
;
prior_box_center_y
;
T
target_box_width
=
std
::
exp
(
prior_box_var_data
[
j
*
len
+
2
]
*
target_box_width
=
std
::
exp
(
prior_box_var_data
[
j
*
len
+
2
]
*
target_box_data
[
offset
+
2
])
*
target_box_data
[
offset
+
2
])
*
prior_box_width
;
prior_box_width
;
T
target_box_height
=
std
::
exp
(
prior_box_var_data
[
j
*
len
+
3
]
*
target_box_height
=
std
::
exp
(
prior_box_var_data
[
j
*
len
+
3
]
*
target_box_data
[
offset
+
3
])
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
prior_box_height
;
}
else
{
target_box_center_x
=
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
target_box_data
[
offset
+
3
])
*
prior_box_height
;
}
output
[
offset
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
offset
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
offset
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
output
[
offset
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
...
@@ -147,10 +166,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -147,10 +166,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
T
*
output
=
output_box
->
data
<
T
>
();
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSize
(
*
target_box
,
*
prior_box
,
*
prior_box_var
,
normalized
,
EncodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
output
);
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSize
(
*
target_box
,
*
prior_box
,
*
prior_box_var
,
normalized
,
DecodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
output
);
output
);
}
}
}
}
...
...
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
ea73fb84
...
@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
...
@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
for
(
auto
&
ep
:
eps
)
{
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"fetch barrier, ep: "
<<
ep
;
VLOG
(
3
)
<<
"fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
};
};
...
...
paddle/fluid/operators/pool_cudnn_op.cu.cc
浏览文件 @
ea73fb84
...
@@ -135,7 +135,11 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
...
@@ -135,7 +135,11 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
PoolingMode
pooling_mode
;
PoolingMode
pooling_mode
;
if
(
pooling_type
==
"max"
)
{
if
(
pooling_type
==
"max"
)
{
if
(
FLAGS_cudnn_deterministic
)
{
pooling_mode
=
PoolingMode
::
kMaximumDeterministic
;
}
else
{
pooling_mode
=
PoolingMode
::
kMaximum
;
pooling_mode
=
PoolingMode
::
kMaximum
;
}
}
else
{
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
PoolingMode
::
kAverage
;
}
}
...
...
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
ea73fb84
...
@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
...
@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
};
};
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
ea73fb84
...
@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
...
@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
if
(
sync_mode
)
{
if
(
sync_mode
)
{
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/reduce_op.h
浏览文件 @
ea73fb84
...
@@ -135,15 +135,16 @@ class ReduceKernel : public framework::OpKernel<T> {
...
@@ -135,15 +135,16 @@ class ReduceKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
int
ndim
=
context
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
();
int
ndim
=
context
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
();
int
rdim
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
).
size
();
int
rdim
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
).
size
();
HANDLE_DIM
(
6
,
5
);
// comments for accelerating compiling temporarily.
HANDLE_DIM
(
6
,
4
);
// HANDLE_DIM(6, 5);
HANDLE_DIM
(
6
,
3
);
// HANDLE_DIM(6, 4);
HANDLE_DIM
(
6
,
2
);
// HANDLE_DIM(6, 3);
HANDLE_DIM
(
6
,
1
);
// HANDLE_DIM(6, 2);
HANDLE_DIM
(
5
,
4
);
// HANDLE_DIM(6, 1);
HANDLE_DIM
(
5
,
3
);
// HANDLE_DIM(5, 4);
HANDLE_DIM
(
5
,
2
);
// HANDLE_DIM(5, 3);
HANDLE_DIM
(
5
,
1
);
// HANDLE_DIM(5, 2);
// HANDLE_DIM(5, 1);
HANDLE_DIM
(
4
,
3
);
HANDLE_DIM
(
4
,
3
);
HANDLE_DIM
(
4
,
2
);
HANDLE_DIM
(
4
,
2
);
HANDLE_DIM
(
4
,
1
);
HANDLE_DIM
(
4
,
1
);
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
ea73fb84
...
@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
// need to wait before sending send_barrier message
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
if
(
sync_mode
)
{
if
(
sync_mode
)
{
for
(
auto
&
ep
:
eps
)
{
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
ea73fb84
...
@@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase {
...
@@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
if
(
sync_mode
)
{
if
(
sync_mode
)
{
for
(
auto
&
ep
:
endpoints
)
{
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
if
(
outs
.
size
()
>
0
)
{
if
(
outs
.
size
()
>
0
)
{
...
@@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase {
...
@@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase {
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
// tell pservers that current trainer have called fetch
// tell pservers that current trainer have called fetch
for
(
auto
&
ep
:
endpoints
)
{
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
2
)
<<
"send fetch barrier, ep: "
<<
ep
;
VLOG
(
2
)
<<
"send fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
ea73fb84
...
@@ -61,7 +61,6 @@ void StartServer() {
...
@@ -61,7 +61,6 @@ void StartServer() {
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
std
::
cout
<<
"before WaitFanInOfSend"
<<
std
::
endl
;
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
...
@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) {
...
@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) {
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
client
;
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
()
;
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
LOG
(
INFO
)
<<
"connect to server
"
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
->
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
Wait
();
client
->
Wait
();
client
.
AsyncSendBatchBarrier
(
ep
);
client
->
AsyncSendBatchBarrier
(
ep
);
client
.
Wait
();
client
->
Wait
();
server_thread
.
join
();
server_thread
.
join
();
g_rpc_service
.
reset
(
nullptr
);
g_rpc_service
.
reset
(
nullptr
);
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
ea73fb84
...
@@ -22,6 +22,8 @@ limitations under the License. */
...
@@ -22,6 +22,8 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -76,8 +78,22 @@ enum class DataLayout { // Not use
...
@@ -76,8 +78,22 @@ enum class DataLayout { // Not use
enum
class
PoolingMode
{
enum
class
PoolingMode
{
kMaximum
,
kMaximum
,
kAverage
,
kAverage
,
kMaximumDeterministic
,
};
};
inline
cudnnPoolingMode_t
GetPoolingMode
(
const
PoolingMode
&
mode
)
{
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX_DETERMINISTIC
;
case
PoolingMode
::
kAverage
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
default:
PADDLE_THROW
(
"Unexpected pooling mode."
);
}
}
template
<
typename
T
>
template
<
typename
T
>
class
CudnnDataType
;
class
CudnnDataType
;
...
@@ -293,9 +309,7 @@ class ScopedPoolingDescriptor {
...
@@ -293,9 +309,7 @@ class ScopedPoolingDescriptor {
PADDLE_ENFORCE_EQ
(
kernel
.
size
(),
pads
.
size
());
PADDLE_ENFORCE_EQ
(
kernel
.
size
(),
pads
.
size
());
PADDLE_ENFORCE_EQ
(
kernel
.
size
(),
strides
.
size
());
PADDLE_ENFORCE_EQ
(
kernel
.
size
(),
strides
.
size
());
PADDLE_ENFORCE
(
dynload
::
cudnnSetPoolingNdDescriptor
(
PADDLE_ENFORCE
(
dynload
::
cudnnSetPoolingNdDescriptor
(
desc_
,
(
mode
==
PoolingMode
::
kMaximum
desc_
,
(
GetPoolingMode
(
mode
)),
?
CUDNN_POOLING_MAX
:
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
),
CUDNN_PROPAGATE_NAN
,
// Always propagate nans.
CUDNN_PROPAGATE_NAN
,
// Always propagate nans.
kernel
.
size
(),
kernel
.
data
(),
pads
.
data
(),
strides
.
data
()));
kernel
.
size
(),
kernel
.
data
(),
pads
.
data
(),
strides
.
data
()));
return
desc_
;
return
desc_
;
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
ea73fb84
...
@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() {
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
void
CUDADeviceContext
::
Wait
()
const
{
std
::
lock_guard
<
std
::
recursive_mutex
>
guard
(
mutex_
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaGetLastError
());
PADDLE_ENFORCE
(
cudaGetLastError
());
}
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
ea73fb84
...
@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
guard
(
mutex_
);
callback
();
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
}
...
@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
mutable
std
::
recursive_mutex
mutex_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
cublasHandle_t
cublas_handle_
;
...
...
paddle/fluid/platform/dynload/cublas.h
浏览文件 @
ea73fb84
...
@@ -45,7 +45,7 @@ extern void *cublas_dso_handle;
...
@@ -45,7 +45,7 @@ extern void *cublas_dso_handle;
std::call_once(cublas_dso_flag, []() { \
std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \
}); \
void *p_##__name = dlsym(cublas_dso_handle, #__name);
\
static void *p_##__name = dlsym(cublas_dso_handle, #__name);
\
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
} \
} \
}; \
}; \
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
ea73fb84
...
@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
...
@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
}); \
}); \
EnforceCUDNNLoaded(#__name); \
EnforceCUDNNLoaded(#__name); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name);
\
static void* p_##__name = dlsym(cudnn_dso_handle, #__name);
\
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \
} \
}; \
}; \
...
...
paddle/fluid/platform/dynload/cupti.h
浏览文件 @
ea73fb84
...
@@ -45,7 +45,7 @@ extern void *cupti_dso_handle;
...
@@ -45,7 +45,7 @@ extern void *cupti_dso_handle;
std::call_once(cupti_dso_flag, []() { \
std::call_once(cupti_dso_flag, []() { \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
}); \
}); \
void *p_##__name = dlsym(cupti_dso_handle, #__name);
\
static void *p_##__name = dlsym(cupti_dso_handle, #__name);
\
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
} \
} \
}; \
}; \
...
...
paddle/fluid/platform/dynload/curand.h
浏览文件 @
ea73fb84
...
@@ -34,7 +34,7 @@ extern void *curand_dso_handle;
...
@@ -34,7 +34,7 @@ extern void *curand_dso_handle;
std::call_once(curand_dso_flag, []() { \
std::call_once(curand_dso_flag, []() { \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
}); \
}); \
void *p_##__name = dlsym(curand_dso_handle, #__name);
\
static void *p_##__name = dlsym(curand_dso_handle, #__name);
\
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \
} \
}; \
}; \
...
...
paddle/fluid/platform/dynload/nccl.h
浏览文件 @
ea73fb84
...
@@ -37,7 +37,7 @@ extern void* nccl_dso_handle;
...
@@ -37,7 +37,7 @@ extern void* nccl_dso_handle;
std::call_once(nccl_dso_flag, []() { \
std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \
}); \
void* p_##__name = dlsym(nccl_dso_handle, #__name);
\
static void* p_##__name = dlsym(nccl_dso_handle, #__name);
\
return reinterpret_cast<nccl_func>(p_##__name)(args...); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \
} \
}; \
}; \
...
...
paddle/fluid/platform/dynload/tensorrt.h
浏览文件 @
ea73fb84
...
@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle;
...
@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle;
paddle::platform::dynload::GetTensorRtDsoHandle(); \
paddle::platform::dynload::GetTensorRtDsoHandle(); \
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
}); \
}); \
void* p_##__name = dlsym(tensorrt_dso_handle, #__name);
\
static void* p_##__name = dlsym(tensorrt_dso_handle, #__name);
\
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \
} \
...
...
paddle/fluid/platform/dynload/warpctc.h
浏览文件 @
ea73fb84
...
@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle;
...
@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle;
std::call_once(warpctc_dso_flag, []() { \
std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
}); \
}); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name);
\
static void* p_##_name = dlsym(warpctc_dso_handle, #__name);
\
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \
} \
}; \
}; \
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
ea73fb84
...
@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle.
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
allow_op_delay_
;
},
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
allow_op_delay_
;
},
[](
ExecutionStrategy
&
self
,
bool
allow_op_delay
)
{
[](
ExecutionStrategy
&
self
,
bool
allow_op_delay
)
{
self
.
allow_op_delay_
=
allow_op_delay
;
self
.
allow_op_delay_
=
allow_op_delay
;
})
.
def_property
(
"num_iteration_per_drop_scope"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
num_iteration_per_drop_scope_
;
},
[](
ExecutionStrategy
&
self
,
size_t
num_iteration_per_drop_scope
)
{
self
.
num_iteration_per_drop_scope_
=
num_iteration_per_drop_scope
;
});
});
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
);
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
);
...
...
python/paddle/fluid/__init__.py
浏览文件 @
ea73fb84
...
@@ -120,7 +120,7 @@ def __bootstrap__():
...
@@ -120,7 +120,7 @@ def __bootstrap__():
]
]
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
read_env_flags
+=
[
read_env_flags
+=
[
'fraction_of_gpu_memory_to_use'
,
'cudnn_
algo_use_autotune
'
'fraction_of_gpu_memory_to_use'
,
'cudnn_
deterministic
'
]
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
[
"--tryfromenv="
+
","
.
join
(
read_env_flags
)])
[
"--tryfromenv="
+
","
.
join
(
read_env_flags
)])
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
ea73fb84
...
@@ -81,6 +81,8 @@ __all__ = [
...
@@ -81,6 +81,8 @@ __all__ = [
'label_smooth'
,
'label_smooth'
,
'roi_pool'
,
'roi_pool'
,
'dice_loss'
,
'dice_loss'
,
'image_resize'
,
'image_resize_short'
,
'resize_bilinear'
,
'resize_bilinear'
,
'gather'
,
'gather'
,
'random_crop'
,
'random_crop'
,
...
@@ -3929,22 +3931,25 @@ def dice_loss(input, label, epsilon=0.00001):
...
@@ -3929,22 +3931,25 @@ def dice_loss(input, label, epsilon=0.00001):
return
reduce_mean
(
dice_score
)
return
reduce_mean
(
dice_score
)
def
resize_bilinear
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
def
image_resize
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
,
resample
=
'BILINEAR'
):
"""
"""
The mathematical meaning of resize bilinear layer is
Resize a batch of images.
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
For details, please refer to Wikipedia:
The input must be a tensor of the shape (num_batches, channels, in_h, in_w),
https://en.wikipedia.org/wiki/Bilinear_interpolation
and the resizing only applies on the last two dimensions(hight and width).
Supporting resample methods:
'BILINEAR' : Bilinear interpolation
Args:
Args:
input (Variable): The input tensor of
resize bilinear
layer,
input (Variable): The input tensor of
image resize
layer,
This is a 4-D tensor of the shape
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
(num_batches, channels, in_h, in_w).
out_shape(list|tuple|Variable|None): Output shape of
resize bilinear
out_shape(list|tuple|Variable|None): Output shape of
image resize
layer, the shape is (out_h, out_w).
layer, the shape is (out_h, out_w).
Default: None
Default: None
scale(float|None): The multiplier for the input height or width.
scale(float|None): The multiplier for the input height or width.
...
@@ -3953,6 +3958,8 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
...
@@ -3953,6 +3958,8 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
Default: None
Default: None
name(str|None): A name for this layer(optional). If set None, the layer
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
resample(str): The resample method. It can only be 'BILINEAR' currently.
Default: 'BILINEAR'
Returns:
Returns:
out (Variable): The output is a 4-D tensor of the shape
out (Variable): The output is a 4-D tensor of the shape
...
@@ -3961,8 +3968,12 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
...
@@ -3961,8 +3968,12 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
out = fluid.layers.
resize_bilinear
(input, out_shape=[12, 12])
out = fluid.layers.
image_resize
(input, out_shape=[12, 12])
"""
"""
resample_methods
=
{
'BILINEAR'
:
'bilinear_interp'
}
if
resample
not
in
resample_methods
:
raise
ValueError
(
"The 'resample' of image_resize can only be 'BILINEAR' currently."
)
if
out_shape
is
None
and
scale
is
None
:
if
out_shape
is
None
and
scale
is
None
:
raise
ValueError
(
"One of out_shape and scale must not be None"
)
raise
ValueError
(
"One of out_shape and scale must not be None"
)
helper
=
LayerHelper
(
'bilinear_interp'
,
**
locals
())
helper
=
LayerHelper
(
'bilinear_interp'
,
**
locals
())
...
@@ -3990,7 +4001,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
...
@@ -3990,7 +4001,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
out
=
helper
.
create_tmp_variable
(
dtype
)
out
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"bilinear_interp"
,
type
=
resample_methods
[
resample
]
,
inputs
=
inputs
,
inputs
=
inputs
,
outputs
=
{
"Out"
:
out
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"out_h"
:
out_h
,
attrs
=
{
"out_h"
:
out_h
,
...
@@ -3998,6 +4009,55 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
...
@@ -3998,6 +4009,55 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
return
out
return
out
def
resize_bilinear
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
"""
This is an alias of layer 'image_resize' with bilinear interpolation.
The mathematical meaning of resize bilinear layer is
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
"""
return
image_resize
(
input
,
out_shape
,
scale
,
name
,
'BILINEAR'
)
def
image_resize_short
(
input
,
out_short_len
,
resample
=
'BILINEAR'
):
"""
Resize a batch of images. The short edge of input images will be
resized to the given 'out_short_len'. The long edge of input images
will be resized proportionately to make images' length-width ratio
constant.
Args:
input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
out_short_len(int): The length of output images' short edge.
Returns:
out (Variable): The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w).
"""
in_shape
=
input
.
shape
if
len
(
in_shape
)
!=
4
:
raise
ValueError
(
"The rank of input must be 4 (num_batches, channels, in_h, in_w)."
)
hw
=
in_shape
[
2
:
4
]
short_idx
=
hw
.
index
(
min
(
hw
))
long_idx
=
1
-
short_idx
out_shape
=
list
(
hw
)
out_shape
[
short_idx
]
=
out_short_len
out_shape
[
long_idx
]
=
int
(
float
(
out_shape
[
long_idx
])
*
(
float
(
out_short_len
)
/
float
(
hw
[
short_idx
]))
+
0.5
)
return
image_resize
(
input
=
input
,
out_shape
=
out_shape
,
resample
=
resample
)
def
gather
(
input
,
index
):
def
gather
(
input
,
index
):
"""
"""
Output is obtained by gathering entries of the outer-most dimension
Output is obtained by gathering entries of the outer-most dimension
...
...
python/paddle/fluid/tests/test_concurrency.py
→
python/paddle/fluid/tests/
no_
test_concurrency.py
浏览文件 @
ea73fb84
文件已移动
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
ea73fb84
...
@@ -43,12 +43,10 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op)
...
@@ -43,12 +43,10 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list
(
REMOVE_ITEM TEST_OPS test_dist_train
)
list
(
REMOVE_ITEM TEST_OPS test_dist_train
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor_crf
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor_crf
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed
)
# TODO(wuyi): this test hungs on CI, will add it back later
list
(
REMOVE_ITEM TEST_OPS test_listen_and_serv_op
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
)
endforeach
(
TEST_OP
)
endforeach
(
TEST_OP
)
py_test_modules
(
test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=
${
WARPCTC_LIB_DIR
}
SERIAL
)
py_test_modules
(
test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=
${
WARPCTC_LIB_DIR
}
SERIAL
)
py_test_modules
(
test_dist_train MODULES test_dist_train SERIAL
)
py_test_modules
(
test_dist_train MODULES test_dist_train SERIAL
)
# FIXME(Yancey1989): this test would cost much more time on CUDAPlace
# since load cudnn libraries, so we use a longer timeout to make this
# unit test stability.
set_tests_properties
(
test_listen_and_serv_op PROPERTIES TIMEOUT 30
)
python/paddle/fluid/tests/unittests/test_box_coder_op.py
浏览文件 @
ea73fb84
...
@@ -120,6 +120,32 @@ class TestBoxCoderOp(OpTest):
...
@@ -120,6 +120,32 @@ class TestBoxCoderOp(OpTest):
self
.
outputs
=
{
'OutputBox'
:
output_box
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithoutBoxVar
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
0
,
1
,
2
,
3
,
4
,
5
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
ones
((
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
,
10
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
,
box_normalized
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
,
'box_normalized'
:
False
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithLoD
(
OpTest
):
class
TestBoxCoderOpWithLoD
(
OpTest
):
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录