Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8855d4a7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8855d4a7
编写于
6月 01, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into dist_recordio
上级
f9556dca
0c0c5df4
变更
90
显示空白变更内容
内联
并排
Showing
90 changed file
with
2218 addition
and
1165 deletion
+2218
-1165
AUTHORS.md
AUTHORS.md
+1
-0
Dockerfile
Dockerfile
+1
-1
benchmark/fluid/Dockerfile
benchmark/fluid/Dockerfile
+22
-0
benchmark/fluid/README.md
benchmark/fluid/README.md
+15
-1
benchmark/fluid/kube_gen_job.py
benchmark/fluid/kube_gen_job.py
+1
-1
benchmark/fluid/run.sh
benchmark/fluid/run.sh
+14
-12
doc/fluid/api/layers.rst
doc/fluid/api/layers.rst
+6
-0
doc/fluid/getstarted/Developer's_Guide_to_Paddle_Fluid.md
doc/fluid/getstarted/Developer's_Guide_to_Paddle_Fluid.md
+34
-34
doc/fluid/getstarted/index_cn.rst
doc/fluid/getstarted/index_cn.rst
+1
-0
doc/fluid/getstarted/index_en.rst
doc/fluid/getstarted/index_en.rst
+1
-0
doc/fluid/getstarted/quickstart_cn.rst
doc/fluid/getstarted/quickstart_cn.rst
+6
-6
doc/fluid/getstarted/quickstart_en.rst
doc/fluid/getstarted/quickstart_en.rst
+6
-6
doc/fluid/howto/index_cn.rst
doc/fluid/howto/index_cn.rst
+1
-1
doc/fluid/howto/index_en.rst
doc/fluid/howto/index_en.rst
+0
-1
doc/fluid/howto/inference/build_and_install_lib_cn.rst
doc/fluid/howto/inference/build_and_install_lib_cn.rst
+96
-0
doc/fluid/howto/inference/index_cn.rst
doc/fluid/howto/inference/index_cn.rst
+8
-0
doc/fluid/howto/inference/inference_support_in_fluid_cn.md
doc/fluid/howto/inference/inference_support_in_fluid_cn.md
+1
-58
paddle/contrib/inference/CMakeLists.txt
paddle/contrib/inference/CMakeLists.txt
+15
-28
paddle/contrib/inference/paddle_inference_api.h
paddle/contrib/inference/paddle_inference_api.h
+32
-22
paddle/contrib/inference/paddle_inference_api_impl.cc
paddle/contrib/inference/paddle_inference_api_impl.cc
+41
-73
paddle/contrib/inference/paddle_inference_api_impl.h
paddle/contrib/inference/paddle_inference_api_impl.h
+5
-13
paddle/contrib/inference/test_paddle_inference_api_impl.cc
paddle/contrib/inference/test_paddle_inference_api_impl.cc
+8
-9
paddle/fluid/framework/block_desc.cc
paddle/fluid/framework/block_desc.cc
+2
-2
paddle/fluid/framework/block_desc.h
paddle/fluid/framework/block_desc.h
+1
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+29
-12
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+1
-1
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+1
-1
paddle/fluid/framework/op_desc.h
paddle/fluid/framework/op_desc.h
+2
-1
paddle/fluid/framework/program_desc.cc
paddle/fluid/framework/program_desc.cc
+19
-6
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+3
-1
paddle/fluid/framework/tensor_impl.h
paddle/fluid/framework/tensor_impl.h
+2
-2
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+8
-3
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+3
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
...nference/analysis/data_flow_graph_to_fluid_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
...fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
+3
-1
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
.../fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
+2
-0
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
...nference/analysis/fluid_to_data_flow_graph_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+4
-2
paddle/fluid/inference/analysis/pass.h
paddle/fluid/inference/analysis/pass.h
+1
-0
paddle/fluid/inference/analysis/subgraph_splitter.h
paddle/fluid/inference/analysis/subgraph_splitter.h
+2
-0
paddle/fluid/inference/analysis/ut_helper.h
paddle/fluid/inference/analysis/ut_helper.h
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+2
-0
paddle/fluid/inference/tensorrt/convert/activation_op.cc
paddle/fluid/inference/tensorrt/convert/activation_op.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+2
-1
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+119
-0
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+4
-3
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+28
-13
paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
+46
-0
paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
+3
-1
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+5
-2
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+30
-15
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+1
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+3
-1
paddle/fluid/operators/bilinear_interp_op.cc
paddle/fluid/operators/bilinear_interp_op.cc
+23
-0
paddle/fluid/operators/bilinear_interp_op.cu
paddle/fluid/operators/bilinear_interp_op.cu
+23
-2
paddle/fluid/operators/bilinear_interp_op.h
paddle/fluid/operators/bilinear_interp_op.h
+18
-4
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+2
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+2
-0
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+147
-225
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+22
-76
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+53
-34
paddle/fluid/operators/detail/request_handler.h
paddle/fluid/operators/detail/request_handler.h
+127
-0
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+115
-0
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+64
-0
paddle/fluid/operators/detail/rpc_server.cc
paddle/fluid/operators/detail/rpc_server.cc
+113
-0
paddle/fluid/operators/detail/rpc_server.h
paddle/fluid/operators/detail/rpc_server.h
+91
-0
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+2
-2
paddle/fluid/operators/gather_op.cc
paddle/fluid/operators/gather_op.cc
+0
-1
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+13
-8
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+63
-148
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+9
-22
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-0
paddle/fluid/operators/shape_op.cc
paddle/fluid/operators/shape_op.cc
+54
-0
paddle/fluid/operators/shape_op.cu
paddle/fluid/operators/shape_op.cu
+20
-0
paddle/fluid/operators/shape_op.h
paddle/fluid/operators/shape_op.h
+38
-0
paddle/fluid/operators/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt_engine_op.cc
+2
-1
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+33
-26
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+1
-0
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+63
-7
python/paddle/fluid/layers/ops.py
python/paddle/fluid/layers/ops.py
+1
-0
python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
...n/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
+34
-3
python/paddle/fluid/tests/unittests/test_gather_op.py
python/paddle/fluid/tests/unittests/test_gather_op.py
+13
-2
python/paddle/fluid/tests/unittests/test_shape_op.py
python/paddle/fluid/tests/unittests/test_shape_op.py
+47
-0
python/paddle/fluid/tests/unittests/test_split_var.py
python/paddle/fluid/tests/unittests/test_split_var.py
+2
-2
python/paddle/fluid/transpiler/details/__init__.py
python/paddle/fluid/transpiler/details/__init__.py
+16
-0
python/paddle/fluid/transpiler/details/program_utils.py
python/paddle/fluid/transpiler/details/program_utils.py
+37
-0
python/paddle/fluid/transpiler/details/ufind.py
python/paddle/fluid/transpiler/details/ufind.py
+64
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+252
-257
python/setup.py.in
python/setup.py.in
+2
-1
未找到文件。
AUTHORS.md
浏览文件 @
8855d4a7
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
| backyes | Yan-Fei Wang |
| backyes | Yan-Fei Wang |
| baiyfbupt | Yi-Fan Bai |
| baiyfbupt | Yi-Fan Bai |
| beckett1124 | Bin Qi |
| beckett1124 | Bin Qi |
| ChengduoZH | Cheng-Duo Zhao|
| chengxiaohua1105 | Xiao-Hua Cheng |
| chengxiaohua1105 | Xiao-Hua Cheng |
| cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang |
| cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang |
| cxysteven | Xing-Yi Cheng |
| cxysteven | Xing-Yi Cheng |
...
...
Dockerfile
浏览文件 @
8855d4a7
...
@@ -29,7 +29,7 @@ RUN apt-get update && \
...
@@ -29,7 +29,7 @@ RUN apt-get update && \
wget unzip unrar
tar
xz-utils bzip2
gzip
coreutils ntp
\
wget unzip unrar
tar
xz-utils bzip2
gzip
coreutils ntp
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
curl
sed grep
graphviz libjpeg-dev zlib1g-dev
\
python-matplotlib gcc-4.8 g++-4.8
\
python-matplotlib gcc-4.8 g++-4.8
\
automake locales clang-format swig
doxygen
cmake
\
automake locales clang-format swig cmake
\
liblapack-dev liblapacke-dev
\
liblapack-dev liblapacke-dev
\
clang-3.8 llvm-3.8 libclang-3.8-dev
\
clang-3.8 llvm-3.8 libclang-3.8-dev
\
net-tools libtool ccache
&&
\
net-tools libtool ccache
&&
\
...
...
benchmark/fluid/Dockerfile
0 → 100644
浏览文件 @
8855d4a7
FROM
nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN
apt-get update
&&
apt-get
install
-y
python python-pip iputils-ping libgtk2.0-dev wget vim net-tools iftop
RUN
ln
-s
/usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so
&&
ln
-s
/usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/lib/libnccl.so
RUN
pip
install
-U
pip
RUN
pip
install
-U
kubernetes opencv-python paddlepaddle
# IMPORTANT:
# Add "ENV http_proxy=http://ip:port" if your download is slow, and don't forget to unset it at runtime.
RUN
sh
-c
'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()\npaddle.dataset.flowers.fetch()" | python'
RUN
sh
-c
'echo "import paddle.v2 as paddle\npaddle.dataset.mnist.train()\npaddle.dataset.mnist.test()\npaddle.dataset.imdb.fetch()" | python'
RUN
sh
-c
'echo "import paddle.v2 as paddle\npaddle.dataset.imikolov.fetch()" | python'
RUN
pip uninstall
-y
paddlepaddle
&&
mkdir
/workspace
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root
ADD
*.whl /
RUN
pip
install
/
*
.whl
&&
rm
-f
/
*
.whl
&&
chmod
+x /usr/bin/paddle_k8s
ENV
LD_LIBRARY_PATH=/usr/local/lib
ADD
fluid_benchmark.py dataset.py models/ /workspace/
benchmark/fluid/README.md
浏览文件 @
8855d4a7
...
@@ -44,11 +44,25 @@ Currently supported `--model` argument include:
...
@@ -44,11 +44,25 @@ Currently supported `--model` argument include:
## Run Distributed Benchmark on Kubernetes Cluster
## Run Distributed Benchmark on Kubernetes Cluster
You may need to build a Docker image before submitting a cluster job onto Kubernetes, or you will
have to start all those processes mannually on each node, which is not recommended.
To build the Docker image, you need to choose a paddle "whl" package to run with, you may either
download it from
http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_en.html or
build it by your own. Once you've got the "whl" package, put it under the current directory and run:
```
bash
docker build
-t
[
your docker image name]:[your docker image tag] .
```
Then push the image to a Docker registry that your Kubernetes cluster can reach.
We provide a script
`kube_gen_job.py`
to generate Kubernetes yaml files to submit
We provide a script
`kube_gen_job.py`
to generate Kubernetes yaml files to submit
distributed benchmark jobs to your cluster. To generate a job yaml, just run:
distributed benchmark jobs to your cluster. To generate a job yaml, just run:
```
bash
```
bash
python kube_gen_job.py
--jobname
myjob
--pscpu
4
--cpu
8
--gpu
8
--psmemory
20
--memory
40
--pservers
4
--trainers
4
--entry
"python fluid_benchmark.py --model mnist --
parallel 1
--device GPU --update_method pserver "
--disttype
pserver
python kube_gen_job.py
--jobname
myjob
--pscpu
4
--cpu
8
--gpu
8
--psmemory
20
--memory
40
--pservers
4
--trainers
4
--entry
"python fluid_benchmark.py --model mnist --
gpus 8
--device GPU --update_method pserver "
--disttype
pserver
```
```
Then the yaml files are generated under directory
`myjob`
, you can run:
Then the yaml files are generated under directory
`myjob`
, you can run:
...
...
benchmark/fluid/kube_gen_job.py
浏览文件 @
8855d4a7
...
@@ -49,7 +49,7 @@ def parse_args():
...
@@ -49,7 +49,7 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
'--fluid'
,
default
=
1
,
type
=
int
,
help
=
'whether is fluid job'
)
'--fluid'
,
default
=
1
,
type
=
int
,
help
=
'whether is fluid job'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--rdma'
,
action
=
'store_t
ur
e'
,
help
=
'whether mount rdma libs'
)
'--rdma'
,
action
=
'store_t
ru
e'
,
help
=
'whether mount rdma libs'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--disttype'
,
'--disttype'
,
default
=
"pserver"
,
default
=
"pserver"
,
...
...
benchmark/fluid/run.sh
浏览文件 @
8855d4a7
...
@@ -37,7 +37,8 @@ nohup stdbuf -oL nvidia-smi \
...
@@ -37,7 +37,8 @@ nohup stdbuf -oL nvidia-smi \
-l
1 &
-l
1 &
# mnist
# mnist
# mnist gpu mnist 128
# mnist gpu mnist 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/mnist.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
mnist
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
128
\
--batch_size
=
128
\
--skip_batch_num
=
5
\
--skip_batch_num
=
5
\
...
@@ -46,7 +47,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/mnist.py \
...
@@ -46,7 +47,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/mnist.py \
# vgg16
# vgg16
# gpu cifar10 128
# gpu cifar10 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/vgg16.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
vgg16
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
128
\
--batch_size
=
128
\
--skip_batch_num
=
5
\
--skip_batch_num
=
5
\
...
@@ -54,7 +56,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
...
@@ -54,7 +56,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
2>&1 |
tee
-a
vgg16_gpu_128.log
2>&1 |
tee
-a
vgg16_gpu_128.log
# flowers gpu 128
# flowers gpu 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/vgg16.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
vgg16
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
32
\
--batch_size
=
32
\
--data_set
=
flowers
\
--data_set
=
flowers
\
...
@@ -64,40 +67,39 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
...
@@ -64,40 +67,39 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
# resnet50
# resnet50
# resnet50 gpu cifar10 128
# resnet50 gpu cifar10 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/resnet50.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
resnet50
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
128
\
--batch_size
=
128
\
--data_set
=
cifar10
\
--data_set
=
cifar10
\
--model
=
resnet_cifar10
\
--skip_batch_num
=
5
\
--skip_batch_num
=
5
\
--iterations
=
30
\
--iterations
=
30
\
2>&1 |
tee
-a
resnet50_gpu_128.log
2>&1 |
tee
-a
resnet50_gpu_128.log
# resnet50 gpu flowers 64
# resnet50 gpu flowers 64
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/resnet50.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
resnet50
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
64
\
--batch_size
=
64
\
--data_set
=
flowers
\
--data_set
=
flowers
\
--model
=
resnet_imagenet
\
--skip_batch_num
=
5
\
--skip_batch_num
=
5
\
--iterations
=
30
\
--iterations
=
30
\
2>&1 |
tee
-a
resnet50_gpu_flowers_64.log
2>&1 |
tee
-a
resnet50_gpu_flowers_64.log
# lstm
# lstm
# lstm gpu imdb 32 # tensorflow only support batch=32
# lstm gpu imdb 32 # tensorflow only support batch=32
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/stacked_dynamic_lstm.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
stacked_dynamic_lstm
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
32
\
--batch_size
=
32
\
--skip_batch_num
=
5
\
--skip_batch_num
=
5
\
--iterations
=
30
\
--iterations
=
30
\
--hidden_dim
=
512
\
--emb_dim
=
512
\
--crop_size
=
1500
\
2>&1 |
tee
-a
lstm_gpu_32.log
2>&1 |
tee
-a
lstm_gpu_32.log
# seq2seq
# seq2seq
# seq2seq gpu wmb 128
# seq2seq gpu wmb 128
FLAGS_benchmark
=
true stdbuf
-oL
python fluid/machine_translation.py
\
FLAGS_benchmark
=
true stdbuf
-oL
python fluid_benchmark.py
\
--model
=
machine_translation
\
--device
=
GPU
\
--device
=
GPU
\
--batch_size
=
128
\
--batch_size
=
128
\
--skip_batch_num
=
5
\
--skip_batch_num
=
5
\
...
...
doc/fluid/api/layers.rst
浏览文件 @
8855d4a7
...
@@ -1009,3 +1009,9 @@ ____
...
@@ -1009,3 +1009,9 @@ ____
.. autofunction:: paddle.fluid.layers.upsampling_bilinear2d
.. autofunction:: paddle.fluid.layers.upsampling_bilinear2d
:noindex:
:noindex:
gather
____
.. autofunction:: paddle.fluid.layers.gather
:noindex:
doc/fluid/getstarted/Developer's_Guide_to_Paddle_Fluid.md
浏览文件 @
8855d4a7
...
@@ -86,7 +86,7 @@
...
@@ -86,7 +86,7 @@
<br>
<br>
<p
align=
"center"
>
<p
align=
"center"
>
<img
src=
"https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid
_
compiler.png"
width=
100%
>
<img
src=
"https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid
-
compiler.png"
width=
100%
>
</p>
</p>
---
---
...
...
doc/fluid/getstarted/index_cn.rst
浏览文件 @
8855d4a7
...
@@ -17,3 +17,4 @@
...
@@ -17,3 +17,4 @@
:maxdepth: 1
:maxdepth: 1
concepts/use_concepts_cn.rst
concepts/use_concepts_cn.rst
developer's_guide_to_paddle_fluid.md
doc/fluid/getstarted/index_en.rst
浏览文件 @
8855d4a7
...
@@ -16,3 +16,4 @@ Here is an example of linear regression. It introduces workflow of PaddlePaddle,
...
@@ -16,3 +16,4 @@ Here is an example of linear regression. It introduces workflow of PaddlePaddle,
:maxdepth: 1
:maxdepth: 1
concepts/index_en.rst
concepts/index_en.rst
developer's_guide_to_paddle_fluid.md
doc/fluid/getstarted/quickstart_cn.rst
浏览文件 @
8855d4a7
...
@@ -11,7 +11,7 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
...
@@ -11,7 +11,7 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
pip install paddlepaddle
pip install paddlepaddle
如果需要安装支持GPU的版本(cuda
7.5
_cudnn5_avx_openblas),需要执行:
如果需要安装支持GPU的版本(cuda
8.0
_cudnn5_avx_openblas),需要执行:
.. code-block:: bash
.. code-block:: bash
...
...
doc/fluid/getstarted/quickstart_en.rst
浏览文件 @
8855d4a7
...
@@ -12,7 +12,7 @@ Simply run the following command to install, the version is cpu_avx_openblas:
...
@@ -12,7 +12,7 @@ Simply run the following command to install, the version is cpu_avx_openblas:
pip install paddlepaddle
pip install paddlepaddle
If you need to install GPU version (cuda
7.5
_cudnn5_avx_openblas), run:
If you need to install GPU version (cuda
8.0
_cudnn5_avx_openblas), run:
.. code-block:: bash
.. code-block:: bash
...
...
doc/fluid/howto/index_cn.rst
浏览文件 @
8855d4a7
...
@@ -4,5 +4,5 @@
...
@@ -4,5 +4,5 @@
.. toctree::
.. toctree::
:maxdepth: 1
:maxdepth: 1
inference/index_cn.rst
optimization/index_cn.rst
optimization/index_cn.rst
inference/inference_support_in_fluid.md
doc/fluid/howto/index_en.rst
浏览文件 @
8855d4a7
...
@@ -5,4 +5,3 @@ HOW TO
...
@@ -5,4 +5,3 @@ HOW TO
:maxdepth: 1
:maxdepth: 1
optimization/index_en.rst
optimization/index_en.rst
inference/inference_support_in_fluid.md
doc/fluid/howto/inference/build_and_install_lib_cn.rst
0 → 100644
浏览文件 @
8855d4a7
安装与编译C++预测库
===========================
直接下载安装
-------------
====================== ========================================
版本说明 C++预测库
====================== ========================================
cpu_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/fluid.tgz>`_
cpu_avx_openblas `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxOpenblas/.lastSuccessful/fluid.tgz>`_
cpu_noavx_openblas `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/fluid.tgz>`_
cuda7.5_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda8.0_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda8.0_cudnn7_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/fluid.tgz>`_
====================== ========================================
从源码编译
----------
用户也可以从 PaddlePaddle 核心代码编译C++预测库,只需在编译时配制下面这些编译选项:
================= =========
选项 值
================= =========
CMAKE_BUILD_TYPE Release
FLUID_INSTALL_DIR 安装路径
WITH_FLUID_ONLY ON(推荐)
WITH_SWIG_PY OFF(推荐
WITH_PYTHON OFF(推荐)
WITH_GPU ON/OFF
WITH_MKL ON/OFF
================= =========
建议按照推荐值设置,以避免链接不必要的库。其它可选编译选项按需进行设定。
下面的代码片段从github拉取最新代码,配制编译选项(需要将PADDLE_ROOT替换为PaddlePaddle预测库的安装路径):
.. code-block:: bash
pip install paddlepaddle-gpu
PADDLE_ROOT=/path/of/capi
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir build
cd build
cmake -DFLUID_INSTALL_DIR=$PADDLE_ROOT \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_FLUID_ONLY=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_PYTHON=OFF \
-DWITH_MKL=OFF \
-DWITH_GPU=OFF \
..
make
make inference_lib_dist
成功编译后,使用C++预测库所需的依赖(包括:(1)编译出的PaddlePaddle预测库和头文件;(2)第三方链接库和头文件;(3)版本信息与编译选项信息)
均会存放于PADDLE_ROOT目录中。目录结构如下:
.. code-block:: text
PaddleRoot/
├── CMakeCache.txt
├── paddle
│ └── fluid
│ ├── framework
│ ├── inference
│ ├── memory
│ ├── platform
│ ├── pybind
│ └── string
├── third_party
│ ├── boost
│ │ └── boost
│ ├── eigen3
│ │ ├── Eigen
│ │ └── unsupported
│ └── install
│ ├── gflags
│ ├── glog
│ ├── mklml
│ ├── protobuf
│ ├── snappy
│ ├── snappystream
│ └── zlib
└── version.txt
version.txt 中记录了该预测库的版本信息,包括Git Commit ID、使用OpenBlas或MKL数学库、CUDA/CUDNN版本号,如:
.. code-block:: text
GIT COMMIT ID: c95cd4742f02bb009e651a00b07b21c979637dc8
WITH_MKL: ON
WITH_GPU: ON
CUDA version: 8.0
CUDNN version: v5
doc/fluid/howto/inference/index_cn.rst
0 → 100644
浏览文件 @
8855d4a7
预测库
------------
.. toctree::
:maxdepth: 1
build_and_install_lib_cn.rst
inference_support_in_fluid_cn.md
doc/fluid/howto/inference/inference_support_in_fluid.md
→
doc/fluid/howto/inference/inference_support_in_fluid
_cn
.md
浏览文件 @
8855d4a7
#
Fluid Inference
使用指南
# 使用指南
## 目录:
## 目录:
-
Python Inference API
-
Python Inference API
-
编译Fluid Inference库
-
Inference C++ API
-
Inference C++ API
-
Inference实例
-
Inference实例
-
Inference计算优化
-
Inference计算优化
...
@@ -55,62 +54,6 @@
...
@@ -55,62 +54,6 @@
return
[
program
,
feed_target_names
,
fetch_targets
]
return
[
program
,
feed_target_names
,
fetch_targets
]
```
```
## 编译Fluid Inference库
-
**不需要额外的CMake选项**
-
1、 配置CMake命令,更多配置请参考
[
源码编译PaddlePaddle
](
http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/build_from_source_cn.html
)
```
bash
$
git clone https://github.com/PaddlePaddle/Paddle.git
$
cd
Paddle
$
mkdir
build
$
cd
build
$
cmake
-DCMAKE_INSTALL_PREFIX
=
your/path/to/paddle_inference_lib
\
-DCMAKE_BUILD_TYPE
=
Release
\
-DWITH_PYTHON
=
ON
\
-DWITH_MKL
=
OFF
\
-DWITH_GPU
=
OFF
\
..
```
- 2、 编译PaddlePaddle
```bash
$ make
```
- 3、 部署。执行如下命令将PaddlePaddle Fluid Inference库部署到`your/path/to/paddle_inference_lib`目录。
```bash
$ make inference_lib_dist
```
-
目录结构
```
bash
$
cd
your/path/to/paddle_inference_lib
$
tree
.
|-- paddle
|
`
--
fluid
| |-- framework
| |-- inference
| | |-- io.h
| |
`
--
libpaddle_fluid.so
| |-- memory
| |-- platform
|
`
--
string
|-- third_party
| |-- eigen3
|
`
--
install
| |-- gflags
| |-- glog
|
`
--
protobuf
`
--
...
```
假设
`PADDLE_ROOT=your/path/to/paddle_inference_lib`
。
## 链接Fluid Inference库
## 链接Fluid Inference库
-
示例项目(
[
链接
](
https://github.com/luotao1/fluid_inference_example.git
)
)
-
示例项目(
[
链接
](
https://github.com/luotao1/fluid_inference_example.git
)
)
...
...
paddle/contrib/inference/CMakeLists.txt
浏览文件 @
8855d4a7
...
@@ -17,46 +17,33 @@ if(APPLE)
...
@@ -17,46 +17,33 @@ if(APPLE)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-error=pessimizing-move"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-error=pessimizing-move"
)
endif
(
APPLE
)
endif
(
APPLE
)
function
(
inference_api_test TARGET_NAME
TEST_SRC
)
function
(
inference_api_test TARGET_NAME
)
set
(
options
""
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs ARGS
)
set
(
multiValueArgs ARGS
)
cmake_parse_arguments
(
inference_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
cmake_parse_arguments
(
inference_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
set
(
arg_list
""
)
cc_test
(
test_paddle_inference_
${
TARGET_NAME
}
SRCS test_paddle_inference_
${
TARGET_NAME
}
.cc
DEPS paddle_fluid_api paddle_inference_api
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/
)
if
(
inference_test_ARGS
)
if
(
inference_test_ARGS
)
foreach
(
arg
${
inference_test_ARGS
}
)
set_tests_properties
(
test_paddle_inference_
${
TARGET_NAME
}
list
(
APPEND arg_list
"_
${
arg
}
"
)
PROPERTIES DEPENDS
"
${
inference_test_ARGS
}
"
)
endforeach
()
else
()
list
(
APPEND arg_list
"_"
)
endif
()
endif
()
foreach
(
arg
${
arg_list
}
)
string
(
REGEX REPLACE
"^_$"
""
arg
"
${
arg
}
"
)
cc_test
(
${
TARGET_NAME
}
SRCS
${
TEST_SRC
}
DEPS paddle_fluid_api paddle_inference_api paddle_inference_api_impl
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/
)
# TODO(panyx0178): Figure out how to add word2vec and image_classification
# as deps.
# set_tests_properties(${TARGET_NAME}
# PROPERTIES DEPENDS ${DEP_TEST})
endforeach
()
endfunction
(
inference_api_test
)
endfunction
(
inference_api_test
)
cc_library
(
paddle_inference_api
cc_library
(
paddle_inference_api
SRCS paddle_inference_api.cc
SRCS paddle_inference_api.cc
paddle_inference_api_impl.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
cc_library
(
paddle_inference_api_impl
if
(
WITH_TESTING
)
SRCS paddle_inference_api_impl.cc
cc_test
(
test_paddle_inference_api
DEPS paddle_inference_api paddle_fluid_api
)
cc_test
(
test_paddle_inference_api
SRCS test_paddle_inference_api.cc
SRCS test_paddle_inference_api.cc
DEPS paddle_inference_api
)
DEPS paddle_inference_api
)
inference_api_test
(
test_paddle_inference_api_impl
inference_api_test
(
api_impl
test_paddle_inference_api_impl.cc
)
ARGS test_word2vec test_image_classification
)
endif
()
paddle/contrib/inference/paddle_inference_api.h
浏览文件 @
8855d4a7
...
@@ -40,15 +40,24 @@ struct PaddleBuf {
...
@@ -40,15 +40,24 @@ struct PaddleBuf {
struct
PaddleTensor
{
struct
PaddleTensor
{
std
::
string
name
;
// variable name.
std
::
string
name
;
// variable name.
std
::
vector
<
int
>
shape
;
std
::
vector
<
int
>
shape
;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
PaddleBuf
data
;
// blob of data.
PaddleBuf
data
;
// blob of data.
PaddleDType
dtype
;
PaddleDType
dtype
;
};
};
enum
class
PaddleEngineKind
{
kNative
=
0
,
// Use the native Fluid facility.
// TODO(Superjomn) support following engines latter.
// kAnakin, // Use Anakin for inference.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
/*
/*
* A simple Inference API for Paddle. Currently this API might just be used by
* A simple Inference API for Paddle. Currently this API can be used by
* non-sequence scenerios.
* non-sequence scenerios.
* TODO(Superjomn) Prepare another API for NLP-related usages.
*/
*/
class
PaddlePredictor
{
class
PaddlePredictor
{
public:
public:
struct
Config
;
struct
Config
;
...
@@ -66,34 +75,35 @@ class PaddlePredictor {
...
@@ -66,34 +75,35 @@ class PaddlePredictor {
// be thread-safe.
// be thread-safe.
virtual
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
=
0
;
virtual
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
=
0
;
virtual
bool
InitShared
()
{
return
false
;
}
// Destroy the Predictor.
// Destroy the Predictor.
virtual
~
PaddlePredictor
()
{}
virtual
~
PaddlePredictor
()
{}
friend
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
(
const
PaddlePredictor
::
Config
&
config
);
// The common configs for all the predictors.
// The common configs for all the predictors.
struct
Config
{
struct
Config
{
enum
class
EngineKind
;
std
::
string
model_dir
;
// path to the model directory.
std
::
string
model_dir
;
// path to the model directory.
bool
enable_engine
{
false
};
// Enable to execute (part of) the model on
bool
enable_engine
{
false
};
// Enable to execute (part of) the model on
// third-party engines.
EngineKind
engine_kind
{
Config
::
EngineKind
::
kNone
};
enum
class
EngineKind
{
kNone
=
-
1
,
// Use the native Fluid facility.
kAnakin
,
// Use Anakin for inference.
kTensorRT
,
// Use TensorRT for inference.
kAutoMixedAnakin
,
// Automatically mix Fluid with Anakin.
kAutoMixedTensorRT
,
// Automatically mix Fluid with TensorRT.
};
};
};
};
};
// A factory to help create difference predictor.
struct
NativeConfig
:
public
PaddlePredictor
::
Config
{
template
<
typename
ConfigT
>
// GPU related fields.
bool
use_gpu
{
false
};
int
device
{
0
};
float
fraction_of_gpu_memory
{
-
1.
f
};
// Negative to notify initialization.
std
::
string
prog_file
;
std
::
string
param_file
;
};
// A factory to help create different predictors.
//
// FOR EXTENSION DEVELOPER:
// Different predictors are designated by config type and engine kind. Similar
// configs can be merged, but there shouldn't be a huge config containing
// different fields for more than one kind of predictors.
//
// Similarly, each engine kind should map to a unique predictor implementation.
template
<
typename
ConfigT
,
PaddleEngineKind
engine
=
PaddleEngineKind
::
kNative
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
(
const
ConfigT
&
config
);
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
(
const
ConfigT
&
config
);
}
// namespace paddle
}
// namespace paddle
paddle/contrib/inference/paddle_inference_api_impl.cc
浏览文件 @
8855d4a7
...
@@ -54,11 +54,10 @@ std::string num2str(T a) {
...
@@ -54,11 +54,10 @@ std::string num2str(T a) {
}
}
}
// namespace
}
// namespace
bool
PaddlePredictorImpl
::
Init
()
{
bool
NativePaddlePredictor
::
Init
()
{
VLOG
(
3
)
<<
"Predictor::init()"
;
VLOG
(
3
)
<<
"Predictor::init()"
;
// TODO(panyx0718): Should CPU vs GPU device be decided by id?
if
(
config_
.
use_gpu
)
{
if
(
config_
.
device
>=
0
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
place_
=
paddle
::
platform
::
CPUPlace
();
...
@@ -85,18 +84,20 @@ bool PaddlePredictorImpl::Init() {
...
@@ -85,18 +84,20 @@ bool PaddlePredictorImpl::Init() {
}
}
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
// Create variables
// Create temporary variables first, so that the first batch do not need to
// TODO(panyx0718): Why need to test share_variables here?
// create variables in the runtime. This is the logics of the old inference
if
(
config_
.
share_variables
)
{
// API.
// TODO(Superjomn) this should be modified when `Clone` is valid for
// multi-thread application.
executor_
->
CreateVariables
(
*
inference_program_
,
scope_
.
get
(),
0
);
executor_
->
CreateVariables
(
*
inference_program_
,
scope_
.
get
(),
0
);
}
// Get the feed_target_names and fetch_target_names
// Get the feed_target_names and fetch_target_names
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
feed_target_names_
=
inference_program_
->
GetFeedTargetNames
();
fetch_target_names_
=
inference_program_
->
GetFetchTargetNames
();
fetch_target_names_
=
inference_program_
->
GetFetchTargetNames
();
return
true
;
return
true
;
}
}
bool
PaddlePredictorImpl
::
Run
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
bool
NativePaddlePredictor
::
Run
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
std
::
vector
<
PaddleTensor
>
*
output_data
)
{
std
::
vector
<
PaddleTensor
>
*
output_data
)
{
VLOG
(
3
)
<<
"Predictor::predict"
;
VLOG
(
3
)
<<
"Predictor::predict"
;
Timer
timer
;
Timer
timer
;
...
@@ -124,7 +125,7 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
...
@@ -124,7 +125,7 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
scope_
.
get
(),
scope_
.
get
(),
&
feed_targets
,
&
feed_targets
,
&
fetch_targets
,
&
fetch_targets
,
!
config_
.
share_variables
);
false
/* don't create variable eatch time */
);
if
(
!
GetFetch
(
fetchs
,
output_data
))
{
if
(
!
GetFetch
(
fetchs
,
output_data
))
{
LOG
(
ERROR
)
<<
"fail to get fetchs"
;
LOG
(
ERROR
)
<<
"fail to get fetchs"
;
return
false
;
return
false
;
...
@@ -133,58 +134,19 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
...
@@ -133,58 +134,19 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
return
true
;
return
true
;
}
}
std
::
unique_ptr
<
PaddlePredictor
>
PaddlePredictorImpl
::
Clone
()
{
std
::
unique_ptr
<
PaddlePredictor
>
NativePaddlePredictor
::
Clone
()
{
VLOG
(
3
)
<<
"Predictor::clone"
;
VLOG
(
3
)
<<
"Predictor::clone"
;
std
::
unique_ptr
<
PaddlePredictor
>
cls
(
new
PaddlePredictorImpl
(
config_
));
std
::
unique_ptr
<
PaddlePredictor
>
cls
(
new
NativePaddlePredictor
(
config_
));
if
(
!
cls
->
InitShared
())
{
LOG
(
ERROR
)
<<
"fail to call InitShared"
;
if
(
!
dynamic_cast
<
NativePaddlePredictor
*>
(
cls
.
get
())
->
Init
())
{
LOG
(
ERROR
)
<<
"fail to call Init"
;
return
nullptr
;
return
nullptr
;
}
}
// fix manylinux compile error.
// fix manylinux compile error.
return
std
::
move
(
cls
);
return
std
::
move
(
cls
);
}
}
// TODO(panyx0718): Consider merge with Init()?
bool
NativePaddlePredictor
::
SetFeed
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
bool
PaddlePredictorImpl
::
InitShared
()
{
VLOG
(
3
)
<<
"Predictor::init_shared"
;
// 1. Define place, executor, scope
if
(
this
->
config_
.
device
>=
0
)
{
place_
=
platform
::
CUDAPlace
();
}
else
{
place_
=
platform
::
CPUPlace
();
}
this
->
executor_
.
reset
(
new
framework
::
Executor
(
this
->
place_
));
this
->
scope_
.
reset
(
new
framework
::
Scope
());
// Initialize the inference program
if
(
!
this
->
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
this
->
inference_program_
=
inference
::
Load
(
this
->
executor_
.
get
(),
this
->
scope_
.
get
(),
this
->
config_
.
model_dir
);
}
else
if
(
!
this
->
config_
.
prog_file
.
empty
()
&&
!
this
->
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
this
->
inference_program_
=
inference
::
Load
(
this
->
executor_
.
get
(),
this
->
scope_
.
get
(),
this
->
config_
.
prog_file
,
this
->
config_
.
param_file
);
}
this
->
ctx_
=
this
->
executor_
->
Prepare
(
*
this
->
inference_program_
,
0
);
// 3. create variables
// TODO(panyx0718): why test share_variables.
if
(
config_
.
share_variables
)
{
this
->
executor_
->
CreateVariables
(
*
this
->
inference_program_
,
this
->
scope_
.
get
(),
0
);
}
// 4. Get the feed_target_names and fetch_target_names
this
->
feed_target_names_
=
this
->
inference_program_
->
GetFeedTargetNames
();
this
->
fetch_target_names_
=
this
->
inference_program_
->
GetFetchTargetNames
();
return
true
;
}
bool
PaddlePredictorImpl
::
SetFeed
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
,
std
::
vector
<
framework
::
LoDTensor
>
*
feeds
)
{
std
::
vector
<
framework
::
LoDTensor
>
*
feeds
)
{
VLOG
(
3
)
<<
"Predictor::set_feed"
;
VLOG
(
3
)
<<
"Predictor::set_feed"
;
if
(
inputs
.
size
()
!=
feed_target_names_
.
size
())
{
if
(
inputs
.
size
()
!=
feed_target_names_
.
size
())
{
...
@@ -213,7 +175,7 @@ bool PaddlePredictorImpl::SetFeed(const std::vector<PaddleTensor> &inputs,
...
@@ -213,7 +175,7 @@ bool PaddlePredictorImpl::SetFeed(const std::vector<PaddleTensor> &inputs,
return
true
;
return
true
;
}
}
bool
PaddlePredictorImpl
::
GetFetch
(
bool
NativePaddlePredictor
::
GetFetch
(
const
std
::
vector
<
framework
::
LoDTensor
>
&
fetchs
,
const
std
::
vector
<
framework
::
LoDTensor
>
&
fetchs
,
std
::
vector
<
PaddleTensor
>
*
outputs
)
{
std
::
vector
<
PaddleTensor
>
*
outputs
)
{
VLOG
(
3
)
<<
"Predictor::get_fetch"
;
VLOG
(
3
)
<<
"Predictor::get_fetch"
;
...
@@ -280,10 +242,15 @@ bool PaddlePredictorImpl::GetFetch(
...
@@ -280,10 +242,15 @@ bool PaddlePredictorImpl::GetFetch(
}
}
template
<
>
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
(
std
::
unique_ptr
<
PaddlePredictor
>
const
ConfigImpl
&
config
)
{
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
VLOG
(
3
)
<<
"create PaddlePredictorImpl"
;
const
NativeConfig
&
config
)
{
VLOG
(
3
)
<<
"create NativePaddlePredictor"
;
if
(
config
.
use_gpu
)
{
// 1. GPU memeroy
// 1. GPU memeroy
PADDLE_ENFORCE
(
config
.
fraction_of_gpu_memory
>
0.
f
,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
std
::
vector
<
std
::
string
>
flags
;
std
::
vector
<
std
::
string
>
flags
;
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
...
@@ -294,9 +261,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
...
@@ -294,9 +261,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
VLOG
(
3
)
<<
"set flag: "
<<
flag
;
VLOG
(
3
)
<<
"set flag: "
<<
flag
;
framework
::
InitGflags
(
flags
);
framework
::
InitGflags
(
flags
);
}
}
}
std
::
unique_ptr
<
PaddlePredictor
>
predictor
(
new
PaddlePredictorImpl
(
config
));
std
::
unique_ptr
<
PaddlePredictor
>
predictor
(
new
NativePaddlePredictor
(
config
));
if
(
!
dynamic_cast
<
PaddlePredictorImpl
*>
(
predictor
.
get
())
->
Init
())
{
if
(
!
dynamic_cast
<
NativePaddlePredictor
*>
(
predictor
.
get
())
->
Init
())
{
return
nullptr
;
return
nullptr
;
}
}
return
std
::
move
(
predictor
);
return
std
::
move
(
predictor
);
...
...
paddle/contrib/inference/paddle_inference_api_impl.h
浏览文件 @
8855d4a7
...
@@ -29,17 +29,10 @@
...
@@ -29,17 +29,10 @@
namespace
paddle
{
namespace
paddle
{
struct
ConfigImpl
:
public
PaddlePredictor
::
Config
{
class
NativePaddlePredictor
:
public
PaddlePredictor
{
int
device
;
float
fraction_of_gpu_memory
;
std
::
string
prog_file
;
std
::
string
param_file
;
bool
share_variables
;
};
class
PaddlePredictorImpl
:
public
PaddlePredictor
{
public:
public:
explicit
PaddlePredictorImpl
(
const
ConfigImpl
&
config
)
:
config_
(
config
)
{}
explicit
NativePaddlePredictor
(
const
NativeConfig
&
config
)
:
config_
(
config
)
{}
bool
Init
();
bool
Init
();
...
@@ -48,16 +41,15 @@ class PaddlePredictorImpl : public PaddlePredictor {
...
@@ -48,16 +41,15 @@ class PaddlePredictorImpl : public PaddlePredictor {
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
override
;
std
::
unique_ptr
<
PaddlePredictor
>
Clone
()
override
;
~
PaddlePredictorImpl
()
override
{};
~
NativePaddlePredictor
()
override
{};
private:
private:
bool
InitShared
()
override
;
bool
SetFeed
(
const
std
::
vector
<
PaddleTensor
>
&
input_datas
,
bool
SetFeed
(
const
std
::
vector
<
PaddleTensor
>
&
input_datas
,
std
::
vector
<
framework
::
LoDTensor
>
*
feeds
);
std
::
vector
<
framework
::
LoDTensor
>
*
feeds
);
bool
GetFetch
(
const
std
::
vector
<
framework
::
LoDTensor
>
&
fetchs
,
bool
GetFetch
(
const
std
::
vector
<
framework
::
LoDTensor
>
&
fetchs
,
std
::
vector
<
PaddleTensor
>
*
output_data
);
std
::
vector
<
PaddleTensor
>
*
output_data
);
ConfigImpl
config_
;
NativeConfig
config_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
std
::
unique_ptr
<
framework
::
Executor
>
executor_
;
std
::
unique_ptr
<
framework
::
Executor
>
executor_
;
std
::
unique_ptr
<
framework
::
Scope
>
scope_
;
std
::
unique_ptr
<
framework
::
Scope
>
scope_
;
...
...
paddle/contrib/inference/test_paddle_inference_api_impl.cc
浏览文件 @
8855d4a7
...
@@ -40,19 +40,19 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
...
@@ -40,19 +40,19 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
return
pt
;
return
pt
;
}
}
ConfigImpl
GetConfig
()
{
NativeConfig
GetConfig
()
{
ConfigImpl
config
;
NativeConfig
config
;
config
.
model_dir
=
FLAGS_dirname
+
"word2vec.inference.model"
;
config
.
model_dir
=
FLAGS_dirname
+
"word2vec.inference.model"
;
LOG
(
INFO
)
<<
"dirname "
<<
config
.
model_dir
;
LOG
(
INFO
)
<<
"dirname "
<<
config
.
model_dir
;
config
.
fraction_of_gpu_memory
=
0.15
;
config
.
fraction_of_gpu_memory
=
0.15
;
config
.
use_gpu
=
true
;
config
.
device
=
0
;
config
.
device
=
0
;
config
.
share_variables
=
true
;
return
config
;
return
config
;
}
}
TEST
(
paddle_inference_api_impl
,
word2vec
)
{
TEST
(
paddle_inference_api_impl
,
word2vec
)
{
ConfigImpl
config
=
GetConfig
();
NativeConfig
config
=
GetConfig
();
std
::
unique_ptr
<
PaddlePredictor
>
predictor
=
CreatePaddlePredictor
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
);
framework
::
LoDTensor
first_word
,
second_word
,
third_word
,
fourth_word
;
framework
::
LoDTensor
first_word
,
second_word
,
third_word
,
fourth_word
;
framework
::
LoD
lod
{{
0
,
1
}};
framework
::
LoD
lod
{{
0
,
1
}};
...
@@ -104,7 +104,7 @@ TEST(paddle_inference_api_impl, image_classification) {
...
@@ -104,7 +104,7 @@ TEST(paddle_inference_api_impl, image_classification) {
int
batch_size
=
2
;
int
batch_size
=
2
;
bool
use_mkldnn
=
false
;
bool
use_mkldnn
=
false
;
bool
repeat
=
false
;
bool
repeat
=
false
;
ConfigImpl
config
=
GetConfig
();
NativeConfig
config
=
GetConfig
();
config
.
model_dir
=
config
.
model_dir
=
FLAGS_dirname
+
"image_classification_resnet.inference.model"
;
FLAGS_dirname
+
"image_classification_resnet.inference.model"
;
...
@@ -133,7 +133,7 @@ TEST(paddle_inference_api_impl, image_classification) {
...
@@ -133,7 +133,7 @@ TEST(paddle_inference_api_impl, image_classification) {
is_combined
,
is_combined
,
use_mkldnn
);
use_mkldnn
);
std
::
unique_ptr
<
PaddlePredictor
>
predictor
=
CreatePaddlePredictor
(
config
);
auto
predictor
=
CreatePaddlePredictor
(
config
);
std
::
vector
<
PaddleTensor
>
paddle_tensor_feeds
;
std
::
vector
<
PaddleTensor
>
paddle_tensor_feeds
;
paddle_tensor_feeds
.
push_back
(
LodTensorToPaddleTensor
(
&
input
));
paddle_tensor_feeds
.
push_back
(
LodTensorToPaddleTensor
(
&
input
));
...
@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) {
...
@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) {
float
*
data
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
);
float
*
data
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
);
float
*
lod_data
=
output1
.
data
<
float
>
();
float
*
lod_data
=
output1
.
data
<
float
>
();
for
(
size_t
j
=
0
;
j
<
len
/
sizeof
(
float
);
++
j
)
{
for
(
size_t
j
=
0
;
j
<
len
/
sizeof
(
float
);
++
j
)
{
EXPECT_LT
(
lod_data
[
j
]
-
data
[
j
],
1e-10
);
EXPECT_NEAR
(
lod_data
[
j
],
data
[
j
],
1e-3
);
EXPECT_GT
(
lod_data
[
j
]
-
data
[
j
],
-
1e-10
);
}
}
free
(
data
);
free
(
data
);
}
}
...
...
paddle/fluid/framework/block_desc.cc
浏览文件 @
8855d4a7
...
@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
...
@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
}
}
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
->
ops
())
{
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
->
ops
())
{
ops_
.
emplace_back
(
new
OpDesc
(
op_desc
,
prog
,
this
));
ops_
.
emplace_back
(
new
OpDesc
(
op_desc
,
this
));
}
}
}
}
...
@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
...
@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
:
prog_
(
prog
),
desc_
(
desc
)
{
:
prog_
(
prog
),
desc_
(
desc
)
{
need_update_
=
true
;
need_update_
=
true
;
for
(
auto
&
op
:
other
.
ops_
)
{
for
(
auto
&
op
:
other
.
ops_
)
{
ops_
.
emplace_back
(
new
OpDesc
(
*
op
->
Proto
(),
prog
,
this
));
ops_
.
emplace_back
(
new
OpDesc
(
*
op
,
this
));
}
}
for
(
auto
&
it
:
other
.
vars_
)
{
for
(
auto
&
it
:
other
.
vars_
)
{
auto
*
var
=
new
VarDesc
(
*
it
.
second
);
auto
*
var
=
new
VarDesc
(
*
it
.
second
);
...
...
paddle/fluid/framework/block_desc.h
浏览文件 @
8855d4a7
...
@@ -105,7 +105,7 @@ class BlockDesc {
...
@@ -105,7 +105,7 @@ class BlockDesc {
size_t
OpSize
()
const
{
return
ops_
.
size
();
}
size_t
OpSize
()
const
{
return
ops_
.
size
();
}
OpDesc
*
Op
(
int
idx
)
{
return
ops_
.
at
(
idx
).
get
();
}
OpDesc
*
Op
(
int
idx
)
const
{
return
ops_
.
at
(
idx
).
get
();
}
void
Flush
();
void
Flush
();
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
8855d4a7
...
@@ -11,11 +11,15 @@
...
@@ -11,11 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include
"paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include
<algorithm>
#include <fstream>
#include <fstream>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...
@@ -26,9 +30,6 @@
...
@@ -26,9 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
#endif
#include <string>
#include <vector>
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
"the ssa graph path only print with GLOG_v=10,"
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"
);
"default /tmp/graph.dot"
);
...
@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
...
@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
var_type
s
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_var
s
;
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
var_types
[
var
->
Name
()]
=
var
->
GetType
()
;
all_vars
[
var
->
Name
()]
=
var
;
}
}
auto
graph
=
new
SSAGraph
();
auto
graph
=
new
SSAGraph
();
...
@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
recv_vars
=
FindDistTrainRecvVars
(
program
);
auto
recv_vars
=
FindDistTrainRecvVars
(
program
);
size_t
cur_device_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
var_name_on_devices
.
resize
(
places_
.
size
());
var_name_on_devices
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
size_t
cur_device_id
=
0
;
std
::
vector
<
int64_t
>
balance_grads
(
places_
.
size
(),
0
);
auto
get_appropriate_dev
=
[
&
](
std
::
string
&
g_name
)
->
size_t
{
auto
var_desc
=
all_vars
.
at
(
g_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
int64_t
numel
=
framework
::
product
(
dim
);
PADDLE_ENFORCE_GE
(
numel
,
0
);
auto
smallest
=
std
::
min_element
(
std
::
begin
(
balance_grads
),
std
::
end
(
balance_grads
));
size_t
dev_id
=
static_cast
<
size_t
>
(
std
::
distance
(
std
::
begin
(
balance_grads
),
smallest
));
balance_grads
[
dev_id
]
+=
numel
;
return
dev_id
;
};
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
...
@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch
(
strategy_
.
reduce_
)
{
switch
(
strategy_
.
reduce_
)
{
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
get_appropriate_dev
(
g_name
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
var_name_on_devices
[
cur_device_id
].
emplace
(
g_name
);
var_name_on_devices
[
cur_device_id
].
emplace
(
g_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
cur_device_id
=
(
cur_device_id
+
1
)
%
places_
.
size
();
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
var_type
s
,
g_name
))
{
if
(
IsSparseGradient
(
all_var
s
,
g_name
))
{
CreateReduceOp
(
&
result
,
g_name
,
0
);
CreateReduceOp
(
&
result
,
g_name
,
0
);
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
}
else
{
}
else
{
...
@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_type
s
,
const
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
&
all_var
s
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
PADDLE_ENFORCE
(
var_type
s
.
count
(
og
)
!=
0
);
PADDLE_ENFORCE
(
all_var
s
.
count
(
og
)
!=
0
);
if
(
var_types
.
at
(
og
)
==
proto
::
VarType
::
SELECTED_ROWS
)
{
if
(
all_vars
.
at
(
og
)
->
GetType
(
)
==
proto
::
VarType
::
SELECTED_ROWS
)
{
return
true
;
return
true
;
}
}
return
false
;
return
false
;
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
8855d4a7
...
@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
bool
IsSparseGradient
(
bool
IsSparseGradient
(
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_type
s
,
const
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
&
all_var
s
,
const
std
::
string
&
og
)
const
;
const
std
::
string
&
og
)
const
;
private:
private:
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
8855d4a7
...
@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
...
@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
need_update_
=
true
;
need_update_
=
true
;
}
}
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
,
ProgramDesc
*
prog
,
BlockDesc
*
block
)
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
,
BlockDesc
*
block
)
:
desc_
(
desc
),
need_update_
(
false
)
{
:
desc_
(
desc
),
need_update_
(
false
)
{
// restore inputs_
// restore inputs_
int
input_size
=
desc_
.
inputs_size
();
int
input_size
=
desc_
.
inputs_size
();
...
...
paddle/fluid/framework/op_desc.h
浏览文件 @
8855d4a7
...
@@ -33,13 +33,14 @@ class OpDesc {
...
@@ -33,13 +33,14 @@ class OpDesc {
OpDesc
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OpDesc
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
OpDesc
(
const
proto
::
OpDesc
&
desc
,
ProgramDesc
*
prog
,
BlockDesc
*
block
);
OpDesc
(
const
proto
::
OpDesc
&
desc
,
BlockDesc
*
block
);
explicit
OpDesc
(
BlockDesc
*
block
)
:
block_
(
block
)
{}
explicit
OpDesc
(
BlockDesc
*
block
)
:
block_
(
block
)
{}
OpDesc
(
const
OpDesc
&
other
,
BlockDesc
*
block
)
{
OpDesc
(
const
OpDesc
&
other
,
BlockDesc
*
block
)
{
*
this
=
other
;
*
this
=
other
;
block_
=
block
;
block_
=
block
;
need_update_
=
true
;
}
}
void
CopyFrom
(
const
OpDesc
&
op_desc
);
void
CopyFrom
(
const
OpDesc
&
op_desc
);
...
...
paddle/fluid/framework/program_desc.cc
浏览文件 @
8855d4a7
...
@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
...
@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
auto
*
block
=
desc_
.
mutable_blocks
(
i
);
auto
*
block
=
desc_
.
mutable_blocks
(
i
);
blocks_
.
emplace_back
(
new
BlockDesc
(
*
o
.
blocks_
[
i
],
block
,
this
));
blocks_
.
emplace_back
(
new
BlockDesc
(
*
o
.
blocks_
[
i
],
block
,
this
));
}
}
for
(
auto
&
block
:
blocks_
)
{
for
(
size_t
block_id
=
0
;
block_id
<
blocks_
.
size
();
++
block_id
)
{
for
(
auto
*
op
:
block
->
AllOps
())
{
auto
all_ops
=
blocks_
[
block_id
]
->
AllOps
();
for
(
const
auto
&
attr
:
op
->
Proto
()
->
attrs
())
{
for
(
size_t
op_id
=
0
;
op_id
<
all_ops
.
size
();
++
op_id
)
{
if
(
attr
.
type
()
==
proto
::
AttrType
::
BLOCK
)
{
auto
&
op
=
all_ops
[
op_id
];
size_t
blk_idx
=
attr
.
block_idx
();
for
(
const
std
::
string
&
attr_name
:
op
->
AttrNames
())
{
op
->
SetBlockAttr
(
attr
.
name
(),
this
->
MutableBlock
(
blk_idx
));
if
(
op
->
GetAttrType
(
attr_name
)
==
proto
::
AttrType
::
BLOCK
)
{
int
sub_block_id
=
o
.
Block
(
block_id
).
Op
(
op_id
)
->
GetBlockAttr
(
attr_name
);
op
->
SetBlockAttr
(
attr_name
,
MutableBlock
(
sub_block_id
));
}
}
}
}
}
}
...
@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
...
@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
blocks_
.
emplace_back
(
new
BlockDesc
(
this
,
&
block_desc
));
blocks_
.
emplace_back
(
new
BlockDesc
(
this
,
&
block_desc
));
}
}
for
(
auto
&
block
:
blocks_
)
{
for
(
auto
*
op
:
block
->
AllOps
())
{
for
(
const
auto
&
attr
:
op
->
Proto
()
->
attrs
())
{
if
(
attr
.
type
()
==
proto
::
AttrType
::
BLOCK
)
{
size_t
blk_idx
=
attr
.
block_idx
();
op
->
SetBlockAttr
(
attr
.
name
(),
this
->
MutableBlock
(
blk_idx
));
}
}
}
}
}
}
const
std
::
vector
<
std
::
string
>
ProgramDesc
::
GetFeedTargetNames
()
{
const
std
::
vector
<
std
::
string
>
ProgramDesc
::
GetFeedTargetNames
()
{
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
8855d4a7
...
@@ -25,8 +25,10 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
...
@@ -25,8 +25,10 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
if
(
out
->
empty
())
{
if
(
out
->
empty
())
{
return
;
return
;
}
}
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
auto
&
actual
=
out
->
at
(
i
)
.
dims
();
auto
&
actual
=
(
*
out
)[
i
]
.
dims
();
auto
&
expect
=
dims_
[
i
];
auto
&
expect
=
dims_
[
i
];
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
...
...
paddle/fluid/framework/tensor_impl.h
浏览文件 @
8855d4a7
...
@@ -39,7 +39,7 @@ template <typename T>
...
@@ -39,7 +39,7 @@ template <typename T>
inline
const
T
*
Tensor
::
data
()
const
{
inline
const
T
*
Tensor
::
data
()
const
{
check_memory_size
();
check_memory_size
();
PADDLE_ENFORCE
(
std
::
is_same
<
T
,
void
>::
value
||
PADDLE_ENFORCE
(
std
::
is_same
<
T
,
void
>::
value
||
holder_
->
type
()
.
hash_code
()
==
typeid
(
T
).
hash_code
(
),
holder_
->
type
()
==
std
::
type_index
(
typeid
(
T
)
),
"Tensor holds the wrong type, it holds %s"
,
"Tensor holds the wrong type, it holds %s"
,
this
->
holder_
->
type
().
name
());
this
->
holder_
->
type
().
name
());
...
@@ -53,7 +53,7 @@ template <typename T>
...
@@ -53,7 +53,7 @@ template <typename T>
inline
T
*
Tensor
::
data
()
{
inline
T
*
Tensor
::
data
()
{
check_memory_size
();
check_memory_size
();
PADDLE_ENFORCE
(
std
::
is_same
<
T
,
void
>::
value
||
PADDLE_ENFORCE
(
std
::
is_same
<
T
,
void
>::
value
||
holder_
->
type
()
.
hash_code
()
==
typeid
(
T
).
hash_code
(
),
holder_
->
type
()
==
std
::
type_index
(
typeid
(
T
)
),
"Tensor holds the wrong type, it holds %s"
,
"Tensor holds the wrong type, it holds %s"
,
this
->
holder_
->
type
().
name
());
this
->
holder_
->
type
().
name
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
8855d4a7
...
@@ -5,14 +5,19 @@ cc_library(paddle_fluid_api
...
@@ -5,14 +5,19 @@ cc_library(paddle_fluid_api
SRCS io.cc
SRCS io.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
# Create static library
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
cc_library
(
paddle_fluid DEPS
${
fluid_modules
}
)
if
(
WITH_CONTRIB
)
set
(
fluid_modules
"
${
fluid_modules
}
"
paddle_inference_api
)
endif
()
# Create static library
cc_library
(
paddle_fluid DEPS
${
fluid_modules
}
paddle_fluid_api
)
# Create shared library
# Create shared library
cc_library
(
paddle_fluid_shared SHARED
cc_library
(
paddle_fluid_shared SHARED
SRCS io.cc
SRCS io.cc
DEPS
${
fluid_modules
}
)
DEPS
${
fluid_modules
}
paddle_fluid_api
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
if
(
NOT APPLE
)
if
(
NOT APPLE
)
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
...
...
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
8855d4a7
...
@@ -21,7 +21,10 @@ limitations under the License. */
...
@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque>
#include <deque>
#include <stack>
#include <stack>
#include <string>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
浏览文件 @
8855d4a7
...
@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
...
@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
}
}
}
//
analysis
}
;
// namespace
analysis
}
//
inference
}
;
// namespace
inference
}
//
paddle
}
;
// namespace
paddle
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
浏览文件 @
8855d4a7
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include
"paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include
<string>
#include <vector>
#include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
浏览文件 @
8855d4a7
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#pragma once
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass.h"
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
浏览文件 @
8855d4a7
...
@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
...
@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG
(
INFO
)
<<
'\n'
<<
graph
.
DotString
();
LOG
(
INFO
)
<<
'\n'
<<
graph
.
DotString
();
}
}
}
// analysis
}
//
namespace
analysis
}
// inference
}
//
namespace
inference
}
// paddle
}
//
namespace
paddle
paddle/fluid/inference/analysis/helper.h
浏览文件 @
8855d4a7
...
@@ -50,7 +50,7 @@ struct DataTypeNamer {
...
@@ -50,7 +50,7 @@ struct DataTypeNamer {
return
dic_
.
at
(
x
);
return
dic_
.
at
(
x
);
}
}
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
// NOLINT
PADDLE_ENFORCE
(
dic_
.
count
(
hash
),
"unknown type for representation"
);
PADDLE_ENFORCE
(
dic_
.
count
(
hash
),
"unknown type for representation"
);
return
dic_
.
at
(
hash
);
return
dic_
.
at
(
hash
);
}
}
...
@@ -62,7 +62,9 @@ struct DataTypeNamer {
...
@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE
(
float
);
SET_TYPE
(
float
);
}
}
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
std
::
string
>
dic_
;
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
// NOLINT
std
::
string
>
dic_
;
};
};
#undef SET_TYPE
#undef SET_TYPE
...
...
paddle/fluid/inference/analysis/pass.h
浏览文件 @
8855d4a7
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <iosfwd>
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
...
...
paddle/fluid/inference/analysis/subgraph_splitter.h
浏览文件 @
8855d4a7
...
@@ -18,6 +18,8 @@ limitations under the License. */
...
@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/ut_helper.h
浏览文件 @
8855d4a7
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
8855d4a7
...
@@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
...
@@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
nv_test
(
test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor
)
nv_test
(
test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor
)
nv_test
(
test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
nv_test
(
test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
nv_test
(
test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
paddle/fluid/inference/tensorrt/convert/activation_op.cc
浏览文件 @
8855d4a7
...
@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter {
...
@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter {
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
)
override
{
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
)
override
{
// Here the two nullptr looks strange, that's because the
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
// framework::OpDesc's constructor is strange.
framework
::
OpDesc
op_desc
(
op
,
nullptr
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
LOG
(
INFO
)
<<
"convert a fluid relu op to tensorrt activation layer whose "
LOG
(
INFO
)
<<
"convert a fluid relu op to tensorrt activation layer whose "
"type is Relu"
;
"type is Relu"
;
const
nvinfer1
::
ITensor
*
input_tensor
=
const
nvinfer1
::
ITensor
*
input_tensor
=
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
8855d4a7
...
@@ -21,7 +21,8 @@ namespace tensorrt {
...
@@ -21,7 +21,8 @@ namespace tensorrt {
class
Conv2dOpConverter
:
public
OpConverter
{
class
Conv2dOpConverter
:
public
OpConverter
{
public:
public:
Conv2dOpConverter
()
{}
Conv2dOpConverter
()
{}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
)
override
{
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
)
override
{
LOG
(
INFO
)
LOG
(
INFO
)
<<
"convert a fluid conv2d op to tensorrt conv layer without bias"
;
<<
"convert a fluid conv2d op to tensorrt conv layer without bias"
;
}
}
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
0 → 100644
浏览文件 @
8855d4a7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
// Reorder the elements from istrides to ostrides, borrowed from TRT convert in
// tensorflow.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorrt/convert/convert_nodes.cc#L318
template
<
typename
T
>
void
Reorder2
(
nvinfer1
::
DimsHW
shape
,
const
T
*
idata
,
nvinfer1
::
DimsHW
istrides
,
T
*
odata
,
nvinfer1
::
DimsHW
ostrides
)
{
for
(
int
h
=
0
;
h
<
shape
.
h
();
++
h
)
{
for
(
int
w
=
0
;
w
<
shape
.
w
();
++
w
)
{
odata
[
h
*
ostrides
.
h
()
+
w
*
ostrides
.
w
()]
=
idata
[
h
*
ostrides
.
h
()
+
w
*
ostrides
.
w
()];
}
}
}
// Reorder the data layout from CK to KC.
void
ReorderCKtoKC
(
TensorRTEngine
::
Weight
&
iweights
,
TensorRTEngine
::
Weight
*
oweights
)
{
int
c
=
iweights
.
dims
[
0
];
int
k
=
iweights
.
dims
[
1
];
oweights
->
dims
.
assign
({
k
,
c
});
nvinfer1
::
DimsHW
istrides
=
{
1
,
k
};
nvinfer1
::
DimsHW
ostrides
=
{
c
,
1
};
Reorder2
({
k
,
c
},
static_cast
<
float
const
*>
(
iweights
.
get
().
values
),
istrides
,
static_cast
<
float
*>
(
const_cast
<
void
*>
(
oweights
->
get
().
values
)),
ostrides
);
}
/*
* FC converter convert a MUL op in Fluid to a FC layer in TRT.
*/
class
FcOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
)
override
{
VLOG
(
4
)
<<
"convert a fluid fc op to tensorrt fc layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
,
nullptr
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Y"
).
size
(),
1
);
// Y is a weight
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
// Declare inputs
auto
*
X
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
).
front
());
// Declare weights
auto
*
Y_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Y"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
Y_v
);
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
// This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, that can't be avoided.
auto
*
weight_data
=
Y_t
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
);
// a matrix
size_t
n_output
=
Y_t
->
dims
()[
1
];
framework
::
LoDTensor
tmp
;
tmp
.
Resize
(
Y_t
->
dims
());
memcpy
(
tmp
.
mutable_data
<
float
>
(
platform
::
CPUPlace
()),
Y_t
->
data
<
float
>
(),
Y_t
->
dims
()[
0
]
*
Y_t
->
dims
()[
1
]);
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
Y_t
->
memory_size
()
/
sizeof
(
float
)};
TensorRTEngine
::
Weight
tmp_weight
(
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
tmp
.
data
<
float
>
()),
Y_t
->
memory_size
()
/
sizeof
(
float
));
weight
.
dims
.
assign
({
Y_t
->
dims
()[
0
],
Y_t
->
dims
()[
1
]});
tmp_weight
.
dims
=
weight
.
dims
;
// The data layout of TRT FC layer's weight is different from fluid's FC,
// need to reorder the elements.
ReorderCKtoKC
(
tmp_weight
,
&
weight
);
// Currently, the framework can only handle one fluid op -> one TRT layer,
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
// handle `mul`, leave `add` as another layer.
// DEBUG
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
n_output
,
weight
.
get
(),
bias
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
engine_
->
DeclareOutput
(
layer
,
0
,
output_name
);
}
};
REGISTER_TRT_OP_CONVERTER
(
fc
,
FcOpConverter
);
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
mul
);
paddle/fluid/inference/tensorrt/convert/mul_op.cc
浏览文件 @
8855d4a7
...
@@ -24,10 +24,11 @@ namespace tensorrt {
...
@@ -24,10 +24,11 @@ namespace tensorrt {
class
MulOpConverter
:
public
OpConverter
{
class
MulOpConverter
:
public
OpConverter
{
public:
public:
MulOpConverter
()
{}
MulOpConverter
()
{}
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
)
override
{
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
VLOG
(
4
)
<<
"convert a fluid mul op to tensorrt fc layer without bias"
;
const
framework
::
Scope
&
scope
)
override
{
VLOG
(
4
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
// Declare inputs
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
8855d4a7
...
@@ -31,27 +31,42 @@ namespace tensorrt {
...
@@ -31,27 +31,42 @@ namespace tensorrt {
class
OpConverter
{
class
OpConverter
{
public:
public:
OpConverter
()
{}
OpConverter
()
{}
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
)
{}
void
Run
(
const
framework
::
proto
::
OpDesc
&
op
,
TensorRTEngine
*
engine
)
{
// Converter logic for an op.
std
::
string
type
=
op
.
type
();
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
auto
*
it
=
Registry
<
OpConverter
>::
Lookup
(
type
);
const
framework
::
Scope
&
scope
)
{}
PADDLE_ENFORCE_NOT_NULL
(
it
,
"no OpConverter for optype [%s]"
,
type
);
it
->
SetEngine
(
engine
);
// Convert a single fluid operaotr and add the corresponding layer to TRT.
(
*
it
)(
op
);
void
ConvertOp
(
const
framework
::
proto
::
OpDesc
&
op
,
}
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
,
nullptr
);
OpConverter
*
it
{
nullptr
};
// convert fluid op to tensorrt layer
if
(
op_desc
.
Type
()
==
"mul"
)
{
void
ConvertOp
(
const
framework
::
proto
::
OpDesc
&
op
,
TensorRTEngine
*
engine
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"Y"
).
size
(),
1UL
);
OpConverter
::
Run
(
op
,
engine
);
std
::
string
Y
=
op_desc
.
Input
(
"Y"
)[
0
];
if
(
parameters
.
count
(
Y
))
{
it
=
Registry
<
OpConverter
>::
Lookup
(
"fc"
);
}
}
if
(
!
it
)
{
it
=
Registry
<
OpConverter
>::
Lookup
(
op_desc
.
Type
());
}
PADDLE_ENFORCE_NOT_NULL
(
it
,
"no OpConverter for optype [%s]"
,
op_desc
.
Type
());
it
->
SetEngine
(
engine
);
(
*
it
)(
op
,
scope
);
}
}
// convert fluid block to tensorrt network
// convert fluid block to tensorrt network
void
ConvertBlock
(
const
framework
::
proto
::
BlockDesc
&
block
,
void
ConvertBlock
(
const
framework
::
proto
::
BlockDesc
&
block
,
TensorRTEngine
*
engine
)
{
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
framework
::
Scope
&
scope
,
TensorRTEngine
*
engine
)
{
for
(
int
i
=
0
;
i
<
block
.
ops_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
block
.
ops_size
();
i
++
)
{
const
auto
&
op
=
block
.
ops
(
i
);
const
auto
&
op
=
block
.
ops
(
i
);
OpConverter
::
Run
(
op
,
engine
);
ConvertOp
(
op
,
parameters
,
scope
,
engine
);
}
}
}
}
...
...
paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
0 → 100644
浏览文件 @
8855d4a7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
TEST
(
fc_op
,
test
)
{
std
::
unordered_set
<
std
::
string
>
parameters
({
"mul-Y"
});
framework
::
Scope
scope
;
TRTConvertValidation
validator
(
20
,
parameters
,
scope
,
1000
);
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims4
(
8
,
3
,
1
,
1
));
validator
.
DeclParamVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
3
,
2
));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
8
,
2
));
// Prepare Op description
framework
::
OpDesc
desc
;
desc
.
SetType
(
"mul"
);
desc
.
SetInput
(
"X"
,
{
"mul-X"
});
desc
.
SetInput
(
"Y"
,
{
"mul-Y"
});
desc
.
SetOutput
(
"Out"
,
{
"mul-Out"
});
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
Execute
(
10
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
浏览文件 @
8855d4a7
...
@@ -21,7 +21,9 @@ namespace inference {
...
@@ -21,7 +21,9 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
TEST
(
MulOpConverter
,
main
)
{
TEST
(
MulOpConverter
,
main
)
{
TRTConvertValidation
validator
(
10
,
1000
);
framework
::
Scope
scope
;
std
::
unordered_set
<
std
::
string
>
parameters
;
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims2
(
10
,
6
));
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims2
(
10
,
6
));
validator
.
DeclInputVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
6
,
10
));
validator
.
DeclInputVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
6
,
10
));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
10
,
10
));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
10
,
10
));
...
...
paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
浏览文件 @
8855d4a7
...
@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) {
...
@@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op
->
SetType
(
"conv2d"
);
conv2d_op
->
SetType
(
"conv2d"
);
OpConverter
converter
;
OpConverter
converter
;
converter
.
ConvertBlock
(
*
block
->
Proto
(),
nullptr
/*TensorRTEngine*/
);
framework
::
Scope
scope
;
converter
.
ConvertBlock
(
*
block
->
Proto
(),
{},
scope
,
nullptr
/*TensorRTEngine*/
);
}
}
}
// namespace tensorrt
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
8855d4a7
...
@@ -19,6 +19,9 @@ limitations under the License. */
...
@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
...
@@ -58,7 +61,10 @@ class TRTConvertValidation {
...
@@ -58,7 +61,10 @@ class TRTConvertValidation {
public:
public:
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1
<<
10
)
{
TRTConvertValidation
(
int
batch_size
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
framework
::
Scope
&
scope
,
int
workspace_size
=
1
<<
10
)
:
parameters_
(
parameters
),
scope_
(
scope
)
{
// create engine.
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
engine_
->
InitNetwork
();
engine_
->
InitNetwork
();
...
@@ -73,19 +79,22 @@ class TRTConvertValidation {
...
@@ -73,19 +79,22 @@ class TRTConvertValidation {
engine_
->
DeclareInput
(
name
,
nvinfer1
::
DataType
::
kFLOAT
,
dims
);
engine_
->
DeclareInput
(
name
,
nvinfer1
::
DataType
::
kFLOAT
,
dims
);
}
}
// Declare a parameter varaible in the scope.
void
DeclParamVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
DeclVar
(
name
,
dims
);
}
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
DeclVar
(
name
,
dims
);
DeclVar
(
name
,
dims
);
}
}
// Declare a variable in a fluid Scope.
void
DeclVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
// Init Fluid tensor.
// Init Fluid tensor.
std
::
vector
<
int
>
dim_vec
(
dims
.
nbDims
);
std
::
vector
<
int
>
dim_vec
(
dims
.
d
,
dims
.
d
+
dims
.
nbDims
);
for
(
int
i
=
0
;
i
<
dims
.
nbDims
;
i
++
)
{
dim_vec
[
i
]
=
dims
.
d
[
i
];
}
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
...
@@ -96,20 +105,22 @@ class TRTConvertValidation {
...
@@ -96,20 +105,22 @@ class TRTConvertValidation {
op_
=
framework
::
OpRegistry
::
CreateOp
(
desc
);
op_
=
framework
::
OpRegistry
::
CreateOp
(
desc
);
OpConverter
op_converter
;
OpConverter
op_converter
;
op_converter
.
ConvertOp
(
desc
,
engine_
.
get
());
op_converter
.
ConvertOp
(
desc
,
parameters_
,
scope_
,
engine_
.
get
());
engine_
->
FreezeNetwork
();
engine_
->
FreezeNetwork
();
// Declare outputs.
// Declare outputs.
op_desc_
.
reset
(
new
framework
::
OpDesc
(
desc
,
nullptr
,
nullptr
));
op_desc_
.
reset
(
new
framework
::
OpDesc
(
desc
,
nullptr
));
// Set Inputs.
// Set Inputs.
for
(
const
auto
&
input
:
op_desc_
->
InputArgumentNames
())
{
for
(
const
auto
&
input
:
op_desc_
->
InputArgumentNames
())
{
if
(
parameters_
.
count
(
input
))
continue
;
auto
*
var
=
scope_
.
FindVar
(
input
);
auto
*
var
=
scope_
.
FindVar
(
input
);
PADDLE_ENFORCE
(
var
);
PADDLE_ENFORCE
(
var
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
engine_
->
SetInputFromCPU
(
engine_
->
SetInputFromCPU
(
input
,
static_cast
<
void
*>
(
tensor
->
data
<
float
>
()),
input
,
static_cast
<
void
*>
(
tensor
->
data
<
void
>
()),
sizeof
(
float
)
*
sizeof
(
float
)
*
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
analysis
::
AccuDims
(
tensor
->
dims
(),
tensor
->
dims
().
size
()));
}
}
...
@@ -117,18 +128,21 @@ class TRTConvertValidation {
...
@@ -117,18 +128,21 @@ class TRTConvertValidation {
void
Execute
(
int
batch_size
)
{
void
Execute
(
int
batch_size
)
{
// Execute Fluid Op
// Execute Fluid Op
// Execute TRT
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
engine_
->
Execute
(
batch_size
);
op_
->
Run
(
scope_
,
place
);
op_
->
Run
(
scope_
,
place
);
// Execute TRT.
engine_
->
Execute
(
batch_size
);
cudaStreamSynchronize
(
*
engine_
->
stream
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
ASSERT_FALSE
(
op_desc_
->
OutputArgumentNames
().
empty
());
const
size_t
output_space_size
=
200
;
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
for
(
const
auto
&
output
:
op_desc_
->
OutputArgumentNames
())
{
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
fluid_out
;
std
::
vector
<
float
>
trt_out
(
200
);
std
::
vector
<
float
>
trt_out
(
output_space_size
);
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
200
*
sizeof
(
float
));
engine_
->
GetOutputInCPU
(
output
,
&
trt_out
[
0
],
output_space_size
*
sizeof
(
float
));
cudaStreamSynchronize
(
*
engine_
->
stream
());
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
@@ -136,7 +150,7 @@ class TRTConvertValidation {
...
@@ -136,7 +150,7 @@ class TRTConvertValidation {
// Compare two output
// Compare two output
ASSERT_FALSE
(
fluid_out
.
empty
());
ASSERT_FALSE
(
fluid_out
.
empty
());
for
(
size_t
i
=
0
;
i
<
fluid_out
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
fluid_out
.
size
();
i
++
)
{
EXPECT_LT
(
std
::
abs
(
fluid_out
[
i
]
-
trt_out
[
i
]),
0.001
);
EXPECT_LT
(
std
::
abs
(
fluid_out
[
i
]
-
trt_out
[
i
]),
1e-6
);
}
}
}
}
}
}
...
@@ -146,9 +160,10 @@ class TRTConvertValidation {
...
@@ -146,9 +160,10 @@ class TRTConvertValidation {
private:
private:
std
::
unique_ptr
<
TensorRTEngine
>
engine_
;
std
::
unique_ptr
<
TensorRTEngine
>
engine_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
framework
::
Scope
scope_
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
op_
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
op_
;
std
::
unique_ptr
<
framework
::
OpDesc
>
op_desc_
;
std
::
unique_ptr
<
framework
::
OpDesc
>
op_desc_
;
const
std
::
unordered_set
<
std
::
string
>&
parameters_
;
framework
::
Scope
&
scope_
;
};
};
}
// namespace tensorrt
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
8855d4a7
...
@@ -106,6 +106,7 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
...
@@ -106,6 +106,7 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
name
);
name
);
auto
*
output
=
layer
->
getOutput
(
offset
);
auto
*
output
=
layer
->
getOutput
(
offset
);
SetITensor
(
name
,
output
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
output
->
setName
(
name
.
c_str
());
output
->
setName
(
name
.
c_str
());
infer_network_
->
markOutput
(
*
output
);
infer_network_
->
markOutput
(
*
output
);
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
8855d4a7
...
@@ -37,13 +37,15 @@ class TensorRTEngine : public EngineBase {
...
@@ -37,13 +37,15 @@ class TensorRTEngine : public EngineBase {
// Weight is model parameter.
// Weight is model parameter.
class
Weight
{
class
Weight
{
public:
public:
Weight
(
nvinfer1
::
DataType
dtype
,
void
*
value
,
in
t
num_elem
)
{
Weight
(
nvinfer1
::
DataType
dtype
,
void
*
value
,
size_
t
num_elem
)
{
w_
.
type
=
dtype
;
w_
.
type
=
dtype
;
w_
.
values
=
value
;
w_
.
values
=
value
;
w_
.
count
=
num_elem
;
w_
.
count
=
num_elem
;
}
}
const
nvinfer1
::
Weights
&
get
()
{
return
w_
;
}
const
nvinfer1
::
Weights
&
get
()
{
return
w_
;
}
std
::
vector
<
int64_t
>
dims
;
private:
private:
nvinfer1
::
Weights
w_
;
nvinfer1
::
Weights
w_
;
};
};
...
...
paddle/fluid/operators/bilinear_interp_op.cc
浏览文件 @
8855d4a7
...
@@ -34,9 +34,22 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
...
@@ -34,9 +34,22 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
int
out_w
=
ctx
->
Attrs
().
Get
<
int
>
(
"out_w"
);
int
out_w
=
ctx
->
Attrs
().
Get
<
int
>
(
"out_w"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"X's dimension must be 4"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"X's dimension must be 4"
);
if
(
ctx
->
HasInput
(
"OutSize"
))
{
auto
out_size_dim
=
ctx
->
GetInputDim
(
"OutSize"
);
PADDLE_ENFORCE_EQ
(
out_size_dim
.
size
(),
1
,
"OutSize's dimension size must be 1"
);
PADDLE_ENFORCE_EQ
(
out_size_dim
[
0
],
2
,
"OutSize's dim[0] must be 2"
);
}
std
::
vector
<
int64_t
>
dim_out
({
dim_x
[
0
],
dim_x
[
1
],
out_h
,
out_w
});
std
::
vector
<
int64_t
>
dim_out
({
dim_x
[
0
],
dim_x
[
1
],
out_h
,
out_w
});
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
dim_out
));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
dim_out
));
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
}
};
};
class
BilinearInterpOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
BilinearInterpOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -45,6 +58,10 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -45,6 +58,10 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
AddInput
(
"X"
,
"(Tensor) The input tensor of bilinear interpolation, "
"(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"
);
"This is a 4-D tensor with shape of (N x C x h x w)"
);
AddInput
(
"OutSize"
,
"(Tensor) This is a 1-D tensor with two number. "
"The first number is height and the second number is width."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"(Tensor) The dimension of output is (N x C x out_h x out_w]"
);
"(Tensor) The dimension of output is (N x C x out_h x out_w]"
);
...
@@ -78,6 +95,12 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel {
...
@@ -78,6 +95,12 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
dim_x
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
dim_x
);
}
}
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/bilinear_interp_op.cu
浏览文件 @
8855d4a7
...
@@ -102,10 +102,21 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -102,10 +102,21 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
auto
*
input_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
input_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
*
input
=
input_t
->
data
<
T
>
();
auto
*
input
=
input_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_dims
=
output_t
->
dims
();
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
Tensor
sizes
;
framework
::
TensorCopy
(
*
out_size_t
,
platform
::
CPUPlace
(),
&
sizes
);
auto
size_data
=
sizes
.
data
<
int
>
();
out_h
=
size_data
[
0
];
out_w
=
size_data
[
1
];
}
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
{
out_dims
[
0
],
out_dims
[
1
],
out_h
,
out_w
},
ctx
.
GetPlace
());
int
batch_size
=
input_t
->
dims
()[
0
];
int
batch_size
=
input_t
->
dims
()[
0
];
int
channels
=
input_t
->
dims
()[
1
];
int
channels
=
input_t
->
dims
()[
1
];
int
in_h
=
input_t
->
dims
()[
2
];
int
in_h
=
input_t
->
dims
()[
2
];
...
@@ -139,8 +150,8 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -139,8 +150,8 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_input_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_input_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_output_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_output_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_input
=
d_input_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_output
=
d_output_t
->
data
<
T
>
();
auto
*
d_output
=
d_output_t
->
data
<
T
>
();
auto
*
d_input
=
d_input_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
device_ctx
=
auto
&
device_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
...
@@ -149,6 +160,16 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -149,6 +160,16 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
Tensor
sizes
;
framework
::
TensorCopy
(
*
out_size_t
,
platform
::
CPUPlace
(),
&
sizes
);
auto
size_data
=
sizes
.
data
<
int
>
();
out_h
=
size_data
[
0
];
out_w
=
size_data
[
1
];
}
int
batch_size
=
d_input_t
->
dims
()[
0
];
int
batch_size
=
d_input_t
->
dims
()[
0
];
int
channels
=
d_input_t
->
dims
()[
1
];
int
channels
=
d_input_t
->
dims
()[
1
];
int
in_h
=
d_input_t
->
dims
()[
2
];
int
in_h
=
d_input_t
->
dims
()[
2
];
...
...
paddle/fluid/operators/bilinear_interp_op.h
浏览文件 @
8855d4a7
...
@@ -24,11 +24,18 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
...
@@ -24,11 +24,18 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
input_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
out_dims
=
output_t
->
dims
();
auto
*
input
=
input_t
->
data
<
T
>
();
auto
*
input
=
input_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
auto
out_size_data
=
out_size_t
->
data
<
int
>
();
out_h
=
out_size_data
[
0
];
out_w
=
out_size_data
[
1
];
}
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
{
out_dims
[
0
],
out_dims
[
1
],
out_h
,
out_w
},
ctx
.
GetPlace
());
int
batch_size
=
input_t
->
dims
()[
0
];
int
batch_size
=
input_t
->
dims
()[
0
];
int
channels
=
input_t
->
dims
()[
1
];
int
channels
=
input_t
->
dims
()[
1
];
int
in_h
=
input_t
->
dims
()[
2
];
int
in_h
=
input_t
->
dims
()[
2
];
...
@@ -83,9 +90,8 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
...
@@ -83,9 +90,8 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_input_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_input_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_output_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_output_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_input
=
d_input_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
d_output
=
d_output_t
->
data
<
T
>
();
auto
*
d_output
=
d_output_t
->
data
<
T
>
();
auto
*
d_input
=
d_input_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
device_ctx
=
auto
&
device_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
zero
;
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
zero
;
...
@@ -93,6 +99,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
...
@@ -93,6 +99,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
auto
out_size_data
=
out_size_t
->
data
<
int
>
();
out_h
=
out_size_data
[
0
];
out_w
=
out_size_data
[
1
];
}
int
batch_size
=
d_input_t
->
dims
()[
0
];
int
batch_size
=
d_input_t
->
dims
()[
0
];
int
channels
=
d_input_t
->
dims
()[
1
];
int
channels
=
d_input_t
->
dims
()[
1
];
int
in_h
=
d_input_t
->
dims
()[
2
];
int
in_h
=
d_input_t
->
dims
()[
2
];
...
...
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
8855d4a7
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
8855d4a7
...
@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
}
}
bool
RPCClient
::
Wait
()
{
bool
RPCClient
::
Wait
()
{
VLOG
(
3
)
<<
"RPCClient begin Wait()"
<<
" req_count_:"
<<
req_count_
;
if
(
req_count_
<=
0
)
{
if
(
req_count_
<=
0
)
{
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
8855d4a7
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <limits>
#include <string>
#include <string>
using
::
grpc
::
ServerAsyncResponseWriter
;
#include "paddle/fluid/operators/detail/grpc_server.h"
DEFINE_int32
(
rpc_server_handle_send_threads
,
20
,
using
::
grpc
::
ServerAsyncResponseWriter
;
"Number of threads used to handle send at rpc server."
);
DEFINE_int32
(
rpc_server_handle_get_threads
,
20
,
"Number of threads used to handle get at rpc server."
);
DEFINE_int32
(
rpc_server_handle_prefetch_threads
,
1
,
"Number of threads used to handle prefetch at rpc server."
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
...
@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
class
RequestBase
{
class
RequestBase
{
public:
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
const
platform
::
DeviceContext
*
dev_ctx
)
RequestHandler
*
request_handler
,
int
req_id
)
:
service_
(
service
),
:
service_
(
service
),
cq_
(
cq
),
cq_
(
cq
),
sync_mode_
(
sync_mode
),
status_
(
PROCESS
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
request_handler_
(
request_handler
),
req_id_
(
req_id
)
{
PADDLE_ENFORCE
(
cq_
);
PADDLE_ENFORCE
(
cq_
);
}
}
virtual
~
RequestBase
()
{}
virtual
~
RequestBase
()
{}
virtual
void
Process
()
{
assert
(
false
);
}
virtual
void
Process
()
=
0
;
CallStatus
Status
()
{
return
status_
;
}
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
virtual
std
::
string
GetReqName
()
=
0
;
assert
(
false
);
return
""
;
}
protected:
protected:
::
grpc
::
ServerContext
ctx_
;
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
const
bool
sync_mode_
;
CallStatus
status_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
RequestHandler
*
request_handler_
;
int
req_id_
;
};
};
class
RequestSend
final
:
public
RequestBase
{
class
RequestSend
final
:
public
RequestBase
{
public:
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
RequestHandler
*
request_handler
,
int
req_id
)
const
platform
::
DeviceContext
*
dev_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
queue_
(
queue
),
request_handler
->
dev_ctx
(),
responder_
(
&
ctx_
),
!
request_handler
->
sync_mode
()));
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
...
@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
virtual
~
RequestSend
()
{}
virtual
~
RequestSend
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
std
::
string
varname
=
GetReqName
();
VLOG
(
3
)
<<
"RequestSend var_name:"
<<
varname
;
virtual
void
Process
()
{
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
var_name
=
GetReqName
();
auto
invar
=
request_
->
GetVar
();
VLOG
(
3
)
<<
"RequestSend "
<<
var_name
;
framework
::
Variable
*
outvar
=
nullptr
;
queue_
->
Push
(
std
::
make_pair
(
var_name
,
request_
));
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
...
@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
...
@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
protected:
protected:
sendrecv
::
VoidMessage
reply_
;
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
ReceivedQueue
*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
int
req_id_
;
};
};
class
RequestGet
final
:
public
RequestBase
{
class
RequestGet
final
:
public
RequestBase
{
public:
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
RequestHandler
*
request_handler
,
int
req_id
)
const
platform
::
DeviceContext
*
dev_ctx
,
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
framework
::
BlockingQueue
<
MessageWithName
>*
queue
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
),
req_id_
(
req_id
)
{
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
}
virtual
~
RequestGet
()
{}
virtual
~
RequestGet
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
v
irtual
void
Process
()
{
v
oid
Process
()
override
{
// proc request.
// proc request.
std
::
string
var_name
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
VLOG
(
3
)
<<
"RequestGet "
<<
var_name
;
VLOG
(
3
)
<<
"RequestGet "
<<
varname
;
auto
*
var
=
scope_
->
FindVar
(
var_name
);
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
auto
scope
=
request_handler_
->
scope
();
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
}
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
if
(
var_name
==
FETCH_BARRIER_MESSAGE
)
{
sendrecv
::
VariableMessage
msg
;
MessageWithName
msg_with_name
=
std
::
make_pair
(
var_name
,
msg
);
queue_
->
Push
(
msg_with_name
);
}
}
}
protected:
protected:
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
BlockingQueue
<
MessageWithName
>*
queue_
;
int
req_id_
;
};
};
class
RequestPrefetch
final
:
public
RequestBase
{
class
RequestPrefetch
final
:
public
RequestBase
{
public:
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
RequestHandler
*
request_handler
,
int
req_id
)
const
platform
::
DeviceContext
*
dev_ctx
,
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
local_scope_
(
nullptr
)
{
executor_
(
executor
),
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
program_
(
program
),
request_handler
->
dev_ctx
(),
true
));
prefetch_ctx_
(
prefetch_ctx
),
req_id_
(
req_id
)
{
// prefetch always create a new sub scope
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
}
virtual
~
RequestPrefetch
()
{}
virtual
~
RequestPrefetch
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
v
irtual
void
Process
()
{
v
oid
Process
()
override
{
// prefetch process...
// prefetch process...
std
::
string
varname
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"RequestPrefetch "
<<
varname
;
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
std
::
string
var_name
=
request_
->
OutVarname
();
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
VLOG
(
3
)
<<
"RequestPrefetch "
<<
var_name
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
request_
->
GetMutableLocalScope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
local_scope
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
...
@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
...
@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
local_scope_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
req_id_
;
};
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
int
fetch_barriers
=
0
;
while
(
fetch_barriers
<
count
)
{
auto
msg
=
var_get_queue_
.
Pop
();
if
(
msg
.
first
==
FETCH_BARRIER_MESSAGE
)
{
fetch_barriers
++
;
}
}
}
void
AsyncGRPCServer
::
WaitServerReady
()
{
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
3
)
<<
"AsyncGRPCServer is wait server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
VLOG
(
3
)
<<
"AsyncGRPCServer WaitSeverReady"
;
}
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
void
AsyncGRPCServer
::
StartServer
()
{
::
grpc
::
ServerBuilder
builder
;
::
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
::
grpc
::
InsecureServerCredentials
(),
builder
.
AddListeningPort
(
bind_
address_
,
::
grpc
::
InsecureServerCredentials
(),
&
selected_port_
);
&
selected_port_
);
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
RegisterService
(
&
service_
);
builder
.
RegisterService
(
&
service_
);
cq_send_
=
builder
.
AddCompletionQueue
();
for
(
auto
t
:
rpc_call_map_
)
{
cq_get_
=
builder
.
AddCompletionQueue
(
);
rpc_cq_
[
t
.
first
].
reset
(
builder
.
AddCompletionQueue
().
release
()
);
cq_prefetch_
=
builder
.
AddCompletionQueue
();
}
server_
=
builder
.
BuildAndStart
();
server_
=
builder
.
BuildAndStart
();
LOG
(
INFO
)
<<
"Server listening on "
<<
address_
LOG
(
INFO
)
<<
"Server listening on "
<<
bind_
address_
<<
" selected port: "
<<
selected_port_
;
<<
" selected port: "
<<
selected_port_
;
std
::
function
<
void
(
int
)
>
send_register
=
std
::
bind
(
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
f
=
&
AsyncGRPCServer
::
TryToRegisterNewSendOne
,
this
,
std
::
placeholders
::
_1
);
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewOne
,
this
,
std
::
function
<
void
(
int
)
>
get_register
=
std
::
bind
(
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
);
&
AsyncGRPCServer
::
TryToRegisterNewGetOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
int
)
>
prefetch_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
,
this
,
std
::
placeholders
::
_1
);
for
(
int
i
=
0
;
i
<
kSendReqsBufSize
;
++
i
)
{
for
(
auto
&
t
:
rpc_call_map_
)
{
TryToRegisterNewSendOne
(
i
);
auto
&
rpc_name
=
t
.
first
;
}
auto
&
cq
=
rpc_cq_
[
rpc_name
];
for
(
int
i
=
0
;
i
<
kGetReqsBufSize
;
++
i
)
{
auto
threadnum
=
rpc_thread_num_
[
rpc_name
];
TryToRegisterNewGetOne
(
i
);
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
}
for
(
int
i
=
0
;
i
<
kPrefetchReqsBufSize
;
++
i
)
{
reqs
.
reserve
(
kRequestBufSize
);
TryToRegisterNewPrefetchOne
(
i
);
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kRequestBufSize
;
i
++
)
{
t_sends_
.
emplace_back
(
TryToRegisterNewOne
(
rpc_name
,
i
);
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
}
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
t_gets_
.
emplace_back
(
for
(
int
i
=
0
;
i
<
threadnum
;
i
++
)
{
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
rpc_threads_
[
rpc_name
].
emplace_back
(
new
std
::
thread
(
std
::
bind
(
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq
.
get
(),
rpc_name
,
f
)));
VLOG
(
3
)
<<
t
.
first
<<
" creates threads!"
;
}
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
t_prefetchs_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
}
}
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
ready_
=
1
;
ready_
=
1
;
}
}
condition_ready_
.
notify_all
();
condition_ready_
.
notify_all
();
// wait server
// wait server
server_
->
Wait
();
server_
->
Wait
();
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
t_sends_
[
i
]
->
join
();
for
(
auto
&
t
:
rpc_threads_
)
{
}
auto
&
threads
=
t
.
second
;
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
++
i
)
{
t_gets_
[
i
]
->
join
();
threads
[
i
]
->
join
();
VLOG
(
3
)
<<
t
.
first
<<
" threads ends!"
;
}
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
t_prefetchs_
[
i
]
->
join
();
}
}
}
}
void
AsyncGRPCServer
::
ShutdownQueue
()
{
void
AsyncGRPCServer
::
ShutdownQueue
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
for
(
auto
&
t
:
rpc_cq_
)
{
cq_send_
->
Shutdown
();
t
.
second
->
Shutdown
();
cq_get_
->
Shutdown
()
;
VLOG
(
3
)
<<
t
.
first
<<
" shutdown!"
;
cq_prefetch_
->
Shutdown
();
}
}
}
// This URL explains why shutdown is complicate:
void
AsyncGRPCServer
::
ShutDownImpl
()
{
void
AsyncGRPCServer
::
ShutDown
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
is_shut_down_
=
true
;
is_shut_down_
=
true
;
ShutdownQueue
();
ShutdownQueue
();
VLOG
(
3
)
<<
"server_ shutdown!"
;
server_
->
Shutdown
();
server_
->
Shutdown
();
}
}
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
(
int
i
)
{
void
AsyncGRPCServer
::
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
return
;
}
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
sync_mode_
,
scope_
,
&
var_recv_queue_
,
dev_ctx_
,
i
);
send_reqs_
[
i
]
=
static_cast
<
RequestBase
*>
(
send
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
(
int
req_id
)
{
VLOG
(
4
)
<<
"register send rpc_name:"
<<
rpc_name
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
<<
", handler:"
<<
rpc_call_map_
[
kRequestSend
];
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
&
var_get_queue_
,
req_id
);
get_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
get
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
(
int
req_id
)
{
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
if
(
is_shut_down_
)
{
auto
&
cq
=
rpc_cq_
[
rpc_name
];
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
return
;
RequestBase
*
b
=
nullptr
;
if
(
rpc_name
==
kRequestSend
)
{
b
=
new
RequestSend
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGet
)
{
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestPrefetch
)
{
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
{
PADDLE_ENFORCE
(
false
,
"not surpported rpc"
);
}
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
.
get
(),
req_id
);
prefetch_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
prefetch
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
reqs
[
req_id
]
=
b
;
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
b
->
Status
();
}
}
// FIXME(typhoonzero): change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq
_name
,
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc
_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
)
{
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
)
{
void
*
tag
=
NULL
;
void
*
tag
=
NULL
;
bool
ok
=
false
;
bool
ok
=
false
;
while
(
true
)
{
while
(
true
)
{
VLOG
(
3
)
<<
"HandleRequest
for "
<<
cq_name
<<
" wait N
ext"
;
VLOG
(
3
)
<<
"HandleRequest
"
<<
rpc_name
<<
" wait n
ext"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" CompletionQueue
shutdown!"
;
LOG
(
INFO
)
<<
"CompletionQueue "
<<
rpc_name
<<
"
shutdown!"
;
break
;
break
;
}
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" get Next"
;
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
if
(
sync_mode_
)
{
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
// FIXME(typhoonzero): de-couple the barriers with recv_op
VLOG
(
3
)
<<
"HandleRequest "
<<
rpc_name
<<
", req_id:"
<<
req_id
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
<<
" get next"
;
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" after WaitCond"
;
}
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
RequestBase
*
base
=
nullptr
;
RequestBase
*
base
=
nullptr
;
{
{
std
::
lock_guard
<
std
::
mutex
>
l
(
cq_mutex_
);
PADDLE_ENFORCE
(
req_id
>=
0
&&
req_id
<
kRequestBufSize
);
if
(
cq_name
==
"cq_get"
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
base
=
get_reqs_
[
req_id
];
base
=
reqs
[
req_id
];
}
else
if
(
cq_name
==
"cq_send"
)
{
base
=
send_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_prefetch"
)
{
base
=
prefetch_reqs_
[
req_id
];
}
}
}
// reference:
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if
(
!
ok
)
{
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name["
LOG
(
WARNING
)
<<
"completion queue:"
<<
rpc_name
<<
" recv no regular event:argument name["
<<
base
->
GetReqName
()
<<
"]"
;
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
(
req_id
);
TryToRegisterNewOne
(
r
pc_name
,
r
eq_id
);
delete
base
;
delete
base
;
continue
;
continue
;
}
}
VLOG
(
3
)
<<
"queue id:"
<<
rpc_name
<<
", req_id:"
<<
req_id
<<
", status:"
<<
base
->
Status
();
switch
(
base
->
Status
())
{
switch
(
base
->
Status
())
{
case
PROCESS
:
{
case
PROCESS
:
{
base
->
Process
();
base
->
Process
();
VLOG
(
4
)
<<
cq_name
<<
" PROCESS status:"
<<
base
->
Status
();
break
;
break
;
}
}
case
FINISH
:
{
case
FINISH
:
{
TryToRegisterNewOne
(
req_id
);
TryToRegisterNewOne
(
rpc_name
,
req_id
);
VLOG
(
4
)
<<
cq_name
<<
" FINISH status:"
<<
base
->
Status
();
delete
base
;
delete
base
;
break
;
break
;
}
}
...
@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
...
@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
}
}
}
}
void
AsyncGRPCServer
::
WaitCond
(
int
cond
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
barrier_cond_step_
==
cond
;
});
}
void
AsyncGRPCServer
::
SetCond
(
int
cond
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_cond_step_
=
cond
;
}
barrier_condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace detail
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
8855d4a7
...
@@ -14,6 +14,8 @@ limitations under the License. */
...
@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#pragma once
#include <map>
#include <set>
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <utility>
#include <utility>
...
@@ -28,6 +30,8 @@ limitations under the License. */
...
@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...
@@ -37,106 +41,48 @@ namespace paddle {
...
@@ -37,106 +41,48 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
VariableResponse
>>
ReceivedMessage
;
typedef
framework
::
BlockingQueue
<
ReceivedMessage
>
ReceivedQueue
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
RequestBase
;
class
AsyncGRPCServer
final
{
class
AsyncGRPCServer
final
:
public
RPCServer
{
public:
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
address_
(
address
),
sync_mode_
(
sync_mode
),
ready_
(
0
)
{}
:
RPCServer
(
address
,
client_num
),
ready_
(
0
)
{}
~
AsyncGRPCServer
()
{}
void
WaitServerReady
();
void
RunSyncUpdate
();
// functions to sync server barrier status.
void
WaitCond
(
int
cond
);
void
SetCond
(
int
cond
);
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
v
oid
Push
(
const
std
::
string
&
msg_name
)
{
v
irtual
~
AsyncGRPCServer
()
{}
this
->
var_recv_queue_
.
Push
(
std
::
make_pair
(
msg_name
,
nullptr
))
;
void
WaitServerReady
()
override
;
}
void
StartServer
()
override
;
void
ShutDown
();
private:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc_name
,
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
);
protected:
void
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
);
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
(
int
req_id
);
void
TryToRegisterNewGetOne
(
int
req_id
);
void
TryToRegisterNewPrefetchOne
(
int
req_id
);
void
ShutdownQueue
();
void
ShutdownQueue
();
void
ShutDownImpl
()
override
;
private:
private:
static
const
int
kSendReqsBufSize
=
100
;
static
const
int
kRequestBufSize
=
100
;
static
const
int
kGetReqsBufSize
=
100
;
static
const
int
kPrefetchReqsBufSize
=
10
;
std
::
mutex
cq_mutex_
;
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_prefetch_
;
RequestBase
*
send_reqs_
[
kSendReqsBufSize
];
RequestBase
*
get_reqs_
[
kGetReqsBufSize
];
RequestBase
*
prefetch_reqs_
[
kPrefetchReqsBufSize
];
GrpcService
::
AsyncService
service_
;
GrpcService
::
AsyncService
service_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
framework
::
BlockingQueue
<
MessageWithName
>
var_get_queue_
;
// client send variable to this queue.
ReceivedQueue
var_recv_queue_
;
// condition of the sub program
// condition of the sub program
std
::
mutex
barrier_mutex_
;
std
::
mutex
barrier_mutex_
;
mutable
int
barrier_cond_step_
;
mutable
int
barrier_cond_step_
;
std
::
condition_variable
barrier_condition_
;
std
::
condition_variable
barrier_condition_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_sends_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_gets_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_prefetchs_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
std
::
mutex
mutex_ready_
;
std
::
mutex
mutex_ready_
;
std
::
condition_variable
condition_ready_
;
std
::
condition_variable
condition_ready_
;
int
ready_
;
int
ready_
;
std
::
map
<
std
::
string
,
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>>
rpc_cq_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>>
rpc_threads_
;
std
::
map
<
std
::
string
,
std
::
vector
<
RequestBase
*>>
rpc_reqs_
;
};
};
};
// namespace detail
};
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
8855d4a7
...
@@ -24,13 +24,16 @@ limitations under the License. */
...
@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace
framework
=
paddle
::
framework
;
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
platform
=
paddle
::
platform
;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
detail
=
paddle
::
operators
::
detail
;
USE_OP
(
lookup_table
);
USE_OP
(
lookup_table
);
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
auto
root_block
=
program
->
MutableBlock
(
0
);
auto
root_block
=
program
->
MutableBlock
(
0
);
...
@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
...
@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
}
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
void
StartServer
()
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
...
@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
...
@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
rpc_service_
->
SetProgram
(
&
program
);
g_req_handler
->
SetProgram
(
&
program
);
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
g_req_handler
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
rpc_service_
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetDevCtx
(
&
ctx
);
rpc_service_
->
SetScope
(
&
scope
);
g_req_handler
->
SetScope
(
&
scope
);
rpc_service_
->
SetExecutor
(
&
exe
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
rpc_service_
->
RunSyncUpdate
();
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
// FIXME(gongwb): don't use hard time.
sleep
(
10
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
}
}
TEST
(
PREFETCH
,
DISABLED_CPU
)
{
TEST
(
PREFETCH
,
CPU
)
{
// start up a server instance backend
g_req_handler
.
reset
(
new
detail
::
RequestPrefetchHandler
(
true
));
std
::
thread
server_thread
(
StartServer
,
"127.0.0.1:8889"
);
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
sleep
(
2
);
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
client
;
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
{
// create var on local scope
// create var on local scope
int64_t
rows_numel
=
5
;
int64_t
rows_numel
=
5
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
auto
client
=
detail
::
RPCClient
::
GetInstance
();
client
.
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
client
.
Wait
();
out_var_name
);
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
rpc_service_
->
ShutDown
();
server_thread
.
join
();
rpc_service_
.
reset
(
nullptr
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
}
}
}
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
}
paddle/fluid/operators/detail/request_handler.h
0 → 100644
浏览文件 @
8855d4a7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
class
RPCServer
;
class
RequestHandler
{
public:
explicit
RequestHandler
(
bool
sync_mode
)
:
sync_mode_
(
sync_mode
),
dev_ctx_
(
nullptr
),
executor_
(
nullptr
),
scope_
(
nullptr
),
program_
(
nullptr
),
rpc_server_
(
nullptr
)
{}
virtual
~
RequestHandler
()
{}
// Set attributes.
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
// Used for async.
void
SetGradToPreparedCtx
(
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
g
)
{
grad_to_prepared_ctx_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
bool
sync_mode
()
{
return
sync_mode_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
framework
::
ExecutorPrepareContext
*
prefetch_ctx
()
{
return
prefetch_ctx_
.
get
();
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
framework
::
Executor
*
executor
()
{
return
executor_
;
}
std
::
vector
<
framework
::
Variable
*>&
sparse_vars
()
{
return
sparse_vars_
;
}
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
=
0
;
protected:
const
bool
sync_mode_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
framework
::
Executor
*
executor_
;
framework
::
Scope
*
scope_
;
framework
::
ProgramDesc
*
program_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
// Used for async.
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
grad_to_prepared_ctx_
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars_
;
RPCServer
*
rpc_server_
;
std
::
mutex
sparse_var_mutex_
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.cc
0 → 100644
浏览文件 @
8855d4a7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Async
if
(
!
sync_mode_
)
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"async: run sub program error "
<<
e
.
what
();
return
false
;
}
return
true
;
}
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv batch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestSend
);
}
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
PADDLE_THROW
(
"sync: Can not find server side var"
);
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
sparse_var_mutex_
);
sparse_vars_
.
push_back
(
invar
);
}
}
return
true
;
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
varname
!=
FETCH_BARRIER_MESSAGE
)
{
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestGet
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
return
true
;
}
// FETCH_BARRIER_MESSAGE
if
(
sync_mode_
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
varname
);
*
outvar
=
scope
->
FindVar
(
varname
);
InitializeVariable
(
*
outvar
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
.
get
(),
scope
);
return
true
;
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.h
0 → 100644
浏览文件 @
8855d4a7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
explicit
RequestSendHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestSendHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
explicit
RequestGetHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestGetHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.cc
0 → 100644
浏览文件 @
8855d4a7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
void
RPCServer
::
ShutDown
()
{
LOG
(
INFO
)
<<
"RPCServer ShutDown "
;
ShutDownImpl
();
exit_flag_
=
true
;
barrier_cond_
.
notify_all
();
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
SavePort
()
const
{
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
;
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
=
]
{
return
(
barrier_counter_
[
rpc_name
]
>=
client_num_
||
exit_flag_
.
load
());
});
VLOG
(
3
)
<<
"batch_barrier_:"
<<
barrier_counter_
[
rpc_name
];
}
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
int
b
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
barrier_counter_
[
rpc_name
];
}
VLOG
(
3
)
<<
"RPCServer IncreaseBatchBarrier "
<<
rpc_name
<<
", barrier_count:"
<<
b
<<
", fan_in"
<<
client_num_
;
if
(
b
>=
client_num_
)
{
barrier_cond_
.
notify_all
();
}
}
void
RPCServer
::
ResetBarrierCounter
()
{
VLOG
(
3
)
<<
"RPCServer ResetBarrierCounter "
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
auto
&
t
:
barrier_counter_
)
{
t
.
second
=
0
;
}
}
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
)
{
rpc_call_map_
[
rpc_name
]
=
handler
;
rpc_thread_num_
[
rpc_name
]
=
thread_num
;
static
int
cond
=
-
1
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
VLOG
(
4
)
<<
"RegisterRPC rpc_name:"
<<
rpc_name
<<
", handler:"
<<
handler
<<
", cond:"
<<
rpc_cond_map_
[
rpc_name
];
}
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer SetCond "
<<
rpc_name
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cur_cond_
=
rpc_cond_map_
[
rpc_name
];
}
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer WaitCond "
<<
rpc_name
;
int
cond
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond
=
rpc_cond_map_
[
rpc_name
];
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.h
0 → 100644
浏览文件 @
8855d4a7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RPCServer
{
public:
explicit
RPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
cur_cond_
(
0
),
bind_address_
(
address
),
exit_flag_
(
false
),
selected_port_
(
0
),
client_num_
(
client_num
)
{}
virtual
~
RPCServer
()
{}
virtual
void
StartServer
()
=
0
;
virtual
void
WaitServerReady
()
=
0
;
void
ShutDown
();
bool
IsExit
()
{
return
exit_flag_
.
load
();
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
void
SavePort
()
const
;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void
WaitBarrier
(
const
std
::
string
&
rpc_name
);
void
SetCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
ResetBarrierCounter
();
protected:
virtual
void
ShutDownImpl
()
=
0
;
private:
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
int
>
barrier_counter_
;
std
::
condition_variable
barrier_cond_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_cond_map_
;
std
::
atomic
<
int
>
cur_cond_
;
std
::
condition_variable
rpc_cond_
;
protected:
std
::
string
bind_address_
;
std
::
atomic
<
int
>
exit_flag_
;
int
selected_port_
;
const
int
client_num_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
friend
class
RequestHandler
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
8855d4a7
...
@@ -67,8 +67,8 @@ class VariableResponse {
...
@@ -67,8 +67,8 @@ class VariableResponse {
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
// should call parse first.
// should call parse first.
framework
::
Variable
*
GetVar
()
{
framework
::
Variable
*
GetVar
()
{
...
...
paddle/fluid/operators/gather_op.cc
浏览文件 @
8855d4a7
...
@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel {
auto
index_dims
=
ctx
->
GetInputDim
(
"Index"
);
auto
index_dims
=
ctx
->
GetInputDim
(
"Index"
);
PADDLE_ENFORCE
(
index_dims
.
size
()
==
1
);
PADDLE_ENFORCE
(
index_dims
.
size
()
==
1
);
int
batch_size
=
ctx
->
GetInputDim
(
"Index"
)[
0
];
int
batch_size
=
ctx
->
GetInputDim
(
"Index"
)[
0
];
PADDLE_ENFORCE_GE
(
batch_size
,
0
,
"Batch size must be >0"
);
framework
::
DDim
output_dims
(
ctx
->
GetInputDim
(
"X"
));
framework
::
DDim
output_dims
(
ctx
->
GetInputDim
(
"X"
));
output_dims
[
0
]
=
batch_size
;
output_dims
[
0
]
=
batch_size
;
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
8855d4a7
...
@@ -23,6 +23,7 @@ limitations under the License. */
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
true
);
detail
::
RequestSendHandler
rpc_h
(
true
);
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
1
);
rpc_service
.
RegisterRPC
(
detail
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
&
rpc_service
);
framework
::
ProgramDesc
empty_program
;
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_
service
.
SetScope
(
scope
);
rpc_
h
.
SetScope
(
scope
);
rpc_
service
.
SetDevCtx
(
&
dev_ctx
);
rpc_
h
.
SetDevCtx
(
&
dev_ctx
);
rpc_
service
.
SetProgram
(
&
empty_program
);
rpc_
h
.
SetProgram
(
&
empty_program
);
rpc_
service
.
SetExecutor
(
&
executor
);
rpc_
h
.
SetExecutor
(
&
executor
);
std
::
thread
server_thread
(
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
&
rpc_service
));
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
&
rpc_service
));
rpc_service
.
SetCond
(
0
);
rpc_service
.
SetCond
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
auto
recv
=
rpc_service
.
Get
(
);
rpc_service
.
WaitBarrier
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
rpc_service
.
ShutDown
();
rpc_service
.
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
VLOG
(
3
)
<<
"rpc server stopped"
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
8855d4a7
...
@@ -19,14 +19,16 @@ limitations under the License. */
...
@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
)
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
)
{
service
->
RunSyncUpdate
();
service
->
StartServer
();
VLOG
(
4
)
<<
"RunServer thread end"
;
VLOG
(
4
)
<<
"RunServer thread end"
;
}
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
...
@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
...
@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
}
}
std
::
atomic_int
ListenAndServOp
::
selected_port_
{
0
};
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
...
@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
...
@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp
::~
ListenAndServOp
()
{
Stop
();
}
ListenAndServOp
::~
ListenAndServOp
()
{
Stop
();
}
void
ListenAndServOp
::
Stop
()
{
void
ListenAndServOp
::
Stop
()
{
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
rpc_service_
->
ShutDown
();
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
server_thread_
->
join
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
...
@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
...
@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void
ListenAndServOp
::
SavePort
()
const
{
void
ListenAndServOp
::
SavePort
()
const
{
// NOTE: default write file to /tmp/paddle.selected_port
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_
=
rpc_service_
->
GetSelectedPort
();
rpc_service_
->
SavePort
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
.
load
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
ListenAndServOp
::
WaitServerReady
()
{
while
(
selected_port_
.
load
()
==
0
)
{
}
}
}
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
framework
::
BlockDesc
*
prefetch_block
)
const
{
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
"server program should have at least 2 blocks"
);
...
@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
...
@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared
.
begin
(),
optimize_prepared
.
begin
(),
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
bool
exit_flag
=
false
;
rpc_service_
->
ResetBarrierCounter
()
;
// Record received sparse variables, so that
// Record received sparse variables, so that
// we could reset those after execute optimize program
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
()
)
{
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
rpc_service_
->
SetCond
(
detail
::
kRequestSend
);
size_t
recv_var_cnt
=
0
;
rpc_service_
->
WaitBarrier
(
detail
::
kRequestSend
);
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
if
(
rpc_service_
->
IsExit
())
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
LOG
(
WARNING
)
<<
"get exit!rpc_processor break!"
;
auto
recv_var_name
=
v
.
first
;
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
if
(
recv_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
continue
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
recv_var_cnt
++
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
}
}
}
if
(
exit_flag
)
{
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
ShutDown
();
break
;
break
;
}
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
...
@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
...
@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_
->
WaitBarrier
(
detail
::
kRequestGet
);
rpc_service_
->
WaitClientGet
(
fan_in
);
rpc_service_
->
ResetBarrierCounter
();
sparse_vars
.
clear
();
}
// while(true)
}
// while(true)
}
}
static
void
AsyncUpdateThread
(
const
std
::
string
&
var_name
,
const
bool
&
exit_flag
,
const
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
&
queue
,
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
)
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" started"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
const
detail
::
ReceivedMessage
v
=
queue
->
Pop
();
if
(
SignalHandler
::
IsProgramExit
())
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" exit"
;
break
;
}
auto
recv_var_name
=
v
.
first
;
VLOG
(
4
)
<<
"async update "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
auto
fs
=
framework
::
Async
([
var_name
,
&
executor
,
&
v
,
prepared
]
{
try
{
executor
->
RunPreparedContext
(
prepared
,
v
.
second
->
GetMutableLocalScope
());
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
fs
.
wait
();
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
)
const
{
framework
::
ProgramDesc
*
program
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
detail
::
ReceivedQueue
>>
grad_to_queue
;
auto
grad_to_block_id_str
=
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
...
@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
queue
=
std
::
make_shared
<
detail
::
ReceivedQueue
>
();
grad_to_queue
[
pieces
[
0
]]
=
queue
;
// record blocking queue in SignalHandler
SignalHandler
::
RegisterBlockingQueue
(
queue
);
id_to_grad
[
block_id
]
=
pieces
[
0
];
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
}
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
...
@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
}
bool
exit_flag
=
false
;
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_get_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
VLOG
(
3
)
<<
"start async optimize threads"
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
iter
=
grad_to_queue
.
begin
();
iter
!=
grad_to_queue
.
end
();
iter
++
)
{
std
::
string
grad_name
=
iter
->
first
;
VLOG
(
3
)
<<
"create async update thread for "
<<
grad_name
;
fs
.
push_back
(
framework
::
AsyncIO
([
grad_name
,
&
exit_flag
,
&
executor
,
&
grad_to_queue
,
&
grad_to_prepared_ctx
]()
{
AsyncUpdateThread
(
grad_name
,
exit_flag
,
grad_to_queue
[
grad_name
],
executor
,
grad_to_prepared_ctx
[
grad_name
].
get
());
}));
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
while
(
true
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
if
(
rpc_service_
->
IsExit
())
{
auto
recv_var_name
=
v
.
first
;
LOG
(
INFO
)
<<
"get exit!rpc_processor break!"
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
grad_to_queue
[
recv_var_name
]
->
Push
(
v
);
}
}
if
(
exit_flag
)
{
sleep
(
1
);
rpc_service_
->
ShutDown
();
break
;
}
}
// while(true)
}
// while(true)
}
}
static
void
FillRequestCtx
(
detail
::
RequestHandler
*
h
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
detail
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
h
->
SetDevCtx
(
dev_ctx
);
h
->
SetExecutor
(
executor
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
std
::
move
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
(
prefetch_ctx
)));
h
->
SetRPCServer
(
rpc_server
);
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
const
platform
::
Place
&
dev_place
)
const
{
// Mark this as PS that it should decide profiling by listening from trainer.
// Mark this as PS that it should decide profiling by listening from trainer.
...
@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
new
detail
::
RequestSendHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
detail
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
());
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
program
=
optimize_block
->
Program
();
auto
*
program
=
optimize_block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
framework
::
Executor
executor
(
dev_place
);
// prepare rpc_service
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
rpc_service_
->
SetProgram
(
program
);
rpc_service_
->
SetExecutor
(
&
executor
);
// prepare for prefetch
// prepare for prefetch
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prefetch_prepared
));
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
prefetch_prepared
.
release
(),
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
// start the server listening after all member initialized.
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
// Write to a file of server selected port for python use.
// Write to a file of server selected port for python use.
std
::
string
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.selected_port"
,
static_cast
<
int
>
(
::
getpid
()));
SavePort
();
SavePort
();
if
(
sync_mode
)
{
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
...
@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
bool
SignalHandler
::
program_exit_flag_
=
false
;
SignalHandler
::
BlockingQueueSet
SignalHandler
::
blocking_queue_set_
{};
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
VLOG
(
3
)
<<
"Catch interrupt signal: "
<<
signal_num
<<
", program will exit"
;
VLOG
(
3
)
<<
"Catch interrupt signal: "
<<
signal_num
<<
", program will exit"
;
exit
(
0
);
program_exit_flag_
=
true
;
// awake all blocking queues
for
(
BlockingQueueSet
::
iterator
iter
=
blocking_queue_set_
.
begin
();
iter
!=
blocking_queue_set_
.
end
();
iter
++
)
{
iter
->
get
()
->
Push
(
std
::
make_pair
(
std
::
string
(
LISTEN_TERMINATE_MESSAGE
),
nullptr
));
}
exit
(
EXIT_SUCCESS
);
}
void
SignalHandler
::
RegisterBlockingQueue
(
BlockingQueue
&
queue
)
{
blocking_queue_set_
.
insert
(
queue
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
8855d4a7
...
@@ -23,7 +23,8 @@ limitations under the License. */
...
@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -31,7 +32,7 @@ namespace operators {
...
@@ -31,7 +32,7 @@ namespace operators {
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
);
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
);
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
public:
public:
...
@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void
SavePort
()
const
;
void
SavePort
()
const
;
void
WaitServerReady
();
int
GetSelectedPort
()
{
return
rpc_service_
->
GetSelectedPort
();
}
int
GetSelectedPort
()
{
return
selected_port_
;
}
void
Stop
()
override
;
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
const
platform
::
Place
&
dev_place
)
const
override
;
static
void
ResetPort
()
{
selected_port_
=
0
;
}
protected:
protected:
mutable
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
// FIXME(wuyi): it's static so that the operator can be cloned.
static
std
::
atomic_int
selected_port_
;
};
};
class
SignalHandler
{
class
SignalHandler
{
public:
typedef
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
BlockingQueue
;
typedef
std
::
unordered_set
<
BlockingQueue
>
BlockingQueueSet
;
public:
public:
static
void
StopAndExit
(
int
signal_num
);
static
void
StopAndExit
(
int
signal_num
);
static
void
RegisterBlockingQueue
(
BlockingQueue
&
);
static
inline
bool
IsProgramExit
()
{
return
program_exit_flag_
;
}
private:
private:
static
bool
program_exit_flag_
;
static
BlockingQueueSet
blocking_queue_set_
;
DISABLE_COPY_AND_ASSIGN
(
SignalHandler
);
DISABLE_COPY_AND_ASSIGN
(
SignalHandler
);
};
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
8855d4a7
...
@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
// need to wait before sending send_barrier message
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
if
(
sync_mode
)
{
...
...
paddle/fluid/operators/shape_op.cc
0 → 100644
浏览文件 @
8855d4a7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
ShapeOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input (Input) of get_shape op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output (Out) of get_shape op should not be null."
);
auto
in_dim
=
ctx
->
GetInputDim
(
"Input"
);
ctx
->
SetOutputDim
(
"Out"
,
{
in_dim
.
size
()});
}
};
class
ShapeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(Tensor), The input tensor."
);
AddOutput
(
"Out"
,
"(Tensor), The shape of input tensor."
);
AddComment
(
R"DOC(
Shape Operator.
Get the shape of input tensor.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
shape
,
ops
::
ShapeOp
,
ops
::
ShapeOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
shape
,
ops
::
ShapeKernel
<
int
>
,
ops
::
ShapeKernel
<
int64_t
>
,
ops
::
ShapeKernel
<
float
>
,
ops
::
ShapeKernel
<
double
>
);
paddle/fluid/operators/shape_op.cu
0 → 100644
浏览文件 @
8855d4a7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL
(
shape
,
paddle
::
operators
::
ShapeKernel
<
int
>
,
paddle
::
operators
::
ShapeKernel
<
int64_t
>
,
paddle
::
operators
::
ShapeKernel
<
float
>
,
paddle
::
operators
::
ShapeKernel
<
double
>
);
paddle/fluid/operators/shape_op.h
0 → 100644
浏览文件 @
8855d4a7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
ShapeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in_t
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
out_data
=
out_t
->
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
auto
in_dims
=
in_t
->
dims
();
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
out_data
[
i
]
=
in_dims
[
i
];
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/tensorrt_engine_op.cc
浏览文件 @
8855d4a7
...
@@ -31,8 +31,9 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
...
@@ -31,8 +31,9 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
auto
max_workspace
=
context
.
Attr
<
int
>
(
"max_workspace"
);
engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
engine_
.
reset
(
new
inference
::
tensorrt
::
TensorRTEngine
(
max_batch_
,
max_workspace
,
nullptr
));
max_batch_
,
max_workspace
,
nullptr
));
// TODO(Superjomn) parameters should be passed after analysised from outside.
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
().
ConvertBlock
(
inference
::
Singleton
<
inference
::
tensorrt
::
OpConverter
>::
Global
().
ConvertBlock
(
block
,
engine_
.
get
());
block
,
{},
context
.
scope
(),
engine_
.
get
());
engine_
->
FreezeNetwork
();
engine_
->
FreezeNetwork
();
}
}
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
8855d4a7
...
@@ -21,6 +21,8 @@ limitations under the License. */
...
@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
...
@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
string
=
paddle
::
string
;
namespace
string
=
paddle
::
string
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
void
StartServer
(
std
::
atomic
<
bool
>*
initialized
)
{
void
StartServer
()
{
f
::
Scope
scope
;
f
::
Scope
scope
;
p
::
CPUPlace
place
;
p
::
CPUPlace
place
;
scope
.
Var
(
NCCL_ID_VARNAME
);
scope
.
Var
(
NCCL_ID_VARNAME
);
p
::
DeviceContextPool
&
pool
=
p
::
DeviceContextPool
::
Instance
();
p
::
DeviceContextPool
&
pool
=
p
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
true
));
f
::
ProgramDesc
empty_program
;
f
::
ProgramDesc
empty_program
;
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_service
->
SetScope
(
&
scope
);
g_req_handler
->
SetScope
(
&
scope
);
rpc_service
->
SetDevCtx
(
&
dev_ctx
);
g_req_handler
->
SetDevCtx
(
&
dev_ctx
);
rpc_service
->
SetProgram
(
&
empty_program
);
g_req_handler
->
SetProgram
(
&
empty_program
);
rpc_service
->
SetExecutor
(
&
executor
);
g_req_handler
->
SetExecutor
(
&
executor
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestSend
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
rpc_service
.
get
()));
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
*
initialized
=
true
;
rpc_service
->
SetCond
(
0
);
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
auto
recv
=
rpc_service
->
Get
();
std
::
cout
<<
"before WaitFanInOfSend"
<<
std
::
endl
;
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
rpc_service
->
ShutDown
();
g_
rpc_service
->
ShutDown
();
server_thread
.
join
();
server_thread
.
join
();
}
}
TEST
(
SendNcclId
,
DISABLED_Normal
)
{
TEST
(
SendNcclId
,
GrpcServer
)
{
std
::
atomic
<
bool
>
initialized
{
false
};
g_req_handler
.
reset
(
new
detail
::
RequestSendHandler
(
true
));
std
::
thread
server_thread
(
StartServer
,
&
initialized
);
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
while
(
!
initialized
)
{
}
std
::
thread
server_thread
(
StartServer
);
// wait server to start
g_rpc_service
->
WaitServerReady
();
// sleep(2);
rpc_service
->
WaitServerReady
();
f
::
Scope
scope
;
f
::
Scope
scope
;
p
::
CPUPlace
place
;
p
::
CPUPlace
place
;
...
@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
...
@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
var
=
scope
.
Var
(
NCCL_ID_VARNAME
);
auto
var
=
scope
.
Var
(
NCCL_ID_VARNAME
);
// var->SetType(f::proto::VarType_Type_RAW);
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
p
::
dynload
::
ncclGetUniqueId
(
id
);
p
::
dynload
::
ncclGetUniqueId
(
id
);
int
port
=
rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
client
;
detail
::
RPCClient
client
;
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
Wait
();
client
.
Wait
();
client
.
AsyncSendBatchBarrier
(
ep
);
client
.
Wait
();
server_thread
.
join
();
server_thread
.
join
();
auto
*
ptr
=
rpc_service
.
release
(
);
g_rpc_service
.
reset
(
nullptr
);
delete
ptr
;
g_req_handler
.
reset
(
nullptr
)
;
}
}
paddle/fluid/platform/nccl_helper.h
浏览文件 @
8855d4a7
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <stdio.h>
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <typeindex>
#include <typeindex>
#include <vector>
#include <vector>
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
8855d4a7
...
@@ -183,7 +183,7 @@ function build() {
...
@@ -183,7 +183,7 @@ function build() {
============================================
============================================
EOF
EOF
make clean
make clean
make
-j
`
nproc
`
make
install
-j
`
nproc
`
}
}
function
build_android
()
{
function
build_android
()
{
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
8855d4a7
...
@@ -82,6 +82,7 @@ __all__ = [
...
@@ -82,6 +82,7 @@ __all__ = [
'roi_pool'
,
'roi_pool'
,
'dice_loss'
,
'dice_loss'
,
'upsampling_bilinear2d'
,
'upsampling_bilinear2d'
,
'gather'
,
'random_crop'
,
'random_crop'
,
]
]
...
@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
...
@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
def
dice_loss
(
input
,
label
,
epsilon
=
0.00001
):
def
dice_loss
(
input
,
label
,
epsilon
=
0.00001
):
"""
"""
**Dice loss Layer**
Dice loss for comparing the similarity of two batch of data,
Dice loss for comparing the similarity of two batch of data,
usually is used for binary image segmentation i.e. labels are binary.
usually is used for binary image segmentation i.e. labels are binary.
The dice loss can be defined as below equation:
The dice loss can be defined as below equation:
...
@@ -3944,7 +3944,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
...
@@ -3944,7 +3944,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
input (Variable): The input tensor of bilinear interpolation,
input (Variable): The input tensor of bilinear interpolation,
This is a 4-D tensor of the shape
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
(num_batches, channels, in_h, in_w).
out_shape(list|tuple|None): Output shape of bilinear interpolation
out_shape(list|tuple|
Variable|
None): Output shape of bilinear interpolation
layer, the shape is (out_h, out_w).
layer, the shape is (out_h, out_w).
Default: None
Default: None
scale(int|None): The multiplier for the input height or width.
scale(int|None): The multiplier for the input height or width.
...
@@ -3971,13 +3971,20 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
...
@@ -3971,13 +3971,20 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
def
_is_list_or_turple_
(
data
):
def
_is_list_or_turple_
(
data
):
return
(
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
))
return
(
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
))
out_h
=
0
out_w
=
0
inputs
=
{
"X"
:
input
}
if
out_shape
is
not
None
:
if
out_shape
is
not
None
:
if
not
(
_is_list_or_turple_
(
out_shape
)
and
len
(
out_shape
)
==
2
):
if
not
(
_is_list_or_turple_
(
out_shape
)
and
len
(
out_shape
)
==
2
)
and
(
out_shape
is
not
Variable
):
raise
ValueError
(
'out_shape should be a list or tuple '
,
raise
ValueError
(
'out_shape should be a list or tuple '
,
'with length 2, (out_h, out_w).'
)
'with length 2, (out_h, out_w).'
)
if
_is_list_or_turple_
(
out_shape
):
out_shape
=
list
(
map
(
int
,
out_shape
))
out_shape
=
list
(
map
(
int
,
out_shape
))
out_h
=
out_shape
[
0
]
out_h
=
out_shape
[
0
]
out_w
=
out_shape
[
1
]
out_w
=
out_shape
[
1
]
else
:
inputs
[
'OutSize'
]
=
out_shape
else
:
else
:
out_h
=
int
(
input
.
shape
[
2
]
*
scale
)
out_h
=
int
(
input
.
shape
[
2
]
*
scale
)
out_w
=
int
(
input
.
shape
[
3
]
*
scale
)
out_w
=
int
(
input
.
shape
[
3
]
*
scale
)
...
@@ -3985,13 +3992,62 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
...
@@ -3985,13 +3992,62 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
out
=
helper
.
create_tmp_variable
(
dtype
)
out
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"bilinear_interp"
,
type
=
"bilinear_interp"
,
inputs
=
{
"X"
:
input
}
,
inputs
=
inputs
,
outputs
=
{
"Out"
:
out
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"out_h"
:
out_h
,
attrs
=
{
"out_h"
:
out_h
,
"out_w"
:
out_w
})
"out_w"
:
out_w
})
return
out
return
out
def
gather
(
input
,
index
):
"""
Output is obtained by gathering entries of the outer-most dimension
of X indexed by `index` and concatenate them together.
.. math::
Out = X[Index]
.. code-block:: text
Given:
X = [[1, 2],
[3, 4],
[5, 6]]
Index = [1, 2]
Then:
Out = [[3, 4],
[5, 6]]
Args:
input (Variable): The source input with rank>=1.
index (Variable): The index input with rank=1.
Returns:
output (Variable): The output is a tensor with the same rank as input.
Examples:
.. code-block:: python
output = fluid.layers.gather(x, index)
"""
helper
=
LayerHelper
(
'gather'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
out
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
"gather"
,
inputs
=
{
"X"
:
input
,
"Index"
:
index
},
outputs
=
{
"Out"
:
out
})
return
out
def
random_crop
(
input
,
shape
,
seed
=
1
):
def
random_crop
(
input
,
shape
,
seed
=
1
):
helper
=
LayerHelper
(
"random_crop"
,
**
locals
())
helper
=
LayerHelper
(
"random_crop"
,
**
locals
())
dtype
=
helper
.
input_dtype
()
dtype
=
helper
.
input_dtype
()
...
...
python/paddle/fluid/layers/ops.py
浏览文件 @
8855d4a7
...
@@ -71,6 +71,7 @@ __all__ = [
...
@@ -71,6 +71,7 @@ __all__ = [
'cumsum'
,
'cumsum'
,
'scatter'
,
'scatter'
,
'sum'
,
'sum'
,
'shape'
,
]
+
__activations__
]
+
__activations__
for
_OP
in
set
(
__all__
):
for
_OP
in
set
(
__all__
):
...
...
python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
浏览文件 @
8855d4a7
...
@@ -17,7 +17,10 @@ import numpy as np
...
@@ -17,7 +17,10 @@ import numpy as np
from
op_test
import
OpTest
from
op_test
import
OpTest
def
bilinear_interp_np
(
input
,
out_h
,
out_w
):
def
bilinear_interp_np
(
input
,
out_h
,
out_w
,
out_size
):
if
out_size
is
not
None
:
out_h
=
out_size
[
0
]
out_w
=
out_size
[
1
]
batch_size
,
channel
,
in_h
,
in_w
=
input
.
shape
batch_size
,
channel
,
in_h
,
in_w
=
input
.
shape
if
out_h
>
1
:
if
out_h
>
1
:
ratio_h
=
(
in_h
-
1.0
)
/
(
out_h
-
1.0
)
ratio_h
=
(
in_h
-
1.0
)
/
(
out_h
-
1.0
)
...
@@ -49,12 +52,15 @@ def bilinear_interp_np(input, out_h, out_w):
...
@@ -49,12 +52,15 @@ def bilinear_interp_np(input, out_h, out_w):
class
TestBilinearInterpOp
(
OpTest
):
class
TestBilinearInterpOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
out_size
=
None
self
.
init_test_case
()
self
.
init_test_case
()
self
.
op_type
=
"bilinear_interp"
self
.
op_type
=
"bilinear_interp"
input_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
)
input_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
)
output_np
=
bilinear_interp_np
(
input_np
,
self
.
out_h
,
self
.
out_w
)
output_np
=
bilinear_interp_np
(
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
self
.
inputs
=
{
'X'
:
input_np
}
self
.
inputs
=
{
'X'
:
input_np
}
if
self
.
out_size
is
not
None
:
self
.
inputs
[
'OutSize'
]
=
self
.
out_size
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
}
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
}
self
.
outputs
=
{
'Out'
:
output_np
}
self
.
outputs
=
{
'Out'
:
output_np
}
...
@@ -68,6 +74,7 @@ class TestBilinearInterpOp(OpTest):
...
@@ -68,6 +74,7 @@ class TestBilinearInterpOp(OpTest):
self
.
input_shape
=
[
2
,
3
,
4
,
4
]
self
.
input_shape
=
[
2
,
3
,
4
,
4
]
self
.
out_h
=
2
self
.
out_h
=
2
self
.
out_w
=
2
self
.
out_w
=
2
self
.
out_size
=
np
.
array
([
3
,
3
]).
astype
(
"int32"
)
class
TestCase1
(
TestBilinearInterpOp
):
class
TestCase1
(
TestBilinearInterpOp
):
...
@@ -91,5 +98,29 @@ class TestCase3(TestBilinearInterpOp):
...
@@ -91,5 +98,29 @@ class TestCase3(TestBilinearInterpOp):
self
.
out_w
=
128
self
.
out_w
=
128
class
TestCase4
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
self
.
out_size
=
np
.
array
([
2
,
2
]).
astype
(
"int32"
)
class
TestCase5
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
self
.
out_size
=
np
.
array
([
11
,
11
]).
astype
(
"int32"
)
class
TestCase6
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
self
.
out_size
=
np
.
array
([
65
,
129
]).
astype
(
"int32"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_gather_op.py
浏览文件 @
8855d4a7
...
@@ -20,8 +20,9 @@ from op_test import OpTest
...
@@ -20,8 +20,9 @@ from op_test import OpTest
class
TestGatherOp
(
OpTest
):
class
TestGatherOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"gather"
self
.
op_type
=
"gather"
xnp
=
np
.
random
.
random
((
10
,
20
)).
astype
(
"float32"
)
self
.
config
()
self
.
inputs
=
{
'X'
:
xnp
,
'Index'
:
np
.
array
([
1
,
3
,
5
]).
astype
(
"int32"
)}
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
xnp
,
'Index'
:
np
.
array
(
self
.
index
).
astype
(
"int32"
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
"X"
][
self
.
inputs
[
"Index"
]]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
"X"
][
self
.
inputs
[
"Index"
]]}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
@@ -30,6 +31,16 @@ class TestGatherOp(OpTest):
...
@@ -30,6 +31,16 @@ class TestGatherOp(OpTest):
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
)
def
config
(
self
):
self
.
x_shape
=
(
10
,
20
)
self
.
index
=
[
1
,
3
,
5
]
class
TestCase1
(
TestGatherOp
):
def
config
(
self
):
self
.
x_shape
=
(
10
)
self
.
index
=
[
1
,
3
,
5
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_shape_op.py
0 → 100644
浏览文件 @
8855d4a7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestShapeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"shape"
self
.
config
()
self
.
shape
=
[
2
,
3
]
input
=
np
.
zeros
(
self
.
shape
)
self
.
inputs
=
{
'Input'
:
input
}
self
.
outputs
=
{
'Out'
:
np
.
array
(
self
.
shape
)}
def
config
(
self
):
self
.
shape
=
[
2
,
3
]
def
test_check_output
(
self
):
self
.
check_output
()
class
case1
(
TestShapeOp
):
def
config
(
self
):
self
.
shape
=
[
2
]
class
case2
(
TestShapeOp
):
def
config
(
self
):
self
.
shape
=
[
1
,
2
,
3
]
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_split_var.py
浏览文件 @
8855d4a7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
math
import
math
import
unittest
import
unittest
from
paddle.fluid.transpiler.distribute_transpiler
import
split_
dense_
variable
from
paddle.fluid.transpiler.distribute_transpiler
import
split_variable
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
import
random
import
random
...
@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
...
@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR,
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape
=
shape
)
shape
=
shape
)
var_list
.
append
(
var
)
var_list
.
append
(
var
)
blocks
=
split_
dense_
variable
(
var_list
,
10
,
min_size
)
blocks
=
split_variable
(
var_list
,
10
,
min_size
)
all_sizes
=
[]
all_sizes
=
[]
for
s
in
expected_sizes
:
for
s
in
expected_sizes
:
for
s2
in
s
:
for
s2
in
s
:
...
...
python/paddle/fluid/transpiler/details/__init__.py
0 → 100644
浏览文件 @
8855d4a7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
program_utils
import
*
from
ufind
import
*
python/paddle/fluid/transpiler/details/program_utils.py
0 → 100644
浏览文件 @
8855d4a7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
delete_ops
(
block
,
ops
):
try
:
start
=
list
(
block
.
ops
).
index
(
ops
[
0
])
end
=
list
(
block
.
ops
).
index
(
ops
[
-
1
])
[
block
.
remove_op
(
start
)
for
_
in
xrange
(
end
-
start
+
1
)]
except
Exception
,
e
:
raise
e
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
python/paddle/fluid/transpiler/details/ufind.py
0 → 100644
浏览文件 @
8855d4a7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
UnionFind
(
object
):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def
__init__
(
self
,
elementes
=
None
):
self
.
_parents
=
[]
# index -> parent index
self
.
_index
=
{}
# element -> index
self
.
_curr_idx
=
0
if
not
elementes
:
elementes
=
[]
for
ele
in
elementes
:
self
.
_parents
.
append
(
self
.
_curr_idx
)
self
.
_index
.
update
({
ele
:
self
.
_curr_idx
})
self
.
_curr_idx
+=
1
def
find
(
self
,
x
):
# Find the root index of given element x,
# execute the path compress while findind the root index
if
not
x
in
self
.
_index
:
return
-
1
idx
=
self
.
_index
[
x
]
while
idx
!=
self
.
_parents
[
idx
]:
t
=
self
.
_parents
[
idx
]
self
.
_parents
[
idx
]
=
self
.
_parents
[
t
]
idx
=
t
return
idx
def
union
(
self
,
x
,
y
):
# Union two given element
x_root
=
self
.
find
(
x
)
y_root
=
self
.
find
(
y
)
if
x_root
==
y_root
:
return
self
.
_parents
[
x_root
]
=
y_root
def
is_connected
(
self
,
x
,
y
):
# If two given elements have the same root index,
# then they are connected.
return
self
.
find
(
x
)
==
self
.
find
(
y
)
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
8855d4a7
...
@@ -11,6 +11,30 @@
...
@@ -11,6 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
"""
from
__future__
import
print_function
from
__future__
import
print_function
...
@@ -21,9 +45,11 @@ from .. import core, framework
...
@@ -21,9 +45,11 @@ from .. import core, framework
from
..framework
import
Program
,
default_main_program
,
\
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
default_startup_program
,
\
Variable
,
Parameter
,
grad_var_name
Variable
,
Parameter
,
grad_var_name
from
details
import
*
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
)
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
...
@@ -40,62 +66,11 @@ class VarBlock:
...
@@ -40,62 +66,11 @@ class VarBlock:
return
"%s:%d:%d"
%
(
self
.
varname
,
self
.
offset
,
self
.
size
)
return
"%s:%d:%d"
%
(
self
.
varname
,
self
.
offset
,
self
.
size
)
class
UnionFind
(
object
):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def
__init__
(
self
,
elementes
=
None
):
self
.
_parents
=
[]
# index -> parent index
self
.
_index
=
{}
# element -> index
self
.
_curr_idx
=
0
if
not
elementes
:
elementes
=
[]
for
ele
in
elementes
:
self
.
_parents
.
append
(
self
.
_curr_idx
)
self
.
_index
.
update
({
ele
:
self
.
_curr_idx
})
self
.
_curr_idx
+=
1
def
find
(
self
,
x
):
# Find the root index of given element x,
# execute the path compress while findind the root index
if
not
x
in
self
.
_index
:
return
-
1
idx
=
self
.
_index
[
x
]
while
idx
!=
self
.
_parents
[
idx
]:
t
=
self
.
_parents
[
idx
]
self
.
_parents
[
idx
]
=
self
.
_parents
[
t
]
idx
=
t
return
idx
def
union
(
self
,
x
,
y
):
# Union two given element
x_root
=
self
.
find
(
x
)
y_root
=
self
.
find
(
y
)
if
x_root
==
y_root
:
return
self
.
_parents
[
x_root
]
=
y_root
def
is_connected
(
self
,
x
,
y
):
# If two given elements have the same root index,
# then they are connected.
return
self
.
find
(
x
)
==
self
.
find
(
y
)
def
same_or_split_var
(
p_name
,
var_name
):
def
same_or_split_var
(
p_name
,
var_name
):
return
p_name
==
var_name
or
p_name
.
startswith
(
var_name
+
".block"
)
return
p_name
==
var_name
or
p_name
.
startswith
(
var_name
+
".block"
)
def
split_
dense_
variable
(
var_list
,
service_count
,
min_block_size
=
8192
):
def
split_variable
(
var_list
,
service_count
,
min_block_size
=
8192
):
"""
"""
We may need to split dense tensor to one or more blocks and put
We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor
them equally onto parameter server. One block is a sub-tensor
...
@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
...
@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
return
blocks
return
blocks
def
delete_ops
(
block
,
ops
):
try
:
start
=
list
(
block
.
ops
).
index
(
ops
[
0
])
end
=
list
(
block
.
ops
).
index
(
ops
[
-
1
])
[
block
.
remove_op
(
start
)
for
_
in
xrange
(
end
-
start
+
1
)]
except
Exception
,
e
:
raise
e
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
class
DistributeTranspiler
:
class
DistributeTranspiler
:
def
transpile
(
self
,
def
_has_distributed_lookup_table
(
self
):
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
split_method
=
RoundRobin
,
sync_mode
=
True
):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
sync_mode
=
sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
pserver_endpoints
)
# process lookup_table_op
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops
=
[]
distributed_lookup_table_ops
=
[]
# support only one distributed_lookup_table now
# support only one distributed_lookup_table now
self
.
table_name
=
None
self
.
table_name
=
None
for
op
in
program
.
global_block
().
ops
:
for
op
in
self
.
origin_
program
.
global_block
().
ops
:
if
op
.
type
==
LOOKUP_TABLE_TYPE
:
if
op
.
type
==
LOOKUP_TABLE_TYPE
:
if
op
.
attrs
[
'is_distributed'
]
is
True
:
if
op
.
attrs
[
'is_distributed'
]
is
True
:
if
self
.
table_name
is
None
:
if
self
.
table_name
is
None
:
...
@@ -246,20 +137,13 @@ class DistributeTranspiler:
...
@@ -246,20 +137,13 @@ class DistributeTranspiler:
if
self
.
table_name
is
not
None
:
if
self
.
table_name
is
not
None
:
assert
op
.
input
(
"W"
)[
0
]
!=
self
.
table_name
assert
op
.
input
(
"W"
)[
0
]
!=
self
.
table_name
self
.
has_distributed_lookup_table
=
len
(
return
len
(
distributed_lookup_table_ops
)
>
0
distributed_lookup_table_ops
)
>
0
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list
=
[]
grad_list
=
[]
for
p
,
g
in
params_grads
:
# skip parameter marked not trainable
if
type
(
p
)
==
Parameter
and
p
.
trainable
==
False
:
continue
param_list
.
append
(
p
)
grad_list
.
append
(
g
)
def
_update_dist_lookup_table_vars
(
self
,
param_list
,
grad_list
,
params_grads
):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
# update self.table_param_grad and self.trainer_side_table_grad_list
program
=
self
.
origin_program
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
param_list
=
[
param_list
=
[
param
for
param
in
param_list
if
param
.
name
!=
self
.
table_name
param
for
param
in
param_list
if
param
.
name
!=
self
.
table_name
...
@@ -277,7 +161,7 @@ class DistributeTranspiler:
...
@@ -277,7 +161,7 @@ class DistributeTranspiler:
self
.
trainer_side_table_grad_list
=
[
self
.
trainer_side_table_grad_list
=
[
program
.
global_block
().
create_var
(
program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d.pserver_%d"
%
name
=
"%s.trainer_%d.pserver_%d"
%
(
table_grad_var
.
name
,
trainer_id
,
index
),
(
table_grad_var
.
name
,
self
.
trainer_id
,
index
),
type
=
table_grad_var
.
type
,
type
=
table_grad_var
.
type
,
shape
=
table_grad_var
.
shape
,
shape
=
table_grad_var
.
shape
,
dtype
=
table_grad_var
.
dtype
)
dtype
=
table_grad_var
.
dtype
)
...
@@ -293,23 +177,41 @@ class DistributeTranspiler:
...
@@ -293,23 +177,41 @@ class DistributeTranspiler:
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
]
]
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
def
_init_splited_vars
(
self
,
split_method
):
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
# update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list
=
[]
grad_list
=
[]
for
p
,
g
in
self
.
params_grads
:
# skip parameter marked not trainable
if
type
(
p
)
==
Parameter
and
p
.
trainable
==
False
:
continue
param_list
.
append
(
p
)
grad_list
.
append
(
g
)
self
.
_update_dist_lookup_table_vars
(
param_list
,
grad_list
,
self
.
params_grads
)
grad_blocks
=
split_variable
(
grad_list
,
len
(
self
.
pserver_endpoints
))
param_blocks
=
split_variable
(
param_list
,
len
(
self
.
pserver_endpoints
))
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
# step2: Create new vars for the parameters and gradients blocks and
# origin_varname -> [splited_var]
# add ops to do the split.
self
.
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
self
.
origin_program
,
param_blocks
)
param_blocks
)
self
.
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
self
.
origin_program
,
program
,
grad_blocks
,
add_trainer_suffix
=
self
.
trainer_num
>
1
)
grad_blocks
,
grad_param_mapping
=
dict
()
add_trainer_suffix
=
self
.
trainer_num
>
1
)
self
.
grad_param_mapping
=
dict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
self
.
grad_param_mapping
[
self
.
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
param_var_mapping
[
p_name
][
int
(
p_bid
)]
self
.
param_var_mapping
[
p_name
][
int
(
p_bid
)]
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
dict
()
self
.
param_grad_ep_mapping
=
dict
()
...
@@ -322,10 +224,50 @@ class DistributeTranspiler:
...
@@ -322,10 +224,50 @@ class DistributeTranspiler:
})
for
ep
in
self
.
pserver_endpoints
})
for
ep
in
self
.
pserver_endpoints
]
]
def
transpile
(
self
,
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
split_method
=
RoundRobin
,
sync_mode
=
True
):
"""
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
sync_mode
=
sync_mode
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
self
.
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
self
.
pserver_endpoints
)
self
.
has_distributed_lookup_table
=
self
.
_has_distributed_lookup_table
()
# split and create vars, then put splited vars in dicts for later use.
self
.
_init_splited_vars
(
split_method
)
# step 3.1: insert send op to send gradient vars to parameter servers
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher
.
reset
()
ps_dispatcher
.
reset
()
send_vars
=
[]
send_vars
=
[]
for
orig_varname
,
splited_vars
in
grad_var_mapping
.
items
():
for
orig_varname
,
splited_vars
in
self
.
grad_var_mapping
.
items
():
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
if
len
(
splited_vars
)
==
1
:
if
len
(
splited_vars
)
==
1
:
orig_varname
=
splited_vars
[
0
].
name
orig_varname
=
splited_vars
[
0
].
name
...
@@ -367,7 +309,7 @@ class DistributeTranspiler:
...
@@ -367,7 +309,7 @@ class DistributeTranspiler:
# step 3.2: insert recv op to receive parameters from parameter server
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars
=
[]
recv_vars
=
[]
for
_
,
var
in
enumerate
(
send_vars
):
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
grad_param_mapping
[
var
])
recv_vars
.
append
(
self
.
grad_param_mapping
[
var
])
ps_dispatcher
.
reset
()
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
...
@@ -375,7 +317,7 @@ class DistributeTranspiler:
...
@@ -375,7 +317,7 @@ class DistributeTranspiler:
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
for
varname
,
splited_var
in
self
.
param_var_mapping
.
iteritems
():
eps
=
[]
eps
=
[]
for
var
in
splited_var
:
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
...
@@ -399,7 +341,7 @@ class DistributeTranspiler:
...
@@ -399,7 +341,7 @@ class DistributeTranspiler:
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
})
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
for
varname
,
splited_var
in
self
.
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
orig_param
=
program
.
global_block
().
vars
[
varname
]
...
@@ -440,7 +382,6 @@ class DistributeTranspiler:
...
@@ -440,7 +382,6 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# change client side var name to origin name by
# removing ".trainer_%d" suffix
# removing ".trainer_%d" suffix
suff_idx
=
v
.
name
.
find
(
".trainer_"
)
suff_idx
=
v
.
name
.
find
(
".trainer_"
)
if
suff_idx
>=
0
:
if
suff_idx
>=
0
:
orig_var_name
=
v
.
name
[:
suff_idx
]
orig_var_name
=
v
.
name
[:
suff_idx
]
...
@@ -477,24 +418,14 @@ class DistributeTranspiler:
...
@@ -477,24 +418,14 @@ class DistributeTranspiler:
# located on current pserver
# located on current pserver
opt_op_on_pserver
=
[]
opt_op_on_pserver
=
[]
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
if
self
.
_is_opt_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
if
self
.
_is_optimizer_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
opt_op_on_pserver
.
append
(
op
)
# step 3.3
# step 3.3
# Iterate through the ops, and if an op and the optimize ops
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# which located on current pserver are in one set, then
# append it into the sub program.
# append it into the sub program.
# We try to put optimization program run parallelly, assume
# optimization program always looks like:
#
# prevop -> prevop -> opt op -> following op -> following op; ->
# prevop -> prevop -> opt op -> following op -> following op; ->
# global op -> global op
#
# we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops
=
[]
global_ops
=
[]
# HACK: optimization global ops only used to scale beta1 and beta2
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
# replace it with dependency engine.
...
@@ -502,12 +433,18 @@ class DistributeTranspiler:
...
@@ -502,12 +433,18 @@ class DistributeTranspiler:
if
self
.
_is_adam_connected_op
(
op
):
if
self
.
_is_adam_connected_op
(
op
):
global_ops
.
append
(
op
)
global_ops
.
append
(
op
)
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
):
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
,
merged_var
):
if
self
.
_is_opt_op
(
op
):
if
self
.
_is_opt
imizer
_op
(
op
):
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
self
.
origin_program
)
self
.
origin_program
,
merged_var
)
else
:
else
:
self
.
_append_pserver_non_opt_ops
(
block
,
op
)
self
.
_append_pserver_non_opt_ops
(
block
,
op
,
endpoint
)
def
__op_have_grad_input__
(
op
):
for
varname
in
op
.
input_arg_names
:
if
varname
.
find
(
"@GRAD"
)
>=
0
:
return
varname
return
""
# append lr decay ops to the child block if exists
# append lr decay ops to the child block if exists
lr_ops
=
self
.
_get_lr_ops
()
lr_ops
=
self
.
_get_lr_ops
()
...
@@ -515,17 +452,26 @@ class DistributeTranspiler:
...
@@ -515,17 +452,26 @@ class DistributeTranspiler:
lr_decay_block
=
pserver_program
.
create_block
(
lr_decay_block
=
pserver_program
.
create_block
(
pserver_program
.
num_blocks
-
1
)
pserver_program
.
num_blocks
-
1
)
for
_
,
op
in
enumerate
(
lr_ops
):
for
_
,
op
in
enumerate
(
lr_ops
):
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
)
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
,
endpoint
)
# append op to the current block
# append op to the current block
grad_to_block_id
=
[]
grad_to_block_id
=
[]
pre_block_idx
=
pserver_program
.
num_blocks
-
1
pre_block_idx
=
pserver_program
.
num_blocks
-
1
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
per_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
per_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
# append grad merging ops before clip and weight decay
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
# find the origin @GRAD var before clipping
grad_varname_for_block
=
__op_have_grad_input__
(
op
)
if
ufind
.
is_connected
(
op
,
opt_op
)
and
grad_varname_for_block
:
merged_var
=
self
.
_append_pserver_grad_merge_ops
(
per_opt_block
,
grad_varname_for_block
,
endpoint
,
grad_to_block_id
,
self
.
origin_program
)
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
# optimizer is connected to itself
# optimizer is connected to itself
if
ufind
.
is_connected
(
op
,
opt_op
)
and
op
not
in
global_ops
:
if
ufind
.
is_connected
(
op
,
opt_op
)
and
op
not
in
global_ops
:
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
)
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
,
merged_var
)
# append global ops
# append global ops
if
global_ops
:
if
global_ops
:
...
@@ -533,15 +479,7 @@ class DistributeTranspiler:
...
@@ -533,15 +479,7 @@ class DistributeTranspiler:
pserver_program
.
num_blocks
-
1
)
pserver_program
.
num_blocks
-
1
)
for
glb_op
in
global_ops
:
for
glb_op
in
global_ops
:
__append_optimize_op__
(
glb_op
,
opt_state_block
,
__append_optimize_op__
(
glb_op
,
opt_state_block
,
grad_to_block_id
)
grad_to_block_id
,
None
)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# process distributed lookup_table
# process distributed lookup_table
prefetch_block
=
None
prefetch_block
=
None
...
@@ -631,6 +569,8 @@ class DistributeTranspiler:
...
@@ -631,6 +569,8 @@ class DistributeTranspiler:
attrs
=
op
.
attrs
)
attrs
=
op
.
attrs
)
return
s_prog
return
s_prog
# ====================== private transpiler functions =====================
# transpiler function for dis lookup_table
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
pserver_endpoints
):
pserver_endpoints
):
...
@@ -836,7 +776,6 @@ class DistributeTranspiler:
...
@@ -836,7 +776,6 @@ class DistributeTranspiler:
return
table_opt_block
return
table_opt_block
# ====================== private transpiler functions =====================
def
_create_vars_from_blocklist
(
self
,
def
_create_vars_from_blocklist
(
self
,
program
,
program
,
block_list
,
block_list
,
...
@@ -979,44 +918,57 @@ class DistributeTranspiler:
...
@@ -979,44 +918,57 @@ class DistributeTranspiler:
pass
pass
return
orig_shape
return
orig_shape
def
_
orig_varname
(
self
,
varname
):
def
_
get_varname_parts
(
self
,
varname
):
suff_idx
=
varname
.
find
(
".trainer_"
)
# returns origin, blockid, trainerid
orig_var_name
=
""
orig_var_name
=
""
if
suff_idx
>=
0
:
trainer_part
=
""
orig_var_name
=
varname
[:
suff_idx
]
block_part
=
""
trainer_idx
=
varname
.
find
(
".trainer_"
)
if
trainer_idx
>=
0
:
trainer_part
=
varname
[
trainer_idx
+
1
:]
else
:
trainer_idx
=
len
(
varname
)
block_index
=
varname
.
find
(
".block"
)
if
block_index
>=
0
:
block_part
=
varname
[
block_index
+
1
:
trainer_idx
]
else
:
else
:
orig_var_name
=
varname
block_index
=
len
(
varname
)
return
orig_var_name
orig_var_name
=
varname
[
0
:
min
(
block_index
,
trainer_idx
)]
return
orig_var_name
,
block_part
,
trainer_part
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
def
_orig_varname
(
self
,
varname
):
orig
,
_
,
_
=
self
.
_get_varname_parts
(
varname
)
return
orig
def
_append_pserver_grad_merge_ops
(
self
,
optimize_block
,
grad_varname_for_block
,
endpoint
,
grad_to_block_id
,
origin_program
):
grad_to_block_id
,
origin_program
):
program
=
optimize_block
.
program
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for
key
in
opt_op
.
input_names
:
if
key
==
"Grad"
:
grad_block
=
None
grad_block
=
None
for
g
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
for
g
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
if
same_or_split_var
(
if
self
.
_orig_varname
(
g
.
name
)
==
\
self
.
_orig_varname
(
g
.
name
),
self
.
_orig_varname
(
grad_varname_for_block
):
self
.
_orig_varname
(
opt_op
.
input
(
key
)[
0
])):
grad_block
=
g
grad_block
=
g
break
break
if
not
grad_block
:
if
not
grad_block
:
# do not append this op if current endpoint
# do not append this op if current endpoint
# is not dealing with this grad block
# is not dealing with this grad block
return
return
orig_varname
,
block_name
,
trainer_name
=
self
.
_get_varname_parts
(
grad_block
.
name
)
if
block_name
:
merged_var_name
=
'.'
.
join
([
orig_varname
,
block_name
])
else
:
merged_var_name
=
orig_varname
merged_var
=
\
merged_var
=
\
pserver_block
.
vars
[
self
.
_orig_varname
(
grad_block
.
name
)]
pserver_block
.
vars
[
merged_var_name
]
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
optimize_block
.
idx
))
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
vars2merge
=
[]
vars2merge
=
[]
for
i
in
xrange
(
self
.
trainer_num
):
for
i
in
xrange
(
self
.
trainer_num
):
per_trainer_name
=
"%s.trainer_%d"
%
\
per_trainer_name
=
"%s.trainer_%d"
%
\
(
self
.
_orig_varname
(
grad_block
.
name
)
,
i
)
(
merged_var_name
,
i
)
vars2merge
.
append
(
pserver_block
.
vars
[
per_trainer_name
])
vars2merge
.
append
(
pserver_block
.
vars
[
per_trainer_name
])
optimize_block
.
append_op
(
optimize_block
.
append_op
(
...
@@ -1030,7 +982,17 @@ class DistributeTranspiler:
...
@@ -1030,7 +982,17 @@ class DistributeTranspiler:
inputs
=
{
"X"
:
merged_var
},
inputs
=
{
"X"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
return
merged_var
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
grad_to_block_id
,
origin_program
,
merged_var
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for
key
in
opt_op
.
input_names
:
if
key
==
"Grad"
:
new_inputs
[
key
]
=
merged_var
new_inputs
[
key
]
=
merged_var
elif
key
==
"Param"
:
elif
key
==
"Param"
:
# param is already created on global program
# param is already created on global program
...
@@ -1089,17 +1051,31 @@ class DistributeTranspiler:
...
@@ -1089,17 +1051,31 @@ class DistributeTranspiler:
outputs
=
outputs
,
outputs
=
outputs
,
attrs
=
opt_op
.
attrs
)
attrs
=
opt_op
.
attrs
)
def
_append_pserver_non_opt_ops
(
self
,
optimize_block
,
opt_op
):
def
_is_splited_grad_var
(
self
,
var
,
var_dict
):
grad_block
=
None
for
_
,
g
in
var_dict
.
iteritems
():
if
self
.
_orig_varname
(
g
.
name
)
==
self
.
_orig_varname
(
var
.
name
):
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
grad_block
=
g
break
return
grad_block
def
_append_pserver_non_opt_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
program
=
optimize_block
.
program
program
=
optimize_block
.
program
# Append the ops for parameters that do not need to be optimized/updated
# Append the ops for parameters that do not need to be optimized/updated
inputs
=
self
.
_get_input_map_from_op
(
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
varlist
in
inputs
.
itervalue
s
():
for
key
,
varlist
in
inputs
.
iteritem
s
():
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
if
not
program
.
global_block
().
vars
.
has_key
(
var
.
name
):
# for ops like clipping and weight decay, get the splited var
# for inputs/outputs
grad_block
=
self
.
_is_splited_grad_var
(
var
,
program
.
global_block
().
vars
)
if
grad_block
:
inputs
[
key
]
=
grad_block
elif
not
program
.
global_block
().
vars
.
has_key
(
var
.
name
):
program
.
global_block
().
create_var
(
program
.
global_block
().
create_var
(
name
=
var
.
name
,
name
=
var
.
name
,
persistable
=
var
.
persistable
,
persistable
=
var
.
persistable
,
...
@@ -1108,12 +1084,15 @@ class DistributeTranspiler:
...
@@ -1108,12 +1084,15 @@ class DistributeTranspiler:
outputs
=
self
.
_get_output_map_from_op
(
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
outputs
.
iteritems
():
for
varlist
in
outputs
.
itervalues
():
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
grad_block
=
self
.
_is_splited_grad_var
(
var
,
program
.
global_block
().
vars
)
if
grad_block
:
outputs
[
key
]
=
grad_block
elif
not
program
.
global_block
().
vars
.
has_key
(
var
.
name
):
program
.
global_block
().
clone_variable
(
var
)
program
.
global_block
().
clone_variable
(
var
)
optimize_block
.
append_op
(
optimize_block
.
append_op
(
...
@@ -1160,9 +1139,17 @@ class DistributeTranspiler:
...
@@ -1160,9 +1139,17 @@ class DistributeTranspiler:
ufind
.
union
(
op1
,
op2
)
ufind
.
union
(
op1
,
op2
)
return
ufind
return
ufind
def
_is_opt_op
(
self
,
op
):
def
_is_opt_role_op
(
self
,
op
):
# NOTE: It's a HACK implement.
# NOTE: depend on oprole to find out whether this op is for
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
# optimize
op_maker
=
core
.
op_proto_and_checker_maker
optimize_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
if
op_maker
.
kOpRoleAttrName
()
in
op
.
attrs
and
\
int
(
op
.
attrs
[
op_maker
.
kOpRoleAttrName
()])
==
int
(
optimize_role
):
return
True
return
False
def
_is_optimizer_op
(
self
,
op
):
if
"Param"
in
op
.
input_names
and
\
if
"Param"
in
op
.
input_names
and
\
"LearningRate"
in
op
.
input_names
:
"LearningRate"
in
op
.
input_names
:
return
True
return
True
...
@@ -1212,7 +1199,7 @@ class DistributeTranspiler:
...
@@ -1212,7 +1199,7 @@ class DistributeTranspiler:
# find learning rate variables by optimize op
# find learning rate variables by optimize op
lr_vars
=
set
()
lr_vars
=
set
()
for
op
in
self
.
optimize_ops
:
for
op
in
self
.
optimize_ops
:
if
self
.
_is_opt_op
(
op
):
if
self
.
_is_opt
imizer
_op
(
op
):
lr_vars
.
add
(
op
.
input
(
"LearningRate"
)[
0
])
lr_vars
.
add
(
op
.
input
(
"LearningRate"
)[
0
])
find_ops
=
[]
find_ops
=
[]
...
@@ -1229,7 +1216,7 @@ class DistributeTranspiler:
...
@@ -1229,7 +1216,7 @@ class DistributeTranspiler:
# NOTE: we need to skip all optimize ops, since it is connected
# NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops.
# with forward/backward ops and lr ops, we only need the lr ops.
if
op1
!=
op2
and
self
.
_is_op_connected
(
op1
,
op2
)
and
\
if
op1
!=
op2
and
self
.
_is_op_connected
(
op1
,
op2
)
and
\
not
self
.
_is_opt
_op
(
op1
)
and
not
self
.
_is_opt
_op
(
op2
):
not
self
.
_is_opt
imizer_op
(
op1
)
and
not
self
.
_is_optimizer
_op
(
op2
):
ufind
.
union
(
op1
,
op2
)
ufind
.
union
(
op1
,
op2
)
# find all ops which is related with lr var
# find all ops which is related with lr var
for
op1
in
block
.
ops
:
for
op1
in
block
.
ops
:
...
@@ -1250,13 +1237,21 @@ class DistributeTranspiler:
...
@@ -1250,13 +1237,21 @@ class DistributeTranspiler:
block
=
self
.
origin_program
.
global_block
()
block
=
self
.
origin_program
.
global_block
()
opt_ops
=
[]
opt_ops
=
[]
params_grads
=
[]
params_grads
=
[]
origin_var_dict
=
self
.
origin_program
.
global_block
().
vars
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
if
self
.
_is_opt_op
(
op
):
if
self
.
_is_opt_
role_
op
(
op
):
opt_ops
.
append
(
op
)
opt_ops
.
append
(
op
)
params_grads
.
append
((
self
.
origin_program
.
global_block
().
var
(
# HACK(wuyi): if we find grad vars from input of optimize
op
.
input
(
"Param"
)[
0
]),
# ops, we may get the output of clip op. Use syntax "@GRAD"
self
.
origin_program
.
global_block
().
var
(
# and op_role_var to get the pair.
op
.
input
(
"Grad"
)[
0
])))
for
input_name
in
op
.
input_arg_names
:
if
input_name
.
find
(
"@GRAD"
)
!=
-
1
and
\
op
.
attrs
[
RPC_OP_ROLE_ATTR_NAME
]:
param_name
=
op
.
attrs
[
OP_ROLE_VAR_ATTR_NAME
][
0
]
params_grads
.
append
([
origin_var_dict
[
param_name
],
origin_var_dict
[
input_name
]
])
elif
self
.
_is_adam_connected_op
(
op
):
elif
self
.
_is_adam_connected_op
(
op
):
opt_ops
.
append
(
op
)
opt_ops
.
append
(
op
)
else
:
else
:
...
...
python/setup.py.in
浏览文件 @
8855d4a7
...
@@ -69,7 +69,8 @@ packages=['paddle',
...
@@ -69,7 +69,8 @@ packages=['paddle',
'paddle.fluid.proto',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.proto.profiler',
'paddle.fluid.layers',
'paddle.fluid.layers',
'paddle.fluid.transpiler']
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
if '${WITH_FLUID_ONLY}'== 'OFF':
if '${WITH_FLUID_ONLY}'== 'OFF':
packages+=['paddle.proto',
packages+=['paddle.proto',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录