提交 8855d4a7 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into dist_recordio

...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
| backyes | Yan-Fei Wang | | backyes | Yan-Fei Wang |
| baiyfbupt | Yi-Fan Bai | | baiyfbupt | Yi-Fan Bai |
| beckett1124 | Bin Qi | | beckett1124 | Bin Qi |
| ChengduoZH | Cheng-Duo Zhao|
| chengxiaohua1105 | Xiao-Hua Cheng | | chengxiaohua1105 | Xiao-Hua Cheng |
| cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang | | cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang |
| cxysteven | Xing-Yi Cheng | | cxysteven | Xing-Yi Cheng |
......
...@@ -29,7 +29,7 @@ RUN apt-get update && \ ...@@ -29,7 +29,7 @@ RUN apt-get update && \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-matplotlib gcc-4.8 g++-4.8 \ python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format swig doxygen cmake \ automake locales clang-format swig cmake \
liblapack-dev liblapacke-dev \ liblapack-dev liblapacke-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools libtool ccache && \ net-tools libtool ccache && \
......
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
RUN apt-get update && apt-get install -y python python-pip iputils-ping libgtk2.0-dev wget vim net-tools iftop
RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/lib/libnccl.so
RUN pip install -U pip
RUN pip install -U kubernetes opencv-python paddlepaddle
# IMPORTANT:
# Add "ENV http_proxy=http://ip:port" if your download is slow, and don't forget to unset it at runtime.
RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()\npaddle.dataset.flowers.fetch()" | python'
RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.mnist.train()\npaddle.dataset.mnist.test()\npaddle.dataset.imdb.fetch()" | python'
RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.imikolov.fetch()" | python'
RUN pip uninstall -y paddlepaddle && mkdir /workspace
ADD https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin
ADD https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root
ADD *.whl /
RUN pip install /*.whl && rm -f /*.whl && chmod +x /usr/bin/paddle_k8s
ENV LD_LIBRARY_PATH=/usr/local/lib
ADD fluid_benchmark.py dataset.py models/ /workspace/
...@@ -44,11 +44,25 @@ Currently supported `--model` argument include: ...@@ -44,11 +44,25 @@ Currently supported `--model` argument include:
## Run Distributed Benchmark on Kubernetes Cluster ## Run Distributed Benchmark on Kubernetes Cluster
You may need to build a Docker image before submitting a cluster job onto Kubernetes, or you will
have to start all those processes mannually on each node, which is not recommended.
To build the Docker image, you need to choose a paddle "whl" package to run with, you may either
download it from
http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_en.html or
build it by your own. Once you've got the "whl" package, put it under the current directory and run:
```bash
docker build -t [your docker image name]:[your docker image tag] .
```
Then push the image to a Docker registry that your Kubernetes cluster can reach.
We provide a script `kube_gen_job.py` to generate Kubernetes yaml files to submit We provide a script `kube_gen_job.py` to generate Kubernetes yaml files to submit
distributed benchmark jobs to your cluster. To generate a job yaml, just run: distributed benchmark jobs to your cluster. To generate a job yaml, just run:
```bash ```bash
python kube_gen_job.py --jobname myjob --pscpu 4 --cpu 8 --gpu 8 --psmemory 20 --memory 40 --pservers 4 --trainers 4 --entry "python fluid_benchmark.py --model mnist --parallel 1 --device GPU --update_method pserver " --disttype pserver python kube_gen_job.py --jobname myjob --pscpu 4 --cpu 8 --gpu 8 --psmemory 20 --memory 40 --pservers 4 --trainers 4 --entry "python fluid_benchmark.py --model mnist --gpus 8 --device GPU --update_method pserver " --disttype pserver
``` ```
Then the yaml files are generated under directory `myjob`, you can run: Then the yaml files are generated under directory `myjob`, you can run:
......
...@@ -49,7 +49,7 @@ def parse_args(): ...@@ -49,7 +49,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--fluid', default=1, type=int, help='whether is fluid job') '--fluid', default=1, type=int, help='whether is fluid job')
parser.add_argument( parser.add_argument(
'--rdma', action='store_ture', help='whether mount rdma libs') '--rdma', action='store_true', help='whether mount rdma libs')
parser.add_argument( parser.add_argument(
'--disttype', '--disttype',
default="pserver", default="pserver",
......
...@@ -37,7 +37,8 @@ nohup stdbuf -oL nvidia-smi \ ...@@ -37,7 +37,8 @@ nohup stdbuf -oL nvidia-smi \
-l 1 & -l 1 &
# mnist # mnist
# mnist gpu mnist 128 # mnist gpu mnist 128
FLAGS_benchmark=true stdbuf -oL python fluid/mnist.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=mnist \
--device=GPU \ --device=GPU \
--batch_size=128 \ --batch_size=128 \
--skip_batch_num=5 \ --skip_batch_num=5 \
...@@ -46,7 +47,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/mnist.py \ ...@@ -46,7 +47,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/mnist.py \
# vgg16 # vgg16
# gpu cifar10 128 # gpu cifar10 128
FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=vgg16 \
--device=GPU \ --device=GPU \
--batch_size=128 \ --batch_size=128 \
--skip_batch_num=5 \ --skip_batch_num=5 \
...@@ -54,7 +56,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \ ...@@ -54,7 +56,8 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
2>&1 | tee -a vgg16_gpu_128.log 2>&1 | tee -a vgg16_gpu_128.log
# flowers gpu 128 # flowers gpu 128
FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=vgg16 \
--device=GPU \ --device=GPU \
--batch_size=32 \ --batch_size=32 \
--data_set=flowers \ --data_set=flowers \
...@@ -64,40 +67,39 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \ ...@@ -64,40 +67,39 @@ FLAGS_benchmark=true stdbuf -oL python fluid/vgg16.py \
# resnet50 # resnet50
# resnet50 gpu cifar10 128 # resnet50 gpu cifar10 128
FLAGS_benchmark=true stdbuf -oL python fluid/resnet50.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=resnet50 \
--device=GPU \ --device=GPU \
--batch_size=128 \ --batch_size=128 \
--data_set=cifar10 \ --data_set=cifar10 \
--model=resnet_cifar10 \
--skip_batch_num=5 \ --skip_batch_num=5 \
--iterations=30 \ --iterations=30 \
2>&1 | tee -a resnet50_gpu_128.log 2>&1 | tee -a resnet50_gpu_128.log
# resnet50 gpu flowers 64 # resnet50 gpu flowers 64
FLAGS_benchmark=true stdbuf -oL python fluid/resnet50.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=resnet50 \
--device=GPU \ --device=GPU \
--batch_size=64 \ --batch_size=64 \
--data_set=flowers \ --data_set=flowers \
--model=resnet_imagenet \
--skip_batch_num=5 \ --skip_batch_num=5 \
--iterations=30 \ --iterations=30 \
2>&1 | tee -a resnet50_gpu_flowers_64.log 2>&1 | tee -a resnet50_gpu_flowers_64.log
# lstm # lstm
# lstm gpu imdb 32 # tensorflow only support batch=32 # lstm gpu imdb 32 # tensorflow only support batch=32
FLAGS_benchmark=true stdbuf -oL python fluid/stacked_dynamic_lstm.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=stacked_dynamic_lstm \
--device=GPU \ --device=GPU \
--batch_size=32 \ --batch_size=32 \
--skip_batch_num=5 \ --skip_batch_num=5 \
--iterations=30 \ --iterations=30 \
--hidden_dim=512 \
--emb_dim=512 \
--crop_size=1500 \
2>&1 | tee -a lstm_gpu_32.log 2>&1 | tee -a lstm_gpu_32.log
# seq2seq # seq2seq
# seq2seq gpu wmb 128 # seq2seq gpu wmb 128
FLAGS_benchmark=true stdbuf -oL python fluid/machine_translation.py \ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
--model=machine_translation \
--device=GPU \ --device=GPU \
--batch_size=128 \ --batch_size=128 \
--skip_batch_num=5 \ --skip_batch_num=5 \
......
...@@ -1009,3 +1009,9 @@ ____ ...@@ -1009,3 +1009,9 @@ ____
.. autofunction:: paddle.fluid.layers.upsampling_bilinear2d .. autofunction:: paddle.fluid.layers.upsampling_bilinear2d
:noindex: :noindex:
gather
____
.. autofunction:: paddle.fluid.layers.gather
:noindex:
...@@ -86,7 +86,7 @@ ...@@ -86,7 +86,7 @@
<br> <br>
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid_compiler.png" width=100%> <img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid-compiler.png" width=100%>
</p> </p>
--- ---
...@@ -123,12 +123,12 @@ ...@@ -123,12 +123,12 @@
<font size=5> <font size=5>
- 在科学计算领域,计算图是一种描述计算的经典方式。下图展示了从前向计算图(蓝色)开始,通过添加反向(红色)和优化算法相关(绿色)操作,构建出整个计算图的过程: - 在科学计算领域,计算图是一种描述计算的经典方式。下图展示了从前向计算图(蓝色)开始,通过添加反向(红色)和优化算法相关(绿色)操作,构建出整个计算图的过程:
- -
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/graph_construction_example_all.png" width=60%> <img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/graph_construction_example_all.png" width=60%>
</p> </p>
- Fluid ==使用`Program`而不是计算图==来描述模型和优化过程。`Program``Block``Operator``Variable`构成,相关概念会在后文详细展开。 - Fluid ==使用`Program`而不是计算图==来描述模型和优化过程。`Program``Block``Operator``Variable`构成,相关概念会在后文详细展开。
- 编译时 Fluid 接受前向计算(这里可以先简单的理解为是一段有序的计算流)`Program`,为这段前向计算按照:前向 -> 反向 -> 梯度 clip -> 正则 -> 优化 的顺序,添加相关 `Operator``Variable``Program`到完整的计算。 - 编译时 Fluid 接受前向计算(这里可以先简单的理解为是一段有序的计算流)`Program`,为这段前向计算按照:前向 -> 反向 -> 梯度 clip -> 正则 -> 优化 的顺序,添加相关 `Operator``Variable``Program`到完整的计算。
...@@ -328,7 +328,7 @@ ...@@ -328,7 +328,7 @@
</font> </font>
--- ---
### 编译时概念 :==**[Transpiler](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/motivation/fluid_compiler.md)**== ### 编译时概念 :==**[Transpiler](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/motivation/fluid_compiler.md)**==
<font size=5> <font size=5>
...@@ -402,7 +402,7 @@ ...@@ -402,7 +402,7 @@
- `Scope` - `Scope`
- 计算相关 - 计算相关
- `Block` - `Block`
- `Kernel``OpWithKernel``OpWithoutKernel` - `Kernel``OpWithKernel``OpWithoutKernel`
<table> <table>
...@@ -439,7 +439,7 @@ ...@@ -439,7 +439,7 @@
</tbody> </tbody>
</table> </table>
- 执行相关 :`Executor` - 执行相关 :`Executor`
</font> </font>
...@@ -798,7 +798,7 @@ class GPUAllocator : public SystemAllocator { ...@@ -798,7 +798,7 @@ class GPUAllocator : public SystemAllocator {
- step 1:添加Place类型,<span style="background-color:#DAB1D5;">由用户实现添加到框架</span> - step 1:添加Place类型,<span style="background-color:#DAB1D5;">由用户实现添加到框架</span>
- 可以将Place类型理解为一个整数加上一个枚举型,包括:设备号 + 设备类型 - 可以将Place类型理解为一个整数加上一个枚举型,包括:设备号 + 设备类型
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/place.png" width=40%> <img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/place.png" width=40%>
</p> </p>
...@@ -824,7 +824,7 @@ class GPUAllocator : public SystemAllocator { ...@@ -824,7 +824,7 @@ class GPUAllocator : public SystemAllocator {
1. DataType 执行数据类型 FP32/FP64/INT32/INT64 1. DataType 执行数据类型 FP32/FP64/INT32/INT64
1. Memory layout: 运行时 Tensor 在内存中的排布格式 NCHW、 NHWC 1. Memory layout: 运行时 Tensor 在内存中的排布格式 NCHW、 NHWC
1. 使用的库 1. 使用的库
来区分Kernel,为同一个operator注册多个 Kernel。 来区分Kernel,为同一个operator注册多个 Kernel。
```cpp ```cpp
...@@ -876,7 +876,7 @@ step 3: 运行时的 KernelType 推断和Kernel切换,<span style="background- ...@@ -876,7 +876,7 @@ step 3: 运行时的 KernelType 推断和Kernel切换,<span style="background-
namespace framework { namespace framework {
using LoDTensorArray = std::vector<LoDTensor>; using LoDTensorArray = std::vector<LoDTensor>;
} }
} }
``` ```
- 每一次循环,从原始输入中“切出”一个片段 - 每一次循环,从原始输入中“切出”一个片段
- LoDTensorArray 在Python端暴露,是Fluid支持的基础数据结构之一,用户可以直接创建并使用 - LoDTensorArray 在Python端暴露,是Fluid支持的基础数据结构之一,用户可以直接创建并使用
...@@ -910,7 +910,7 @@ void Run(const framework::Scope &scope, ...@@ -910,7 +910,7 @@ void Run(const framework::Scope &scope,
false /*create_local_scope*/); false /*create_local_scope*/);
} }
} }
``` ```
</font> </font>
...@@ -951,7 +951,7 @@ void Run(const framework::Scope &scope, ...@@ -951,7 +951,7 @@ void Run(const framework::Scope &scope,
--- ---
#### dynamicRNN 中的 Memory #### dynamicRNN 中的 Memory
<font size=5> <font size=5>
...@@ -961,7 +961,7 @@ void Run(const framework::Scope &scope, ...@@ -961,7 +961,7 @@ void Run(const framework::Scope &scope,
- `memory` 在 operator A 前向计算之后,进行前向计算 - `memory` 在 operator A 前向计算之后,进行前向计算
-`memory` 的前向计算会 "指向" A 的输出 LoDTensor -`memory` 的前向计算会 "指向" A 的输出 LoDTensor
- `memory` 的输出可以是另一个 operator 的输入,于是形成了“循环”连接 - `memory` 的输出可以是另一个 operator 的输入,于是形成了“循环”连接
</font> </font>
--- ---
...@@ -1107,7 +1107,7 @@ void Run(const framework::Scope &scope, ...@@ -1107,7 +1107,7 @@ void Run(const framework::Scope &scope,
<td> <td>
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid_module_1.png" width=60%> <img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid_module_1.png" width=60%>
</p> </p>
</td> </td>
<td> <td>
<p align="center"> <p align="center">
...@@ -1127,13 +1127,13 @@ void Run(const framework::Scope &scope, ...@@ -1127,13 +1127,13 @@ void Run(const framework::Scope &scope,
<font size=5> <font size=5>
- 设计概览 - 设计概览
- 重构概览 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/refactorization.md) - 重构概览 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/refactorization.md)
- fluid [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/fluid.md) - fluid [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/fluid.md)
- fluid_compiler [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/motivation/fluid_compiler.md) - fluid_compiler [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/motivation/fluid_compiler.md)
- 核心概念 - 核心概念
- variable 描述 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/var_desc.md) - variable 描述 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/var_desc.md)
- Tensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.md) - Tensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.md)
- LoDTensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) - LoDTensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
- TensorArray [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md) - TensorArray [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md)
- Program [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md) - Program [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md)
- Block [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md) - Block [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md)
...@@ -1152,7 +1152,7 @@ void Run(const framework::Scope &scope, ...@@ -1152,7 +1152,7 @@ void Run(const framework::Scope &scope,
- 支持新设硬件设备库 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md) - 支持新设硬件设备库 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md)
- 添加新的Operator [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_cn.md) - 添加新的Operator [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_cn.md)
- 添加新的Kernel [->]( - 添加新的Kernel [->](
https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_en.md) https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_en.md)
</font> </font>
...@@ -1167,10 +1167,10 @@ https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_ ...@@ -1167,10 +1167,10 @@ https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_
<font size=5> <font size=5>
Docker编译PaddlePaddle源码: [->](http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/docker_install_cn.html) Docker编译PaddlePaddle源码: [->](http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/docker_install_cn.html)
PaddlePaddle 在 Dockerhub 地址:[->]( PaddlePaddle 在 Dockerhub 地址:[->](
https://hub.docker.com/r/paddlepaddle/paddle/tags/) https://hub.docker.com/r/paddlepaddle/paddle/tags/)
1. 获取PaddlePaddle的Docker镜像 1. 获取PaddlePaddle的Docker镜像
```bash ```bash
docker pull paddlepaddle/paddle:latest-dev docker pull paddlepaddle/paddle:latest-dev
...@@ -1183,7 +1183,7 @@ PaddlePaddle 在 Dockerhub 地址:[->]( ...@@ -1183,7 +1183,7 @@ PaddlePaddle 在 Dockerhub 地址:[->](
``` ```
1. 进入docker container后,从源码编译,请参考文档 [->]( http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/build_from_source_cn.html) 1. 进入docker container后,从源码编译,请参考文档 [->]( http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/build_from_source_cn.html)
</font> </font>
--- ---
...@@ -1196,7 +1196,7 @@ PaddlePaddle 在 Dockerhub 地址:[->]( ...@@ -1196,7 +1196,7 @@ PaddlePaddle 在 Dockerhub 地址:[->](
1. 开发推荐使用tag为`latest-dev`的镜像,其中打包了所有编译依赖。`latest``lastest-gpu`是production镜像,主要用于运行PaddlePaddle程序。 1. 开发推荐使用tag为`latest-dev`的镜像,其中打包了所有编译依赖。`latest``lastest-gpu`是production镜像,主要用于运行PaddlePaddle程序。
2. 在Docker中运行GPU程序,推荐使用nvidia-docker,[否则需要将CUDA库和设备挂载到Docker容器内](http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/docker_install_cn.html) 2. 在Docker中运行GPU程序,推荐使用nvidia-docker,[否则需要将CUDA库和设备挂载到Docker容器内](http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/docker_install_cn.html)
<font size=4> <font size=4>
```bash ```bash
nvidia-docker run -it -v $PWD/Paddle:/paddle paddlepaddle/paddle:latest-dev /bin/bash nvidia-docker run -it -v $PWD/Paddle:/paddle paddlepaddle/paddle:latest-dev /bin/bash
``` ```
...@@ -1353,9 +1353,9 @@ Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实 ...@@ -1353,9 +1353,9 @@ Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实
} }
}; };
``` ```
</font> </font>
--- ---
###### 实现带Kernel的Operator <span style="background-color:#c4e1e1;">step2</span>: 定义Operator类 ###### 实现带Kernel的Operator <span style="background-color:#c4e1e1;">step2</span>: 定义Operator类
...@@ -1420,11 +1420,11 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -1420,11 +1420,11 @@ class ClipOp : public framework::OperatorWithKernel {
2. override InferShape函数(参考 [clip_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.cc#L24) 2. override InferShape函数(参考 [clip_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.cc#L24)
1. 什么是`functor` ? 1. 什么是`functor` ?
- 类或结构体仅重载了`()`,一般是可被多个kernel复用的计算函数。 - 类或结构体仅重载了`()`,一般是可被多个kernel复用的计算函数。
<font size=4> <font size=4>
```cpp ```cpp
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CPUDeviceContext, T> { class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
...@@ -1438,9 +1438,9 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -1438,9 +1438,9 @@ class ClipOp : public framework::OperatorWithKernel {
}; };
``` ```
</font> </font>
- 在 clip_op 内也会看到将一段计算函数抽象为functor的使用法: [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.h#L27)。 - 在 clip_op 内也会看到将一段计算函数抽象为functor的使用法: [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.h#L27)。
</font> </font>
--- ---
...@@ -1504,7 +1504,7 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -1504,7 +1504,7 @@ class ClipKernel : public framework::OpKernel<T> {
- 需要注意,<span style="background-color:#e1c4c4;">Fluid中,不区分Cost Op和中间层Op,所有Op都必须正确处理接收到的梯度</span> - 需要注意,<span style="background-color:#e1c4c4;">Fluid中,不区分Cost Op和中间层Op,所有Op都必须正确处理接收到的梯度</span>
2. 反向Op的输出 2. 反向Op的输出
- 对可学习参数的求导结果 - 对可学习参数的求导结果
- 对所有输入的求导结果 - 对所有输入的求导结果
</font> </font>
...@@ -1520,7 +1520,7 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -1520,7 +1520,7 @@ class ClipKernel : public framework::OpKernel<T> {
1.`.cc`文件中注册前向、反向Op类,注册CPU Kernel。 1.`.cc`文件中注册前向、反向Op类,注册CPU Kernel。
<font size=4> <font size=4>
```cpp ```cpp
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad, REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
...@@ -1530,13 +1530,13 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -1530,13 +1530,13 @@ class ClipKernel : public framework::OpKernel<T> {
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>); clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>);
``` ```
- 在上面的代码片段中: - 在上面的代码片段中:
1. `REGISTER_OP` : 注册`ops::ClipOp`类,类型名为`clip`,该类的`ProtoMaker`为`ops::ClipOpMaker`,注册`ops::ClipOpGrad`,类型名为`clip_grad` 1. `REGISTER_OP` : 注册`ops::ClipOp`类,类型名为`clip`,该类的`ProtoMaker`为`ops::ClipOpMaker`,注册`ops::ClipOpGrad`,类型名为`clip_grad`
1. `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op,例如:优化算法相关的Op 1. `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op,例如:优化算法相关的Op
1. `REGISTER_OP_CPU_KERNEL` :注册`ops::ClipKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::ClipGradKernel`类 1. `REGISTER_OP_CPU_KERNEL` :注册`ops::ClipKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::ClipGradKernel`类
</font> </font>
1. 按照同样方法,在`.cu`文件中注册GPU Kernel 1. 按照同样方法,在`.cu`文件中注册GPU Kernel
- <span style="background-color:#e1c4c4;">如果CUDA Kernel的实现基于Eigen,需在 `.cu`的开始加上宏定义 `#define EIGEN_USE_GPU` </span> - <span style="background-color:#e1c4c4;">如果CUDA Kernel的实现基于Eigen,需在 `.cu`的开始加上宏定义 `#define EIGEN_USE_GPU` </span>
...@@ -1593,7 +1593,7 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -1593,7 +1593,7 @@ class ClipKernel : public framework::OpKernel<T> {
```bash ```bash
make test ARGS="-R test_mul_op -V" make test ARGS="-R test_mul_op -V"
``` ```
或者: 或者:
``` ```
...@@ -1613,7 +1613,7 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -1613,7 +1613,7 @@ class ClipKernel : public framework::OpKernel<T> {
- 如果多个Op依赖一些共用的函数,可以创建非`*_op.*`格式的文件来存放,如`gather.h`文件。 - 如果多个Op依赖一些共用的函数,可以创建非`*_op.*`格式的文件来存放,如`gather.h`文件。
</font> </font>
--- ---
### ==10.== 使用相关问题 ### ==10.== 使用相关问题
...@@ -1735,7 +1735,7 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -1735,7 +1735,7 @@ class ClipKernel : public framework::OpKernel<T> {
y_data = np.random.randint(0, 8, [1]).astype("int32") y_data = np.random.randint(0, 8, [1]).astype("int32")
y_tensor = core.Tensor() y_tensor = core.Tensor()
y_tensor.set(y_data, place) y_tensor.set(y_data, place)
x_data = np.random.uniform(0.1, 1, [11, 8]).astype("float32") x_data = np.random.uniform(0.1, 1, [11, 8]).astype("float32")
x_tensor = core.Tensor() x_tensor = core.Tensor()
x_tensor.set(x_data, place) x_tensor.set(x_data, place)
......
...@@ -17,3 +17,4 @@ ...@@ -17,3 +17,4 @@
:maxdepth: 1 :maxdepth: 1
concepts/use_concepts_cn.rst concepts/use_concepts_cn.rst
developer's_guide_to_paddle_fluid.md
...@@ -16,3 +16,4 @@ Here is an example of linear regression. It introduces workflow of PaddlePaddle, ...@@ -16,3 +16,4 @@ Here is an example of linear regression. It introduces workflow of PaddlePaddle,
:maxdepth: 1 :maxdepth: 1
concepts/index_en.rst concepts/index_en.rst
developer's_guide_to_paddle_fluid.md
...@@ -11,7 +11,7 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14. ...@@ -11,7 +11,7 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
pip install paddlepaddle pip install paddlepaddle
如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: 如果需要安装支持GPU的版本(cuda8.0_cudnn5_avx_openblas),需要执行:
.. code-block:: bash .. code-block:: bash
...@@ -28,18 +28,18 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14. ...@@ -28,18 +28,18 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
import paddle.dataset.uci_housing as uci_housing import paddle.dataset.uci_housing as uci_housing
import paddle.fluid as fluid import paddle.fluid as fluid
with fluid.scope_guard(fluid.core.Scope()): with fluid.scope_guard(fluid.core.Scope()):
# initialize executor with cpu # initialize executor with cpu
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
# load inference model # load inference model
[inference_program, feed_target_names,fetch_targets] = \ [inference_program, feed_target_names,fetch_targets] = \
fluid.io.load_inference_model(uci_housing.fluid_model(), exe) fluid.io.load_inference_model(uci_housing.fluid_model(), exe)
# run inference # run inference
result = exe.run(inference_program, result = exe.run(inference_program,
feed={feed_target_names[0]: uci_housing.predict_reader()}, feed={feed_target_names[0]: uci_housing.predict_reader()},
fetch_list=fetch_targets) fetch_list=fetch_targets)
# print predicted price is $12,273.97 # print predicted price is $12,273.97
print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000) print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000)
执行 :code:`python housing.py` 瞧! 它应该打印出预测住房数据的清单。 执行 :code:`python housing.py` 瞧! 它应该打印出预测住房数据的清单。
...@@ -12,7 +12,7 @@ Simply run the following command to install, the version is cpu_avx_openblas: ...@@ -12,7 +12,7 @@ Simply run the following command to install, the version is cpu_avx_openblas:
pip install paddlepaddle pip install paddlepaddle
If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run: If you need to install GPU version (cuda8.0_cudnn5_avx_openblas), run:
.. code-block:: bash .. code-block:: bash
...@@ -31,18 +31,18 @@ code: ...@@ -31,18 +31,18 @@ code:
import paddle.dataset.uci_housing as uci_housing import paddle.dataset.uci_housing as uci_housing
import paddle.fluid as fluid import paddle.fluid as fluid
with fluid.scope_guard(fluid.core.Scope()): with fluid.scope_guard(fluid.core.Scope()):
# initialize executor with cpu # initialize executor with cpu
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
# load inference model # load inference model
[inference_program, feed_target_names,fetch_targets] = \ [inference_program, feed_target_names,fetch_targets] = \
fluid.io.load_inference_model(uci_housing.fluid_model(), exe) fluid.io.load_inference_model(uci_housing.fluid_model(), exe)
# run inference # run inference
result = exe.run(inference_program, result = exe.run(inference_program,
feed={feed_target_names[0]: uci_housing.predict_reader()}, feed={feed_target_names[0]: uci_housing.predict_reader()},
fetch_list=fetch_targets) fetch_list=fetch_targets)
# print predicted price is $12,273.97 # print predicted price is $12,273.97
print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000) print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000)
Run :code:`python housing.py` and voila! It should print out a list of predictions Run :code:`python housing.py` and voila! It should print out a list of predictions
......
...@@ -4,5 +4,5 @@ ...@@ -4,5 +4,5 @@
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
inference/index_cn.rst
optimization/index_cn.rst optimization/index_cn.rst
inference/inference_support_in_fluid.md
...@@ -5,4 +5,3 @@ HOW TO ...@@ -5,4 +5,3 @@ HOW TO
:maxdepth: 1 :maxdepth: 1
optimization/index_en.rst optimization/index_en.rst
inference/inference_support_in_fluid.md
安装与编译C++预测库
===========================
直接下载安装
-------------
====================== ========================================
版本说明 C++预测库
====================== ========================================
cpu_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/fluid.tgz>`_
cpu_avx_openblas `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxOpenblas/.lastSuccessful/fluid.tgz>`_
cpu_noavx_openblas `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/fluid.tgz>`_
cuda7.5_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda8.0_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda8.0_cudnn7_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/fluid.tgz>`_
====================== ========================================
从源码编译
----------
用户也可以从 PaddlePaddle 核心代码编译C++预测库,只需在编译时配制下面这些编译选项:
================= =========
选项 值
================= =========
CMAKE_BUILD_TYPE Release
FLUID_INSTALL_DIR 安装路径
WITH_FLUID_ONLY ON(推荐)
WITH_SWIG_PY OFF(推荐
WITH_PYTHON OFF(推荐)
WITH_GPU ON/OFF
WITH_MKL ON/OFF
================= =========
建议按照推荐值设置,以避免链接不必要的库。其它可选编译选项按需进行设定。
下面的代码片段从github拉取最新代码,配制编译选项(需要将PADDLE_ROOT替换为PaddlePaddle预测库的安装路径):
.. code-block:: bash
pip install paddlepaddle-gpu
PADDLE_ROOT=/path/of/capi
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir build
cd build
cmake -DFLUID_INSTALL_DIR=$PADDLE_ROOT \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_FLUID_ONLY=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_PYTHON=OFF \
-DWITH_MKL=OFF \
-DWITH_GPU=OFF \
..
make
make inference_lib_dist
成功编译后,使用C++预测库所需的依赖(包括:(1)编译出的PaddlePaddle预测库和头文件;(2)第三方链接库和头文件;(3)版本信息与编译选项信息)
均会存放于PADDLE_ROOT目录中。目录结构如下:
.. code-block:: text
PaddleRoot/
├── CMakeCache.txt
├── paddle
│   └── fluid
│   ├── framework
│   ├── inference
│   ├── memory
│   ├── platform
│   ├── pybind
│   └── string
├── third_party
│   ├── boost
│   │   └── boost
│   ├── eigen3
│   │   ├── Eigen
│   │   └── unsupported
│   └── install
│   ├── gflags
│   ├── glog
│   ├── mklml
│   ├── protobuf
│   ├── snappy
│   ├── snappystream
│   └── zlib
└── version.txt
version.txt 中记录了该预测库的版本信息,包括Git Commit ID、使用OpenBlas或MKL数学库、CUDA/CUDNN版本号,如:
.. code-block:: text
GIT COMMIT ID: c95cd4742f02bb009e651a00b07b21c979637dc8
WITH_MKL: ON
WITH_GPU: ON
CUDA version: 8.0
CUDNN version: v5
预测库
------------
.. toctree::
:maxdepth: 1
build_and_install_lib_cn.rst
inference_support_in_fluid_cn.md
# Fluid Inference使用指南 # 使用指南
## 目录: ## 目录:
- Python Inference API - Python Inference API
- 编译Fluid Inference库
- Inference C++ API - Inference C++ API
- Inference实例 - Inference实例
- Inference计算优化 - Inference计算优化
...@@ -55,62 +54,6 @@ ...@@ -55,62 +54,6 @@
return [program, feed_target_names, fetch_targets] return [program, feed_target_names, fetch_targets]
``` ```
## 编译Fluid Inference库
- **不需要额外的CMake选项**
- 1、 配置CMake命令,更多配置请参考[源码编译PaddlePaddle](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/build_from_source_cn.html)
```bash
$ git clone https://github.com/PaddlePaddle/Paddle.git
$ cd Paddle
$ mkdir build
$ cd build
$ cmake -DCMAKE_INSTALL_PREFIX=your/path/to/paddle_inference_lib \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_PYTHON=ON \
-DWITH_MKL=OFF \
-DWITH_GPU=OFF \
..
```
- 2、 编译PaddlePaddle
```bash
$ make
```
- 3、 部署。执行如下命令将PaddlePaddle Fluid Inference库部署到`your/path/to/paddle_inference_lib`目录。
```bash
$ make inference_lib_dist
```
- 目录结构
```bash
$ cd your/path/to/paddle_inference_lib
$ tree
.
|-- paddle
| `-- fluid
| |-- framework
| |-- inference
| | |-- io.h
| | `-- libpaddle_fluid.so
| |-- memory
| |-- platform
| `-- string
|-- third_party
| |-- eigen3
| `-- install
| |-- gflags
| |-- glog
| `-- protobuf
`-- ...
```
假设`PADDLE_ROOT=your/path/to/paddle_inference_lib`
## 链接Fluid Inference库 ## 链接Fluid Inference库
- 示例项目([链接](https://github.com/luotao1/fluid_inference_example.git)) - 示例项目([链接](https://github.com/luotao1/fluid_inference_example.git))
......
...@@ -17,46 +17,33 @@ if(APPLE) ...@@ -17,46 +17,33 @@ if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE) endif(APPLE)
function(inference_api_test TARGET_NAME TEST_SRC) function(inference_api_test TARGET_NAME)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs ARGS) set(multiValueArgs ARGS)
cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
set(arg_list "") cc_test(test_paddle_inference_${TARGET_NAME}
SRCS test_paddle_inference_${TARGET_NAME}.cc
DEPS paddle_fluid_api paddle_inference_api
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
if(inference_test_ARGS) if(inference_test_ARGS)
foreach(arg ${inference_test_ARGS}) set_tests_properties(test_paddle_inference_${TARGET_NAME}
list(APPEND arg_list "_${arg}") PROPERTIES DEPENDS "${inference_test_ARGS}")
endforeach()
else()
list(APPEND arg_list "_")
endif() endif()
foreach(arg ${arg_list})
string(REGEX REPLACE "^_$" "" arg "${arg}")
cc_test(${TARGET_NAME}
SRCS ${TEST_SRC}
DEPS paddle_fluid_api paddle_inference_api paddle_inference_api_impl
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
# TODO(panyx0178): Figure out how to add word2vec and image_classification
# as deps.
# set_tests_properties(${TARGET_NAME}
# PROPERTIES DEPENDS ${DEP_TEST})
endforeach()
endfunction(inference_api_test) endfunction(inference_api_test)
cc_library(paddle_inference_api cc_library(paddle_inference_api
SRCS paddle_inference_api.cc SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
cc_library(paddle_inference_api_impl if(WITH_TESTING)
SRCS paddle_inference_api_impl.cc cc_test(test_paddle_inference_api
DEPS paddle_inference_api paddle_fluid_api) SRCS test_paddle_inference_api.cc
DEPS paddle_inference_api)
cc_test(test_paddle_inference_api inference_api_test(api_impl
SRCS test_paddle_inference_api.cc ARGS test_word2vec test_image_classification)
DEPS paddle_inference_api) endif()
inference_api_test(test_paddle_inference_api_impl
test_paddle_inference_api_impl.cc)
...@@ -40,15 +40,24 @@ struct PaddleBuf { ...@@ -40,15 +40,24 @@ struct PaddleBuf {
struct PaddleTensor { struct PaddleTensor {
std::string name; // variable name. std::string name; // variable name.
std::vector<int> shape; std::vector<int> shape;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
PaddleBuf data; // blob of data. PaddleBuf data; // blob of data.
PaddleDType dtype; PaddleDType dtype;
}; };
enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility.
// TODO(Superjomn) support following engines latter.
// kAnakin, // Use Anakin for inference.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
/* /*
* A simple Inference API for Paddle. Currently this API might just be used by * A simple Inference API for Paddle. Currently this API can be used by
* non-sequence scenerios. * non-sequence scenerios.
* TODO(Superjomn) Prepare another API for NLP-related usages. */
*/
class PaddlePredictor { class PaddlePredictor {
public: public:
struct Config; struct Config;
...@@ -66,34 +75,35 @@ class PaddlePredictor { ...@@ -66,34 +75,35 @@ class PaddlePredictor {
// be thread-safe. // be thread-safe.
virtual std::unique_ptr<PaddlePredictor> Clone() = 0; virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
virtual bool InitShared() { return false; }
// Destroy the Predictor. // Destroy the Predictor.
virtual ~PaddlePredictor() {} virtual ~PaddlePredictor() {}
friend std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
const PaddlePredictor::Config& config);
// The common configs for all the predictors. // The common configs for all the predictors.
struct Config { struct Config {
enum class EngineKind;
std::string model_dir; // path to the model directory. std::string model_dir; // path to the model directory.
bool enable_engine{false}; // Enable to execute (part of) the model on bool enable_engine{false}; // Enable to execute (part of) the model on
// third-party engines.
EngineKind engine_kind{Config::EngineKind::kNone};
enum class EngineKind {
kNone = -1, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
kTensorRT, // Use TensorRT for inference.
kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
}; };
}; };
// A factory to help create difference predictor. struct NativeConfig : public PaddlePredictor::Config {
template <typename ConfigT> // GPU related fields.
bool use_gpu{false};
int device{0};
float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization.
std::string prog_file;
std::string param_file;
};
// A factory to help create different predictors.
//
// FOR EXTENSION DEVELOPER:
// Different predictors are designated by config type and engine kind. Similar
// configs can be merged, but there shouldn't be a huge config containing
// different fields for more than one kind of predictors.
//
// Similarly, each engine kind should map to a unique predictor implementation.
template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config); std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
} // namespace paddle } // namespace paddle
...@@ -54,11 +54,10 @@ std::string num2str(T a) { ...@@ -54,11 +54,10 @@ std::string num2str(T a) {
} }
} // namespace } // namespace
bool PaddlePredictorImpl::Init() { bool NativePaddlePredictor::Init() {
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
// TODO(panyx0718): Should CPU vs GPU device be decided by id? if (config_.use_gpu) {
if (config_.device >= 0) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
...@@ -85,19 +84,21 @@ bool PaddlePredictorImpl::Init() { ...@@ -85,19 +84,21 @@ bool PaddlePredictorImpl::Init() {
} }
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
// Create variables // Create temporary variables first, so that the first batch do not need to
// TODO(panyx0718): Why need to test share_variables here? // create variables in the runtime. This is the logics of the old inference
if (config_.share_variables) { // API.
executor_->CreateVariables(*inference_program_, scope_.get(), 0); // TODO(Superjomn) this should be modified when `Clone` is valid for
} // multi-thread application.
executor_->CreateVariables(*inference_program_, scope_.get(), 0);
// Get the feed_target_names and fetch_target_names // Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames(); feed_target_names_ = inference_program_->GetFeedTargetNames();
fetch_target_names_ = inference_program_->GetFetchTargetNames(); fetch_target_names_ = inference_program_->GetFetchTargetNames();
return true; return true;
} }
bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs, bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) { std::vector<PaddleTensor> *output_data) {
VLOG(3) << "Predictor::predict"; VLOG(3) << "Predictor::predict";
Timer timer; Timer timer;
timer.tic(); timer.tic();
...@@ -124,7 +125,7 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs, ...@@ -124,7 +125,7 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
scope_.get(), scope_.get(),
&feed_targets, &feed_targets,
&fetch_targets, &fetch_targets,
!config_.share_variables); false /* don't create variable eatch time */);
if (!GetFetch(fetchs, output_data)) { if (!GetFetch(fetchs, output_data)) {
LOG(ERROR) << "fail to get fetchs"; LOG(ERROR) << "fail to get fetchs";
return false; return false;
...@@ -133,59 +134,20 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs, ...@@ -133,59 +134,20 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
return true; return true;
} }
std::unique_ptr<PaddlePredictor> PaddlePredictorImpl::Clone() { std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
VLOG(3) << "Predictor::clone"; VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new PaddlePredictorImpl(config_)); std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
if (!cls->InitShared()) {
LOG(ERROR) << "fail to call InitShared"; if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init()) {
LOG(ERROR) << "fail to call Init";
return nullptr; return nullptr;
} }
// fix manylinux compile error. // fix manylinux compile error.
return std::move(cls); return std::move(cls);
} }
// TODO(panyx0718): Consider merge with Init()? bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
bool PaddlePredictorImpl::InitShared() { std::vector<framework::LoDTensor> *feeds) {
VLOG(3) << "Predictor::init_shared";
// 1. Define place, executor, scope
if (this->config_.device >= 0) {
place_ = platform::CUDAPlace();
} else {
place_ = platform::CPUPlace();
}
this->executor_.reset(new framework::Executor(this->place_));
this->scope_.reset(new framework::Scope());
// Initialize the inference program
if (!this->config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
this->inference_program_ = inference::Load(
this->executor_.get(), this->scope_.get(), this->config_.model_dir);
} else if (!this->config_.prog_file.empty() &&
!this->config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
this->inference_program_ = inference::Load(this->executor_.get(),
this->scope_.get(),
this->config_.prog_file,
this->config_.param_file);
}
this->ctx_ = this->executor_->Prepare(*this->inference_program_, 0);
// 3. create variables
// TODO(panyx0718): why test share_variables.
if (config_.share_variables) {
this->executor_->CreateVariables(
*this->inference_program_, this->scope_.get(), 0);
}
// 4. Get the feed_target_names and fetch_target_names
this->feed_target_names_ = this->inference_program_->GetFeedTargetNames();
this->fetch_target_names_ = this->inference_program_->GetFetchTargetNames();
return true;
}
bool PaddlePredictorImpl::SetFeed(const std::vector<PaddleTensor> &inputs,
std::vector<framework::LoDTensor> *feeds) {
VLOG(3) << "Predictor::set_feed"; VLOG(3) << "Predictor::set_feed";
if (inputs.size() != feed_target_names_.size()) { if (inputs.size() != feed_target_names_.size()) {
LOG(ERROR) << "wrong feed input size."; LOG(ERROR) << "wrong feed input size.";
...@@ -213,7 +175,7 @@ bool PaddlePredictorImpl::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -213,7 +175,7 @@ bool PaddlePredictorImpl::SetFeed(const std::vector<PaddleTensor> &inputs,
return true; return true;
} }
bool PaddlePredictorImpl::GetFetch( bool NativePaddlePredictor::GetFetch(
const std::vector<framework::LoDTensor> &fetchs, const std::vector<framework::LoDTensor> &fetchs,
std::vector<PaddleTensor> *outputs) { std::vector<PaddleTensor> *outputs) {
VLOG(3) << "Predictor::get_fetch"; VLOG(3) << "Predictor::get_fetch";
...@@ -280,23 +242,29 @@ bool PaddlePredictorImpl::GetFetch( ...@@ -280,23 +242,29 @@ bool PaddlePredictorImpl::GetFetch(
} }
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor( std::unique_ptr<PaddlePredictor>
const ConfigImpl &config) { CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
VLOG(3) << "create PaddlePredictorImpl"; const NativeConfig &config) {
// 1. GPU memeroy VLOG(3) << "create NativePaddlePredictor";
std::vector<std::string> flags; if (config.use_gpu) {
if (config.fraction_of_gpu_memory >= 0.0f || // 1. GPU memeroy
config.fraction_of_gpu_memory <= 0.95f) { PADDLE_ENFORCE(
flags.push_back("dummpy"); config.fraction_of_gpu_memory > 0.f,
std::string flag = "--fraction_of_gpu_memory_to_use=" + "fraction_of_gpu_memory in the config should be set to range (0., 1.]");
num2str<float>(config.fraction_of_gpu_memory); std::vector<std::string> flags;
flags.push_back(flag); if (config.fraction_of_gpu_memory >= 0.0f ||
VLOG(3) << "set flag: " << flag; config.fraction_of_gpu_memory <= 0.95f) {
framework::InitGflags(flags); flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
num2str<float>(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
} }
std::unique_ptr<PaddlePredictor> predictor(new PaddlePredictorImpl(config)); std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
if (!dynamic_cast<PaddlePredictorImpl *>(predictor.get())->Init()) { if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init()) {
return nullptr; return nullptr;
} }
return std::move(predictor); return std::move(predictor);
......
...@@ -29,17 +29,10 @@ ...@@ -29,17 +29,10 @@
namespace paddle { namespace paddle {
struct ConfigImpl : public PaddlePredictor::Config { class NativePaddlePredictor : public PaddlePredictor {
int device;
float fraction_of_gpu_memory;
std::string prog_file;
std::string param_file;
bool share_variables;
};
class PaddlePredictorImpl : public PaddlePredictor {
public: public:
explicit PaddlePredictorImpl(const ConfigImpl &config) : config_(config) {} explicit NativePaddlePredictor(const NativeConfig &config)
: config_(config) {}
bool Init(); bool Init();
...@@ -48,16 +41,15 @@ class PaddlePredictorImpl : public PaddlePredictor { ...@@ -48,16 +41,15 @@ class PaddlePredictorImpl : public PaddlePredictor {
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
~PaddlePredictorImpl() override{}; ~NativePaddlePredictor() override{};
private: private:
bool InitShared() override;
bool SetFeed(const std::vector<PaddleTensor> &input_datas, bool SetFeed(const std::vector<PaddleTensor> &input_datas,
std::vector<framework::LoDTensor> *feeds); std::vector<framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs, bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
std::vector<PaddleTensor> *output_data); std::vector<PaddleTensor> *output_data);
ConfigImpl config_; NativeConfig config_;
platform::Place place_; platform::Place place_;
std::unique_ptr<framework::Executor> executor_; std::unique_ptr<framework::Executor> executor_;
std::unique_ptr<framework::Scope> scope_; std::unique_ptr<framework::Scope> scope_;
......
...@@ -40,19 +40,19 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { ...@@ -40,19 +40,19 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
return pt; return pt;
} }
ConfigImpl GetConfig() { NativeConfig GetConfig() {
ConfigImpl config; NativeConfig config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model"; config.model_dir = FLAGS_dirname + "word2vec.inference.model";
LOG(INFO) << "dirname " << config.model_dir; LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.15; config.fraction_of_gpu_memory = 0.15;
config.use_gpu = true;
config.device = 0; config.device = 0;
config.share_variables = true;
return config; return config;
} }
TEST(paddle_inference_api_impl, word2vec) { TEST(paddle_inference_api_impl, word2vec) {
ConfigImpl config = GetConfig(); NativeConfig config = GetConfig();
std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor<NativeConfig>(config);
framework::LoDTensor first_word, second_word, third_word, fourth_word; framework::LoDTensor first_word, second_word, third_word, fourth_word;
framework::LoD lod{{0, 1}}; framework::LoD lod{{0, 1}};
...@@ -104,7 +104,7 @@ TEST(paddle_inference_api_impl, image_classification) { ...@@ -104,7 +104,7 @@ TEST(paddle_inference_api_impl, image_classification) {
int batch_size = 2; int batch_size = 2;
bool use_mkldnn = false; bool use_mkldnn = false;
bool repeat = false; bool repeat = false;
ConfigImpl config = GetConfig(); NativeConfig config = GetConfig();
config.model_dir = config.model_dir =
FLAGS_dirname + "image_classification_resnet.inference.model"; FLAGS_dirname + "image_classification_resnet.inference.model";
...@@ -133,7 +133,7 @@ TEST(paddle_inference_api_impl, image_classification) { ...@@ -133,7 +133,7 @@ TEST(paddle_inference_api_impl, image_classification) {
is_combined, is_combined,
use_mkldnn); use_mkldnn);
std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor(config);
std::vector<PaddleTensor> paddle_tensor_feeds; std::vector<PaddleTensor> paddle_tensor_feeds;
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&input)); paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&input));
...@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) { ...@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) {
float* data = static_cast<float*>(outputs[0].data.data); float* data = static_cast<float*>(outputs[0].data.data);
float* lod_data = output1.data<float>(); float* lod_data = output1.data<float>();
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
EXPECT_LT(lod_data[j] - data[j], 1e-10); EXPECT_NEAR(lod_data[j], data[j], 1e-3);
EXPECT_GT(lod_data[j] - data[j], -1e-10);
} }
free(data); free(data);
} }
......
...@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) ...@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, prog, this)); ops_.emplace_back(new OpDesc(op_desc, this));
} }
} }
...@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ...@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
need_update_ = true; need_update_ = true;
for (auto &op : other.ops_) { for (auto &op : other.ops_) {
ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); ops_.emplace_back(new OpDesc(*op, this));
} }
for (auto &it : other.vars_) { for (auto &it : other.vars_) {
auto *var = new VarDesc(*it.second); auto *var = new VarDesc(*it.second);
......
...@@ -105,7 +105,7 @@ class BlockDesc { ...@@ -105,7 +105,7 @@ class BlockDesc {
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
OpDesc *Op(int idx) { return ops_.at(idx).get(); } OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
void Flush(); void Flush();
......
...@@ -11,11 +11,15 @@ ...@@ -11,11 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include <algorithm>
#include <fstream> #include <fstream>
#include <string>
#include <utility> #include <utility>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...@@ -26,9 +30,6 @@ ...@@ -26,9 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif #endif
#include <string>
#include <vector>
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10," "the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"); "default /tmp/graph.dot");
...@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType(); all_vars[var->Name()] = var;
} }
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size()); var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
...@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name);
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name); var_name_on_devices[cur_device_id].emplace(g_name);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
cur_device_id = (cur_device_id + 1) % places_.size();
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(var_types, g_name)) { if (IsSparseGradient(all_vars, g_name)) {
CreateReduceOp(&result, g_name, 0); CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0);
} else { } else {
...@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient( bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const { const std::string &og) const {
PADDLE_ENFORCE(var_types.count(og) != 0); PADDLE_ENFORCE(all_vars.count(og) != 0);
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) { if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true; return true;
} }
return false; return false;
......
...@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient( bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const; const std::string &og) const;
private: private:
......
...@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
need_update_ = true; need_update_ = true;
} }
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
......
...@@ -33,13 +33,14 @@ class OpDesc { ...@@ -33,13 +33,14 @@ class OpDesc {
OpDesc(const std::string &type, const VariableNameMap &inputs, OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block); OpDesc(const proto::OpDesc &desc, BlockDesc *block);
explicit OpDesc(BlockDesc *block) : block_(block) {} explicit OpDesc(BlockDesc *block) : block_(block) {}
OpDesc(const OpDesc &other, BlockDesc *block) { OpDesc(const OpDesc &other, BlockDesc *block) {
*this = other; *this = other;
block_ = block; block_ = block;
need_update_ = true;
} }
void CopyFrom(const OpDesc &op_desc); void CopyFrom(const OpDesc &op_desc);
......
...@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ...@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
auto *block = desc_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
} }
for (auto &block : blocks_) { for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) {
for (auto *op : block->AllOps()) { auto all_ops = blocks_[block_id]->AllOps();
for (const auto &attr : op->Proto()->attrs()) { for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
if (attr.type() == proto::AttrType::BLOCK) { auto &op = all_ops[op_id];
size_t blk_idx = attr.block_idx(); for (const std::string &attr_name : op->AttrNames()) {
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
int sub_block_id =
o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name);
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
} }
} }
} }
...@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { ...@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
}
} }
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
......
...@@ -25,8 +25,10 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) { ...@@ -25,8 +25,10 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
if (out->empty()) { if (out->empty()) {
return; return;
} }
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) { for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims(); auto &actual = (*out)[i].dims();
auto &expect = dims_[i]; auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size()); PADDLE_ENFORCE_EQ(actual.size(), expect.size());
......
...@@ -39,7 +39,7 @@ template <typename T> ...@@ -39,7 +39,7 @@ template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE(std::is_same<T, void>::value || PADDLE_ENFORCE(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code(), holder_->type() == std::type_index(typeid(T)),
"Tensor holds the wrong type, it holds %s", "Tensor holds the wrong type, it holds %s",
this->holder_->type().name()); this->holder_->type().name());
...@@ -53,7 +53,7 @@ template <typename T> ...@@ -53,7 +53,7 @@ template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE(std::is_same<T, void>::value || PADDLE_ENFORCE(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code(), holder_->type() == std::type_index(typeid(T)),
"Tensor holds the wrong type, it holds %s", "Tensor holds the wrong type, it holds %s",
this->holder_->type().name()); this->holder_->type().name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
......
...@@ -5,14 +5,19 @@ cc_library(paddle_fluid_api ...@@ -5,14 +5,19 @@ cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
# Create static library
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
cc_library(paddle_fluid DEPS ${fluid_modules})
if(WITH_CONTRIB)
set(fluid_modules "${fluid_modules}" paddle_inference_api)
endif()
# Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api)
# Create shared library # Create shared library
cc_library(paddle_fluid_shared SHARED cc_library(paddle_fluid_shared SHARED
SRCS io.cc SRCS io.cc
DEPS ${fluid_modules}) DEPS ${fluid_modules} paddle_fluid_api)
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
if(NOT APPLE) if(NOT APPLE)
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac. # TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
......
...@@ -21,7 +21,10 @@ limitations under the License. */ ...@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque> #include <deque>
#include <stack> #include <stack>
#include <string>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h" #include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
......
...@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) { ...@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG(INFO) << graph.nodes.size(); LOG(INFO) << graph.nodes.size();
} }
} // analysis }; // namespace analysis
} // inference }; // namespace inference
} // paddle }; // namespace paddle
...@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
......
...@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) { ...@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG(INFO) << '\n' << graph.DotString(); LOG(INFO) << '\n' << graph.DotString();
} }
} // analysis } // namespace analysis
} // inference } // namespace inference
} // paddle } // namespace paddle
...@@ -50,7 +50,7 @@ struct DataTypeNamer { ...@@ -50,7 +50,7 @@ struct DataTypeNamer {
return dic_.at(x); return dic_.at(x);
} }
const std::string &repr(size_t &hash) const { const std::string &repr(size_t &hash) const { // NOLINT
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation"); PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
return dic_.at(hash); return dic_.at(hash);
} }
...@@ -62,7 +62,9 @@ struct DataTypeNamer { ...@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE(float); SET_TYPE(float);
} }
std::unordered_map<decltype(typeid(int).hash_code()), std::string> dic_; std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
std::string>
dic_;
}; };
#undef SET_TYPE #undef SET_TYPE
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <iosfwd> #include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
......
...@@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ...@@ -8,3 +8,5 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
...@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter { ...@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr, nullptr); framework::OpDesc op_desc(op, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu"; "type is Relu";
const nvinfer1::ITensor* input_tensor = const nvinfer1::ITensor* input_tensor =
......
...@@ -21,7 +21,8 @@ namespace tensorrt { ...@@ -21,7 +21,8 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter { class Conv2dOpConverter : public OpConverter {
public: public:
Conv2dOpConverter() {} Conv2dOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
LOG(INFO) LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias"; << "convert a fluid conv2d op to tensorrt conv layer without bias";
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace tensorrt {
// Reorder the elements from istrides to ostrides, borrowed from TRT convert in
// tensorflow.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorrt/convert/convert_nodes.cc#L318
template <typename T>
void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
T* odata, nvinfer1::DimsHW ostrides) {
for (int h = 0; h < shape.h(); ++h) {
for (int w = 0; w < shape.w(); ++w) {
odata[h * ostrides.h() + w * ostrides.w()] =
idata[h * ostrides.h() + w * ostrides.w()];
}
}
}
// Reorder the data layout from CK to KC.
void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
TensorRTEngine::Weight* oweights) {
int c = iweights.dims[0];
int k = iweights.dims[1];
oweights->dims.assign({k, c});
nvinfer1::DimsHW istrides = {1, k};
nvinfer1::DimsHW ostrides = {c, 1};
Reorder2({k, c}, static_cast<float const*>(iweights.get().values), istrides,
static_cast<float*>(const_cast<void*>(oweights->get().values)),
ostrides);
}
/*
* FC converter convert a MUL op in Fluid to a FC layer in TRT.
*/
class FcOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias";
framework::OpDesc op_desc(op, nullptr, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
// Declare inputs
auto* X = engine_->GetITensor(op_desc.Input("X").front());
// Declare weights
auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
// This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, that can't be avoided.
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix
size_t n_output = Y_t->dims()[1];
framework::LoDTensor tmp;
tmp.Resize(Y_t->dims());
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), Y_t->data<float>(),
Y_t->dims()[0] * Y_t->dims()[1]);
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)};
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
static_cast<void*>(tmp.data<float>()),
Y_t->memory_size() / sizeof(float));
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
tmp_weight.dims = weight.dims;
// The data layout of TRT FC layer's weight is different from fluid's FC,
// need to reorder the elements.
ReorderCKtoKC(tmp_weight, &weight);
// Currently, the framework can only handle one fluid op -> one TRT layer,
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
// handle `mul`, leave `add` as another layer.
// DEBUG
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected,
*const_cast<nvinfer1::ITensor*>(X),
n_output, weight.get(), bias.get());
auto output_name = op_desc.Output("Out").front();
engine_->DeclareOutput(layer, 0, output_name);
}
};
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(mul);
...@@ -24,10 +24,11 @@ namespace tensorrt { ...@@ -24,10 +24,11 @@ namespace tensorrt {
class MulOpConverter : public OpConverter { class MulOpConverter : public OpConverter {
public: public:
MulOpConverter() {} MulOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op,
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias"; const framework::Scope& scope) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
......
...@@ -31,27 +31,42 @@ namespace tensorrt { ...@@ -31,27 +31,42 @@ namespace tensorrt {
class OpConverter { class OpConverter {
public: public:
OpConverter() {} OpConverter() {}
virtual void operator()(const framework::proto::OpDesc& op) {}
void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) { // Converter logic for an op.
std::string type = op.type(); virtual void operator()(const framework::proto::OpDesc& op,
auto* it = Registry<OpConverter>::Lookup(type); const framework::Scope& scope) {}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine); // Convert a single fluid operaotr and add the corresponding layer to TRT.
(*it)(op); void ConvertOp(const framework::proto::OpDesc& op,
} const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
framework::OpDesc op_desc(op, nullptr, nullptr);
OpConverter* it{nullptr};
// convert fluid op to tensorrt layer if (op_desc.Type() == "mul") {
void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) { PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
OpConverter::Run(op, engine); std::string Y = op_desc.Input("Y")[0];
if (parameters.count(Y)) {
it = Registry<OpConverter>::Lookup("fc");
}
}
if (!it) {
it = Registry<OpConverter>::Lookup(op_desc.Type());
}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
it->SetEngine(engine);
(*it)(op, scope);
} }
// convert fluid block to tensorrt network // convert fluid block to tensorrt network
void ConvertBlock(const framework::proto::BlockDesc& block, void ConvertBlock(const framework::proto::BlockDesc& block,
TensorRTEngine* engine) { const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
for (int i = 0; i < block.ops_size(); i++) { for (int i = 0; i < block.ops_size(); i++) {
const auto& op = block.ops(i); const auto& op = block.ops(i);
OpConverter::Run(op, engine); ConvertOp(op, parameters, scope, engine);
} }
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(fc_op, test) {
std::unordered_set<std::string> parameters({"mul-Y"});
framework::Scope scope;
TRTConvertValidation validator(20, parameters, scope, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1));
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {"mul-X"});
desc.SetInput("Y", {"mul-Y"});
desc.SetOutput("Out", {"mul-Out"});
validator.SetOp(*desc.Proto());
validator.Execute(10);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -21,7 +21,9 @@ namespace inference { ...@@ -21,7 +21,9 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(MulOpConverter, main) { TEST(MulOpConverter, main) {
TRTConvertValidation validator(10, 1000); framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6)); validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10)); validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10)); validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
......
...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) { ...@@ -27,7 +28,9 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d"); conv2d_op->SetType("conv2d");
OpConverter converter; OpConverter converter;
converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/); framework::Scope scope;
converter.ConvertBlock(*block->Proto(), {}, scope,
nullptr /*TensorRTEngine*/);
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
...@@ -58,7 +61,10 @@ class TRTConvertValidation { ...@@ -58,7 +61,10 @@ class TRTConvertValidation {
public: public:
TRTConvertValidation() = delete; TRTConvertValidation() = delete;
TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) { TRTConvertValidation(int batch_size,
const std::unordered_set<std::string>& parameters,
framework::Scope& scope, int workspace_size = 1 << 10)
: parameters_(parameters), scope_(scope) {
// create engine. // create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork(); engine_->InitNetwork();
...@@ -73,19 +79,22 @@ class TRTConvertValidation { ...@@ -73,19 +79,22 @@ class TRTConvertValidation {
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims); engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
} }
// Declare a parameter varaible in the scope.
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims); DeclVar(name, dims);
} }
// Declare a variable in a fluid Scope.
void DeclVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclVar(const std::string& name, const nvinfer1::Dims& dims) {
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// Init Fluid tensor. // Init Fluid tensor.
std::vector<int> dim_vec(dims.nbDims); std::vector<int> dim_vec(dims.d, dims.d + dims.nbDims);
for (int i = 0; i < dims.nbDims; i++) {
dim_vec[i] = dims.d[i];
}
auto* x = scope_.Var(name); auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>(); auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec)); x_tensor->Resize(framework::make_ddim(dim_vec));
...@@ -96,20 +105,22 @@ class TRTConvertValidation { ...@@ -96,20 +105,22 @@ class TRTConvertValidation {
op_ = framework::OpRegistry::CreateOp(desc); op_ = framework::OpRegistry::CreateOp(desc);
OpConverter op_converter; OpConverter op_converter;
op_converter.ConvertOp(desc, engine_.get()); op_converter.ConvertOp(desc, parameters_, scope_, engine_.get());
engine_->FreezeNetwork(); engine_->FreezeNetwork();
// Declare outputs. // Declare outputs.
op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr)); op_desc_.reset(new framework::OpDesc(desc, nullptr));
// Set Inputs. // Set Inputs.
for (const auto& input : op_desc_->InputArgumentNames()) { for (const auto& input : op_desc_->InputArgumentNames()) {
if (parameters_.count(input)) continue;
auto* var = scope_.FindVar(input); auto* var = scope_.FindVar(input);
PADDLE_ENFORCE(var); PADDLE_ENFORCE(var);
auto tensor = var->GetMutable<framework::LoDTensor>(); auto tensor = var->GetMutable<framework::LoDTensor>();
engine_->SetInputFromCPU( engine_->SetInputFromCPU(
input, static_cast<void*>(tensor->data<float>()), input, static_cast<void*>(tensor->data<void>()),
sizeof(float) * sizeof(float) *
analysis::AccuDims(tensor->dims(), tensor->dims().size())); analysis::AccuDims(tensor->dims(), tensor->dims().size()));
} }
...@@ -117,18 +128,21 @@ class TRTConvertValidation { ...@@ -117,18 +128,21 @@ class TRTConvertValidation {
void Execute(int batch_size) { void Execute(int batch_size) {
// Execute Fluid Op // Execute Fluid Op
// Execute TRT
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
engine_->Execute(batch_size);
op_->Run(scope_, place); op_->Run(scope_, place);
// Execute TRT.
engine_->Execute(batch_size);
cudaStreamSynchronize(*engine_->stream());
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 200;
for (const auto& output : op_desc_->OutputArgumentNames()) { for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out; std::vector<float> fluid_out;
std::vector<float> trt_out(200); std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float)); engine_->GetOutputInCPU(output, &trt_out[0],
output_space_size * sizeof(float));
cudaStreamSynchronize(*engine_->stream());
auto* var = scope_.FindVar(output); auto* var = scope_.FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>(); auto tensor = var->GetMutable<framework::LoDTensor>();
...@@ -136,7 +150,7 @@ class TRTConvertValidation { ...@@ -136,7 +150,7 @@ class TRTConvertValidation {
// Compare two output // Compare two output
ASSERT_FALSE(fluid_out.empty()); ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) { for (size_t i = 0; i < fluid_out.size(); i++) {
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001); EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 1e-6);
} }
} }
} }
...@@ -146,9 +160,10 @@ class TRTConvertValidation { ...@@ -146,9 +160,10 @@ class TRTConvertValidation {
private: private:
std::unique_ptr<TensorRTEngine> engine_; std::unique_ptr<TensorRTEngine> engine_;
cudaStream_t stream_; cudaStream_t stream_;
framework::Scope scope_;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_; std::unique_ptr<framework::OpDesc> op_desc_;
const std::unordered_set<std::string>& parameters_;
framework::Scope& scope_;
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -106,6 +106,7 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, ...@@ -106,6 +106,7 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
name); name);
auto* output = layer->getOutput(offset); auto* output = layer->getOutput(offset);
SetITensor(name, output);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
infer_network_->markOutput(*output); infer_network_->markOutput(*output);
......
...@@ -37,13 +37,15 @@ class TensorRTEngine : public EngineBase { ...@@ -37,13 +37,15 @@ class TensorRTEngine : public EngineBase {
// Weight is model parameter. // Weight is model parameter.
class Weight { class Weight {
public: public:
Weight(nvinfer1::DataType dtype, void* value, int num_elem) { Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) {
w_.type = dtype; w_.type = dtype;
w_.values = value; w_.values = value;
w_.count = num_elem; w_.count = num_elem;
} }
const nvinfer1::Weights& get() { return w_; } const nvinfer1::Weights& get() { return w_; }
std::vector<int64_t> dims;
private: private:
nvinfer1::Weights w_; nvinfer1::Weights w_;
}; };
......
...@@ -34,9 +34,22 @@ class BilinearInterpOp : public framework::OperatorWithKernel { ...@@ -34,9 +34,22 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
int out_w = ctx->Attrs().Get<int>("out_w"); int out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
if (ctx->HasInput("OutSize")) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w}); std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
}
}; };
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -45,6 +58,10 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,6 +58,10 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of bilinear interpolation, " "(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"); "This is a 4-D tensor with shape of (N x C x h x w)");
AddInput("OutSize",
"(Tensor) This is a 1-D tensor with two number. "
"The first number is height and the second number is width.")
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"(Tensor) The dimension of output is (N x C x out_h x out_w]"); "(Tensor) The dimension of output is (N x C x out_h x out_w]");
...@@ -78,6 +95,12 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel { ...@@ -78,6 +95,12 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x); ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -102,10 +102,21 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> { ...@@ -102,10 +102,21 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
auto* input_t = ctx.Input<Tensor>("X"); // float tensor auto* input_t = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
auto* input = input_t->data<T>(); auto* input = input_t->data<T>();
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
auto out_dims = output_t->dims();
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0]; int batch_size = input_t->dims()[0];
int channels = input_t->dims()[1]; int channels = input_t->dims()[1];
int in_h = input_t->dims()[2]; int in_h = input_t->dims()[2];
...@@ -139,8 +150,8 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -139,8 +150,8 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X")); auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto* d_output = d_output_t->data<T>(); auto* d_output = d_output_t->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto& device_ctx = auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
...@@ -149,6 +160,16 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -149,6 +160,16 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
int batch_size = d_input_t->dims()[0]; int batch_size = d_input_t->dims()[0];
int channels = d_input_t->dims()[1]; int channels = d_input_t->dims()[1];
int in_h = d_input_t->dims()[2]; int in_h = d_input_t->dims()[2];
......
...@@ -24,11 +24,18 @@ class BilinearInterpKernel : public framework::OpKernel<T> { ...@@ -24,11 +24,18 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_t = ctx.Input<Tensor>("X"); // float tensor auto* input_t = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
auto out_dims = output_t->dims();
auto* input = input_t->data<T>(); auto* input = input_t->data<T>();
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0]; int batch_size = input_t->dims()[0];
int channels = input_t->dims()[1]; int channels = input_t->dims()[1];
int in_h = input_t->dims()[2]; int in_h = input_t->dims()[2];
...@@ -83,9 +90,8 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -83,9 +90,8 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X")); auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto* d_output = d_output_t->data<T>(); auto* d_output = d_output_t->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto& device_ctx = auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>(); ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero; math::SetConstant<platform::CPUDeviceContext, T> zero;
...@@ -93,6 +99,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -93,6 +99,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
int batch_size = d_input_t->dims()[0]; int batch_size = d_input_t->dims()[0];
int channels = d_input_t->dims()[1]; int channels = d_input_t->dims()[1];
int in_h = d_input_t->dims()[2]; int in_h = d_input_t->dims()[2];
......
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
......
...@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
} }
bool RPCClient::Wait() { bool RPCClient::Wait() {
VLOG(3) << "RPCClient begin Wait()"
<< " req_count_:" << req_count_;
if (req_count_ <= 0) { if (req_count_ <= 0) {
return true; return true;
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits> #include <limits>
#include <string> #include <string>
using ::grpc::ServerAsyncResponseWriter; #include "paddle/fluid/operators/detail/grpc_server.h"
DEFINE_int32(rpc_server_handle_send_threads, 20, using ::grpc::ServerAsyncResponseWriter;
"Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
"Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
"Number of threads used to handle prefetch at rpc server.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH }; ...@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
class RequestBase { class RequestBase {
public: public:
explicit RequestBase(GrpcService::AsyncService* service, explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
const platform::DeviceContext* dev_ctx) RequestHandler* request_handler, int req_id)
: service_(service), : service_(service),
cq_(cq), cq_(cq),
sync_mode_(sync_mode),
status_(PROCESS), status_(PROCESS),
dev_ctx_(dev_ctx) { request_handler_(request_handler),
req_id_(req_id) {
PADDLE_ENFORCE(cq_); PADDLE_ENFORCE(cq_);
} }
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() { assert(false); } virtual void Process() = 0;
CallStatus Status() { return status_; } CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; } void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { virtual std::string GetReqName() = 0;
assert(false);
return "";
}
protected: protected:
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_; GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
const bool sync_mode_;
CallStatus status_; CallStatus status_;
const platform::DeviceContext* dev_ctx_; RequestHandler* request_handler_;
int req_id_;
}; };
class RequestSend final : public RequestBase { class RequestSend final : public RequestBase {
public: public:
explicit RequestSend(GrpcService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, ReceivedQueue* queue, RequestHandler* request_handler, int req_id)
const platform::DeviceContext* dev_ctx, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
: RequestBase(service, cq, sync_mode, dev_ctx), request_.reset(new VariableResponse(request_handler->scope(),
queue_(queue), request_handler->dev_ctx(),
responder_(&ctx_), !request_handler->sync_mode()));
req_id_(req_id) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase { ...@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
virtual ~RequestSend() {} virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_->Varname(); } std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string varname = GetReqName();
VLOG(3) << "RequestSend var_name:" << varname;
virtual void Process() { auto scope = request_->GetMutableLocalScope();
std::string var_name = GetReqName(); auto invar = request_->GetVar();
VLOG(3) << "RequestSend " << var_name; framework::Variable* outvar = nullptr;
queue_->Push(std::make_pair(var_name, request_));
request_handler_->Handle(varname, scope, invar, &outvar);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
...@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase { ...@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
protected: protected:
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
int req_id_;
}; };
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(GrpcService::AsyncService* service, explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, RequestHandler* request_handler, int req_id)
const platform::DeviceContext* dev_ctx, : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
framework::BlockingQueue<MessageWithName>* queue,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
queue_(queue),
req_id_(req_id) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_, method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestGet() {} virtual ~RequestGet() {}
virtual std::string GetReqName() { return request_.varname(); } std::string GetReqName() override { return request_.varname(); }
virtual void Process() { void Process() override {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string varname = request_.varname();
VLOG(3) << "RequestGet " << var_name; VLOG(3) << "RequestGet " << varname;
auto* var = scope_->FindVar(var_name);
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
if (var_name != FETCH_BARRIER_MESSAGE) { request_handler_->Handle(varname, scope, invar, &outvar);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_);
} }
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
if (var_name == FETCH_BARRIER_MESSAGE) {
sendrecv::VariableMessage msg;
MessageWithName msg_with_name = std::make_pair(var_name, msg);
queue_->Push(msg_with_name);
}
} }
protected: protected:
sendrecv::VariableMessage request_; sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::BlockingQueue<MessageWithName>* queue_;
int req_id_;
}; };
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
public: public:
explicit RequestPrefetch(GrpcService::AsyncService* service, explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, RequestHandler* request_handler, int req_id)
const platform::DeviceContext* dev_ctx, : RequestBase(service, cq, request_handler, req_id),
framework::Executor* executor,
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), local_scope_(nullptr) {
executor_(executor), request_.reset(new VariableResponse(request_handler->scope(),
program_(program), request_handler->dev_ctx(), true));
prefetch_ctx_(prefetch_ctx),
req_id_(req_id) {
// prefetch always create a new sub scope
request_.reset(new VariableResponse(scope, dev_ctx_, true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestPrefetch() {} virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_->Varname(); } std::string GetReqName() override { return request_->Varname(); }
virtual void Process() { void Process() override {
// prefetch process... // prefetch process...
std::string varname = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
std::string var_name = request_->OutVarname(); request_handler_->Handle(varname, scope, invar, &outvar);
VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = request_->GetMutableLocalScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, local_scope);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
...@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase { ...@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* local_scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int req_id_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) {
int fetch_barriers = 0;
while (fetch_barriers < count) {
auto msg = var_get_queue_.Pop();
if (msg.first == FETCH_BARRIER_MESSAGE) {
fetch_barriers++;
}
}
}
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
} }
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::StartServer() {
::grpc::ServerBuilder builder; ::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(), builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
&selected_port_); &selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_); builder.RegisterService(&service_);
cq_send_ = builder.AddCompletionQueue(); for (auto t : rpc_call_map_) {
cq_get_ = builder.AddCompletionQueue(); rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
cq_prefetch_ = builder.AddCompletionQueue(); }
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ LOG(INFO) << "Server listening on " << bind_address_
<< " selected port: " << selected_port_; << " selected port: " << selected_port_;
std::function<void(int)> send_register = std::bind( std::function<void(const std::string&, int)> f =
&AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1); std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
std::function<void(int)> get_register = std::bind( std::placeholders::_1, std::placeholders::_2);
&AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
std::function<void(int)> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
std::placeholders::_1);
for (int i = 0; i < kSendReqsBufSize; ++i) { for (auto& t : rpc_call_map_) {
TryToRegisterNewSendOne(i); auto& rpc_name = t.first;
} auto& cq = rpc_cq_[rpc_name];
for (int i = 0; i < kGetReqsBufSize; ++i) { auto threadnum = rpc_thread_num_[rpc_name];
TryToRegisterNewGetOne(i); auto& reqs = rpc_reqs_[rpc_name];
}
for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
TryToRegisterNewPrefetchOne(i);
}
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { reqs.reserve(kRequestBufSize);
t_sends_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, for (int i = 0; i < kRequestBufSize; i++) {
cq_send_.get(), "cq_send", send_register))); TryToRegisterNewOne(rpc_name, i);
} }
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_.emplace_back( for (int i = 0; i < threadnum; i++) {
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
cq_get_.get(), "cq_get", get_register))); &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
} VLOG(3) << t.first << " creates threads!";
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { }
t_prefetchs_.emplace_back(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
} }
{ {
std::lock_guard<std::mutex> lock(this->mutex_ready_); std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1; ready_ = 1;
} }
condition_ready_.notify_all(); condition_ready_.notify_all();
// wait server // wait server
server_->Wait(); server_->Wait();
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_[i]->join(); for (auto& t : rpc_threads_) {
} auto& threads = t.second;
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { for (size_t i = 0; i < threads.size(); ++i) {
t_gets_[i]->join(); threads[i]->join();
} VLOG(3) << t.first << " threads ends!";
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { }
t_prefetchs_[i]->join();
} }
} }
void AsyncGRPCServer::ShutdownQueue() { void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_); for (auto& t : rpc_cq_) {
cq_send_->Shutdown(); t.second->Shutdown();
cq_get_->Shutdown(); VLOG(3) << t.first << " shutdown!";
cq_prefetch_->Shutdown(); }
} }
// This URL explains why shutdown is complicate: void AsyncGRPCServer::ShutDownImpl() {
void AsyncGRPCServer::ShutDown() { std::unique_lock<std::mutex> lock(cq_mutex_);
is_shut_down_ = true; is_shut_down_ = true;
ShutdownQueue(); ShutdownQueue();
VLOG(3) << "server_ shutdown!";
server_->Shutdown(); server_->Shutdown();
} }
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
scope_, &var_recv_queue_, dev_ctx_, i);
send_reqs_[i] = static_cast<RequestBase*>(send);
VLOG(4) << "Create RequestSend status:" << send->Status();
}
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) { VLOG(4) << "register send rpc_name:" << rpc_name
std::unique_lock<std::mutex> lock(cq_mutex_); << ", handler:" << rpc_call_map_[kRequestSend];
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; auto& reqs = rpc_reqs_[rpc_name];
return; auto& handler = rpc_call_map_[rpc_name];
auto& cq = rpc_cq_[rpc_name];
RequestBase* b = nullptr;
if (rpc_name == kRequestSend) {
b = new RequestSend(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not surpported rpc");
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_, req_id);
get_reqs_[req_id] = static_cast<RequestBase*>(get);
VLOG(4) << "Create RequestGet status:" << get->Status();
}
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { reqs[req_id] = b;
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return;
}
RequestPrefetch* prefetch = new RequestPrefetch(
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
program_, prefetch_ctx_.get(), req_id);
prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestSend status:" << b->Status();
} }
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest( void AsyncGRPCServer::HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& cq_name, ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(int)> TryToRegisterNewOne) { std::function<void(const std::string&, int)> TryToRegisterNewOne) {
void* tag = NULL; void* tag = NULL;
bool ok = false; bool ok = false;
while (true) { while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " wait Next"; VLOG(3) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!"; LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!";
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
if (sync_mode_) { int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
// FIXME(typhoonzero): de-couple the barriers with recv_op VLOG(3) << "HandleRequest " << rpc_name << ", req_id:" << req_id
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); << " get next";
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
}
auto& reqs = rpc_reqs_[rpc_name];
RequestBase* base = nullptr; RequestBase* base = nullptr;
{ {
std::lock_guard<std::mutex> l(cq_mutex_); PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
if (cq_name == "cq_get") { std::unique_lock<std::mutex> lock(cq_mutex_);
base = get_reqs_[req_id]; base = reqs[req_id];
} else if (cq_name == "cq_send") {
base = send_reqs_[req_id];
} else if (cq_name == "cq_prefetch") {
base = prefetch_reqs_[req_id];
}
} }
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) { if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name[" LOG(WARNING) << "completion queue:" << rpc_name
<< " recv no regular event:argument name["
<< base->GetReqName() << "]"; << base->GetReqName() << "]";
TryToRegisterNewOne(req_id); TryToRegisterNewOne(rpc_name, req_id);
delete base; delete base;
continue; continue;
} }
VLOG(3) << "queue id:" << rpc_name << ", req_id:" << req_id
<< ", status:" << base->Status();
switch (base->Status()) { switch (base->Status()) {
case PROCESS: { case PROCESS: {
base->Process(); base->Process();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
break; break;
} }
case FINISH: { case FINISH: {
TryToRegisterNewOne(req_id); TryToRegisterNewOne(rpc_name, req_id);
VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base; delete base;
break; break;
} }
...@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest( ...@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
} }
} }
void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
barrier_condition_.wait(lock,
[=] { return this->barrier_cond_step_ == cond; });
}
void AsyncGRPCServer::SetCond(int cond) {
{
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
barrier_condition_.notify_all();
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <set>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
...@@ -28,6 +30,8 @@ limitations under the License. */ ...@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h" #include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...@@ -37,106 +41,48 @@ namespace paddle { ...@@ -37,106 +41,48 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef framework::BlockingQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase; class RequestBase;
class AsyncGRPCServer final { class AsyncGRPCServer final : public RPCServer {
public: public:
explicit AsyncGRPCServer(const std::string &address, bool sync_mode) explicit AsyncGRPCServer(const std::string& address, int client_num)
: address_(address), sync_mode_(sync_mode), ready_(0) {} : RPCServer(address, client_num), ready_(0) {}
~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
int GetSelectedPort() const { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const std::string &msg_name) { virtual ~AsyncGRPCServer() {}
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); void WaitServerReady() override;
} void StartServer() override;
void ShutDown(); private:
void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne);
protected: void TryToRegisterNewOne(const std::string& rpc_name, int req_id);
void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(int req_id);
void ShutdownQueue(); void ShutdownQueue();
void ShutDownImpl() override;
private: private:
static const int kSendReqsBufSize = 100; static const int kRequestBufSize = 100;
static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
std::string address_;
const bool sync_mode_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
framework::BlockingQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
// condition of the sub program // condition of the sub program
std::mutex barrier_mutex_; std::mutex barrier_mutex_;
mutable int barrier_cond_step_; mutable int barrier_cond_step_;
std::condition_variable barrier_condition_; std::condition_variable barrier_condition_;
std::vector<std::unique_ptr<std::thread>> t_sends_;
std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
std::mutex mutex_ready_; std::mutex mutex_ready_;
std::condition_variable condition_ready_; std::condition_variable condition_ready_;
int ready_; int ready_;
std::map<std::string, std::unique_ptr<::grpc::ServerCompletionQueue>> rpc_cq_;
std::map<std::string, std::vector<std::unique_ptr<std::thread>>> rpc_threads_;
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
}; };
}; // namespace detail }; // namespace detail
......
...@@ -24,13 +24,16 @@ limitations under the License. */ ...@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_; std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
...@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, ...@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
} }
} }
void StartServer(const std::string& endpoint) { void StartServer() {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
...@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) { ...@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto prepared = exe.Prepare(program, block->ID()); auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10); InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program); g_req_handler->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(std::move(prepared)); g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
rpc_service_->SetDevCtx(&ctx); g_req_handler->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope); g_req_handler->SetScope(&scope);
rpc_service_->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
rpc_service_->RunSyncUpdate(); // FIXME(gongwb): don't use hard time.
sleep(10);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join();
} }
TEST(PREFETCH, DISABLED_CPU) { TEST(PREFETCH, CPU) {
// start up a server instance backend g_req_handler.reset(new detail::RequestPrefetchHandler(true));
std::thread server_thread(StartServer, "127.0.0.1:8889"); g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
sleep(2);
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
detail::RPCClient client;
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope {
int64_t rows_numel = 5; // create var on local scope
InitTensorsOnClient(&scope, &place, rows_numel); int64_t rows_numel = 5;
std::string in_var_name("ids"); InitTensorsOnClient(&scope, &place, rows_numel);
std::string out_var_name("out"); std::string in_var_name("ids");
std::string out_var_name("out");
auto client = detail::RPCClient::GetInstance();
client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
out_var_name); client.Wait();
client->Wait(); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto var = scope.Var(out_var_name); auto ptr = value.mutable_data<float>(place);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place); for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
rpc_service_->ShutDown(); }
server_thread.join();
rpc_service_.reset(nullptr);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
} }
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
} }
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
class RPCServer;
class RequestHandler {
public:
explicit RequestHandler(bool sync_mode)
: sync_mode_(sync_mode),
dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr),
rpc_server_(nullptr) {}
virtual ~RequestHandler() {}
// Set attributes.
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
grad_to_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var,
framework::Variable** outvar) = 0;
protected:
const bool sync_mode_;
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle {
namespace operators {
namespace detail {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestSendHandler:" << varname;
// Async
if (!sync_mode_) {
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true;
}
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
rpc_server_->WaitCond(kRequestSend);
}
if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname;
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
sparse_vars_.push_back(invar);
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
if (sync_mode_) {
rpc_server_->WaitCond(kRequestGet);
}
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
return true;
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle {
namespace operators {
namespace detail {
void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown ";
ShutDownImpl();
exit_flag_ = true;
barrier_cond_.notify_all();
rpc_cond_.notify_all();
}
void RPCServer::SavePort() const {
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_;
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
if (b >= client_num_) {
barrier_cond_.notify_all();
}
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
for (auto& t : barrier_counter_) {
t.second = 0;
}
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
rpc_call_map_[rpc_name] = handler;
rpc_thread_num_[rpc_name] = thread_num;
static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond;
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler
<< ", cond:" << rpc_cond_map_[rpc_name];
}
void RPCServer::SetCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer SetCond " << rpc_name;
{
std::unique_lock<std::mutex> lock(mutex_);
cur_cond_ = rpc_cond_map_[rpc_name];
}
rpc_cond_.notify_all();
}
void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
cond = rpc_cond_map_[rpc_name];
}
std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCServer {
public:
explicit RPCServer(const std::string& address, int client_num)
: cur_cond_(0),
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
virtual void WaitServerReady() = 0;
void ShutDown();
bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; }
void SavePort() const;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void WaitBarrier(const std::string& rpc_name);
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
protected:
virtual void ShutDownImpl() = 0;
private:
std::mutex mutex_;
std::unordered_map<std::string, int> barrier_counter_;
std::condition_variable barrier_cond_;
std::unordered_map<std::string, int> rpc_cond_map_;
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;
const int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
friend class RequestHandler;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
...@@ -67,8 +67,8 @@ class VariableResponse { ...@@ -67,8 +67,8 @@ class VariableResponse {
framework::Scope* GetMutableLocalScope() const { return local_scope_; } framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
......
...@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel {
auto index_dims = ctx->GetInputDim("Index"); auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1); PADDLE_ENFORCE(index_dims.size() == 1);
int batch_size = ctx->GetInputDim("Index")[0]; int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx->GetInputDim("X")); framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims); ctx->SetOutputDim("Out", output_dims);
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
...@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
detail::AsyncGRPCServer rpc_service(endpoint, true); detail::RequestSendHandler rpc_h(true);
detail::AsyncGRPCServer rpc_service(endpoint, 1);
rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(&rpc_service);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
rpc_service.SetScope(scope); rpc_h.SetScope(scope);
rpc_service.SetDevCtx(&dev_ctx); rpc_h.SetDevCtx(&dev_ctx);
rpc_service.SetProgram(&empty_program); rpc_h.SetProgram(&empty_program);
rpc_service.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service)); std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service));
rpc_service.SetCond(0); rpc_service.SetCond(detail::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service.Get(); rpc_service.WaitBarrier(detail::kRequestSend);
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown(); rpc_service.ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
......
...@@ -19,14 +19,16 @@ limitations under the License. */ ...@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { void RunServer(std::shared_ptr<detail::RPCServer> service) {
service->RunSyncUpdate(); service->StartServer();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
...@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks( ...@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
std::atomic_int ListenAndServOp::selected_port_{0};
ListenAndServOp::ListenAndServOp(const std::string &type, ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
...@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type, ...@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp::~ListenAndServOp() { Stop(); } ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() { void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown(); rpc_service_->ShutDown();
server_thread_->join(); server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
...@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() { ...@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void ListenAndServOp::SavePort() const { void ListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port // NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort(); rpc_service_->SavePort();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
} }
void ListenAndServOp::RunSyncLoop(framework::Executor *executor, void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const { framework::BlockDesc *prefetch_block) const {
auto fan_in = Attr<int>("Fanin");
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
...@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared.begin(), optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
bool exit_flag = false; rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that // Record received sparse variables, so that
// we could reset those after execute optimize program // we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars; std::vector<framework::Variable *> sparse_vars;
while (!exit_flag && !SignalHandler::IsProgramExit()) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(detail::kRequestSend);
size_t recv_var_cnt = 0; rpc_service_->WaitBarrier(detail::kRequestSend);
int batch_barrier = 0;
while (batch_barrier != fan_in) { if (rpc_service_->IsExit()) {
const detail::ReceivedMessage v = rpc_service_->Get(); LOG(WARNING) << "get exit!rpc_processor break!";
auto recv_var_name = v.first; rpc_service_->SetCond(detail::kRequestGet);
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
}
}
if (exit_flag) {
rpc_service_->SetCond(1);
rpc_service_->ShutDown();
break; break;
} }
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work. // and this will still work.
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future // TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = program->Block(1).Parent();
...@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(detail::kRequestGet);
// FIXME(typhoonzero): use another condition to sync wait clients get. rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->WaitClientGet(fan_in); rpc_service_->ResetBarrierCounter();
sparse_vars.clear();
} // while(true) } // while(true)
} }
static void AsyncUpdateThread(
const std::string &var_name, const bool &exit_flag,
const std::shared_ptr<detail::ReceivedQueue> &queue,
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = queue->Pop();
if (SignalHandler::IsProgramExit()) {
VLOG(3) << "update thread for " << var_name << " exit";
break;
}
auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope());
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
fs.wait();
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const { framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in"; VLOG(3) << "RunAsyncLoop in";
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id; std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
grad_to_queue;
auto grad_to_block_id_str = auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id"); Attr<std::vector<std::string>>("grad_to_block_id");
...@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id; grad_to_block_id[pieces[0]] = block_id;
std::shared_ptr<detail::ReceivedQueue> queue =
std::make_shared<detail::ReceivedQueue>();
grad_to_queue[pieces[0]] = queue;
// record blocking queue in SignalHandler
SignalHandler::RegisterBlockingQueue(queue);
id_to_grad[block_id] = pieces[0]; id_to_grad[block_id] = pieces[0];
} }
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
...@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
} }
bool exit_flag = false; request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
VLOG(3) << "start async optimize threads";
std::vector<std::future<void>> fs;
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
std::string grad_name = iter->first;
VLOG(3) << "create async update thread for " << grad_name;
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
&grad_to_queue, &grad_to_prepared_ctx]() {
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
executor, grad_to_prepared_ctx[grad_name].get());
}));
}
VLOG(3) << "RunAsyncLoop into while"; VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag && !SignalHandler::IsProgramExit()) { while (true) {
const detail::ReceivedMessage v = rpc_service_->Get(); if (rpc_service_->IsExit()) {
auto recv_var_name = v.first; LOG(INFO) << "get exit!rpc_processor break!";
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break; break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
grad_to_queue[recv_var_name]->Push(v);
} }
if (exit_flag) { sleep(1);
rpc_service_->ShutDown();
break;
}
} // while(true) } // while(true)
} }
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(std::move(
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx)));
h->SetRPCServer(rpc_server);
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer. // Mark this as PS that it should decide profiling by listening from trainer.
...@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock); auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare rpc_service
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(program);
rpc_service_->SetExecutor(&executor);
// prepare for prefetch // prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID(); VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(),
rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGTERM, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid()));
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
...@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
bool SignalHandler::program_exit_flag_ = false;
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
void SignalHandler::StopAndExit(int signal_num) { void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit"; VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
exit(0);
program_exit_flag_ = true;
// awake all blocking queues
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
iter != blocking_queue_set_.end(); iter++) {
iter->get()->Push(
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
}
exit(EXIT_SUCCESS);
}
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
blocking_queue_set_.insert(queue);
} }
} // namespace operators } // namespace operators
......
...@@ -23,7 +23,8 @@ limitations under the License. */ ...@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,7 +32,7 @@ namespace operators { ...@@ -31,7 +32,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock"; constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
...@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void SavePort() const; void SavePort() const;
void WaitServerReady(); int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
int GetSelectedPort() { return selected_port_; }
void Stop() override; void Stop() override;
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected: protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_; mutable std::shared_ptr<detail::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
// FIXME(wuyi): it's static so that the operator can be cloned.
static std::atomic_int selected_port_;
}; };
class SignalHandler { class SignalHandler {
public:
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
public: public:
static void StopAndExit(int signal_num); static void StopAndExit(int signal_num);
static void RegisterBlockingQueue(BlockingQueue&);
static inline bool IsProgramExit() { return program_exit_flag_; }
private: private:
static bool program_exit_flag_;
static BlockingQueueSet blocking_queue_set_;
DISABLE_COPY_AND_ASSIGN(SignalHandler); DISABLE_COPY_AND_ASSIGN(SignalHandler);
}; };
......
...@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto rpc_client = detail::RPCClient::GetInstance(); auto rpc_client = detail::RPCClient::GetInstance();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) { if (sync_mode) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ShapeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of get_shape op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output (Out) of get_shape op should not be null.");
auto in_dim = ctx->GetInputDim("Input");
ctx->SetOutputDim("Out", {in_dim.size()});
}
};
class ShapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(Tensor), The input tensor.");
AddOutput("Out", "(Tensor), The shape of input tensor.");
AddComment(R"DOC(
Shape Operator.
Get the shape of input tensor.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int64_t>,
ops::ShapeKernel<float>, ops::ShapeKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ShapeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
auto in_dims = in_t->dims();
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}
};
} // namespace operators
} // namespace paddle
...@@ -31,8 +31,9 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare( ...@@ -31,8 +31,9 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
auto max_workspace = context.Attr<int>("max_workspace"); auto max_workspace = context.Attr<int>("max_workspace");
engine_.reset(new inference::tensorrt::TensorRTEngine( engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_, max_workspace, nullptr)); max_batch_, max_workspace, nullptr));
// TODO(Superjomn) parameters should be passed after analysised from outside.
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock( inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block, engine_.get()); block, {}, context.scope(), engine_.get());
engine_->FreezeNetwork(); engine_->FreezeNetwork();
} }
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -35,42 +37,44 @@ namespace m = paddle::operators::math; ...@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
namespace string = paddle::string; namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service; std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
void StartServer(std::atomic<bool>* initialized) { void StartServer() {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
scope.Var(NCCL_ID_VARNAME); scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace()); auto& dev_ctx = *pool.Get(p::CPUPlace());
rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
f::ProgramDesc empty_program; f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace()); f::Executor executor(dev_ctx.GetPlace());
rpc_service->SetScope(&scope); g_req_handler->SetScope(&scope);
rpc_service->SetDevCtx(&dev_ctx); g_req_handler->SetDevCtx(&dev_ctx);
rpc_service->SetProgram(&empty_program); g_req_handler->SetProgram(&empty_program);
rpc_service->SetExecutor(&executor); g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get())); std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
*initialized = true;
rpc_service->SetCond(0); g_rpc_service->SetCond(detail::kRequestSend);
auto recv = rpc_service->Get(); std::cout << "before WaitFanInOfSend" << std::endl;
g_rpc_service->WaitBarrier(detail::kRequestSend);
LOG(INFO) << "got nccl id and stop server..."; LOG(INFO) << "got nccl id and stop server...";
rpc_service->ShutDown(); g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
} }
TEST(SendNcclId, DISABLED_Normal) { TEST(SendNcclId, GrpcServer) {
std::atomic<bool> initialized{false}; g_req_handler.reset(new detail::RequestSendHandler(true));
std::thread server_thread(StartServer, &initialized); g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
while (!initialized) {
} std::thread server_thread(StartServer);
// wait server to start g_rpc_service->WaitServerReady();
// sleep(2);
rpc_service->WaitServerReady();
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
...@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) { ...@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto& dev_ctx = *pool.Get(p::CPUPlace()); auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var(NCCL_ID_VARNAME); auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id); p::dynload::ncclGetUniqueId(id);
int port = rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client; detail::RPCClient client;
LOG(INFO) << "connect to server" << ep;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait(); client.Wait();
client.AsyncSendBatchBarrier(ep);
client.Wait();
server_thread.join(); server_thread.join();
auto* ptr = rpc_service.release(); g_rpc_service.reset(nullptr);
delete ptr; g_req_handler.reset(nullptr);
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <stdio.h> #include <stdio.h>
#include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
......
...@@ -183,7 +183,7 @@ function build() { ...@@ -183,7 +183,7 @@ function build() {
============================================ ============================================
EOF EOF
make clean make clean
make -j `nproc` make install -j `nproc`
} }
function build_android() { function build_android() {
......
...@@ -82,6 +82,7 @@ __all__ = [ ...@@ -82,6 +82,7 @@ __all__ = [
'roi_pool', 'roi_pool',
'dice_loss', 'dice_loss',
'upsampling_bilinear2d', 'upsampling_bilinear2d',
'gather',
'random_crop', 'random_crop',
] ]
...@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): ...@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
def dice_loss(input, label, epsilon=0.00001): def dice_loss(input, label, epsilon=0.00001):
""" """
**Dice loss Layer**
Dice loss for comparing the similarity of two batch of data, Dice loss for comparing the similarity of two batch of data,
usually is used for binary image segmentation i.e. labels are binary. usually is used for binary image segmentation i.e. labels are binary.
The dice loss can be defined as below equation: The dice loss can be defined as below equation:
...@@ -3944,7 +3944,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): ...@@ -3944,7 +3944,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
input (Variable): The input tensor of bilinear interpolation, input (Variable): The input tensor of bilinear interpolation,
This is a 4-D tensor of the shape This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w). (num_batches, channels, in_h, in_w).
out_shape(list|tuple|None): Output shape of bilinear interpolation out_shape(list|tuple|Variable|None): Output shape of bilinear interpolation
layer, the shape is (out_h, out_w). layer, the shape is (out_h, out_w).
Default: None Default: None
scale(int|None): The multiplier for the input height or width. scale(int|None): The multiplier for the input height or width.
...@@ -3971,13 +3971,20 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): ...@@ -3971,13 +3971,20 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
def _is_list_or_turple_(data): def _is_list_or_turple_(data):
return (isinstance(data, list) or isinstance(data, tuple)) return (isinstance(data, list) or isinstance(data, tuple))
out_h = 0
out_w = 0
inputs = {"X": input}
if out_shape is not None: if out_shape is not None:
if not (_is_list_or_turple_(out_shape) and len(out_shape) == 2): if not (_is_list_or_turple_(out_shape) and len(out_shape) == 2) and (
out_shape is not Variable):
raise ValueError('out_shape should be a list or tuple ', raise ValueError('out_shape should be a list or tuple ',
'with length 2, (out_h, out_w).') 'with length 2, (out_h, out_w).')
out_shape = list(map(int, out_shape)) if _is_list_or_turple_(out_shape):
out_h = out_shape[0] out_shape = list(map(int, out_shape))
out_w = out_shape[1] out_h = out_shape[0]
out_w = out_shape[1]
else:
inputs['OutSize'] = out_shape
else: else:
out_h = int(input.shape[2] * scale) out_h = int(input.shape[2] * scale)
out_w = int(input.shape[3] * scale) out_w = int(input.shape[3] * scale)
...@@ -3985,13 +3992,62 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): ...@@ -3985,13 +3992,62 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
out = helper.create_tmp_variable(dtype) out = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type="bilinear_interp", type="bilinear_interp",
inputs={"X": input}, inputs=inputs,
outputs={"Out": out}, outputs={"Out": out},
attrs={"out_h": out_h, attrs={"out_h": out_h,
"out_w": out_w}) "out_w": out_w})
return out return out
def gather(input, index):
"""
Output is obtained by gathering entries of the outer-most dimension
of X indexed by `index` and concatenate them together.
.. math::
Out = X[Index]
.. code-block:: text
Given:
X = [[1, 2],
[3, 4],
[5, 6]]
Index = [1, 2]
Then:
Out = [[3, 4],
[5, 6]]
Args:
input (Variable): The source input with rank>=1.
index (Variable): The index input with rank=1.
Returns:
output (Variable): The output is a tensor with the same rank as input.
Examples:
.. code-block:: python
output = fluid.layers.gather(x, index)
"""
helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
helper.append_op(
type="gather",
inputs={"X": input,
"Index": index},
outputs={"Out": out})
return out
def random_crop(input, shape, seed=1): def random_crop(input, shape, seed=1):
helper = LayerHelper("random_crop", **locals()) helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
......
...@@ -71,6 +71,7 @@ __all__ = [ ...@@ -71,6 +71,7 @@ __all__ = [
'cumsum', 'cumsum',
'scatter', 'scatter',
'sum', 'sum',
'shape',
] + __activations__ ] + __activations__
for _OP in set(__all__): for _OP in set(__all__):
......
...@@ -17,7 +17,10 @@ import numpy as np ...@@ -17,7 +17,10 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def bilinear_interp_np(input, out_h, out_w): def bilinear_interp_np(input, out_h, out_w, out_size):
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
batch_size, channel, in_h, in_w = input.shape batch_size, channel, in_h, in_w = input.shape
if out_h > 1: if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0) ratio_h = (in_h - 1.0) / (out_h - 1.0)
...@@ -49,12 +52,15 @@ def bilinear_interp_np(input, out_h, out_w): ...@@ -49,12 +52,15 @@ def bilinear_interp_np(input, out_h, out_w):
class TestBilinearInterpOp(OpTest): class TestBilinearInterpOp(OpTest):
def setUp(self): def setUp(self):
self.out_size = None
self.init_test_case() self.init_test_case()
self.op_type = "bilinear_interp" self.op_type = "bilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32") input_np = np.random.random(self.input_shape).astype("float32")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w) output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
...@@ -68,6 +74,7 @@ class TestBilinearInterpOp(OpTest): ...@@ -68,6 +74,7 @@ class TestBilinearInterpOp(OpTest):
self.input_shape = [2, 3, 4, 4] self.input_shape = [2, 3, 4, 4]
self.out_h = 2 self.out_h = 2
self.out_w = 2 self.out_w = 2
self.out_size = np.array([3, 3]).astype("int32")
class TestCase1(TestBilinearInterpOp): class TestCase1(TestBilinearInterpOp):
...@@ -91,5 +98,29 @@ class TestCase3(TestBilinearInterpOp): ...@@ -91,5 +98,29 @@ class TestCase3(TestBilinearInterpOp):
self.out_w = 128 self.out_w = 128
class TestCase4(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestCase5(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestCase6(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -20,8 +20,9 @@ from op_test import OpTest ...@@ -20,8 +20,9 @@ from op_test import OpTest
class TestGatherOp(OpTest): class TestGatherOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "gather" self.op_type = "gather"
xnp = np.random.random((10, 20)).astype("float32") self.config()
self.inputs = {'X': xnp, 'Index': np.array([1, 3, 5]).astype("int32")} xnp = np.random.random(self.x_shape).astype("float32")
self.inputs = {'X': xnp, 'Index': np.array(self.index).astype("int32")}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self): def test_check_output(self):
...@@ -30,6 +31,16 @@ class TestGatherOp(OpTest): ...@@ -30,6 +31,16 @@ class TestGatherOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def config(self):
self.x_shape = (10, 20)
self.index = [1, 3, 5]
class TestCase1(TestGatherOp):
def config(self):
self.x_shape = (10)
self.index = [1, 3, 5]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestShapeOp(OpTest):
def setUp(self):
self.op_type = "shape"
self.config()
self.shape = [2, 3]
input = np.zeros(self.shape)
self.inputs = {'Input': input}
self.outputs = {'Out': np.array(self.shape)}
def config(self):
self.shape = [2, 3]
def test_check_output(self):
self.check_output()
class case1(TestShapeOp):
def config(self):
self.shape = [2]
class case2(TestShapeOp):
def config(self):
self.shape = [1, 2, 3]
if __name__ == '__main__':
unittest.main()
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import math import math
import unittest import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable from paddle.fluid.transpiler.distribute_transpiler import split_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import random import random
...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR, # dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape) shape=shape)
var_list.append(var) var_list.append(var)
blocks = split_dense_variable(var_list, 10, min_size) blocks = split_variable(var_list, 10, min_size)
all_sizes = [] all_sizes = []
for s in expected_sizes: for s in expected_sizes:
for s2 in s: for s2 in s:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from program_utils import *
from ufind import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
...@@ -11,6 +11,30 @@ ...@@ -11,6 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
"""
from __future__ import print_function from __future__ import print_function
...@@ -21,9 +45,11 @@ from .. import core, framework ...@@ -21,9 +45,11 @@ from .. import core, framework
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, \ default_startup_program, \
Variable, Parameter, grad_var_name Variable, Parameter, grad_var_name
from details import *
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
...@@ -40,62 +66,11 @@ class VarBlock: ...@@ -40,62 +66,11 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size) return "%s:%d:%d" % (self.varname, self.offset, self.size)
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
def same_or_split_var(p_name, var_name): def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block") return p_name == var_name or p_name.startswith(var_name + ".block")
def split_dense_variable(var_list, service_count, min_block_size=8192): def split_variable(var_list, service_count, min_block_size=8192):
""" """
We may need to split dense tensor to one or more blocks and put We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor them equally onto parameter server. One block is a sub-tensor
...@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192): ...@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
return blocks return blocks
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def _has_distributed_lookup_table(self):
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=RoundRobin,
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
# process lookup_table_op # process lookup_table_op
# 1. check all lookup_table_op is distributed # 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table. # 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = [] distributed_lookup_table_ops = []
# support only one distributed_lookup_table now # support only one distributed_lookup_table now
self.table_name = None self.table_name = None
for op in program.global_block().ops: for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True: if op.attrs['is_distributed'] is True:
if self.table_name is None: if self.table_name is None:
...@@ -246,20 +137,13 @@ class DistributeTranspiler: ...@@ -246,20 +137,13 @@ class DistributeTranspiler:
if self.table_name is not None: if self.table_name is not None:
assert op.input("W")[0] != self.table_name assert op.input("W")[0] != self.table_name
self.has_distributed_lookup_table = len( return len(distributed_lookup_table_ops) > 0
distributed_lookup_table_ops) > 0
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list = []
grad_list = []
for p, g in params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
# update self.table_param_grad and self.trainer_side_table_grad_list
program = self.origin_program
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
param_list = [ param_list = [
param for param in param_list if param.name != self.table_name param for param in param_list if param.name != self.table_name
...@@ -277,7 +161,7 @@ class DistributeTranspiler: ...@@ -277,7 +161,7 @@ class DistributeTranspiler:
self.trainer_side_table_grad_list = [ self.trainer_side_table_grad_list = [
program.global_block().create_var( program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" % name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index), (table_grad_var.name, self.trainer_id, index),
type=table_grad_var.type, type=table_grad_var.type,
shape=table_grad_var.shape, shape=table_grad_var.shape,
dtype=table_grad_var.dtype) dtype=table_grad_var.dtype)
...@@ -293,23 +177,41 @@ class DistributeTranspiler: ...@@ -293,23 +177,41 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints)) for index in range(len(self.pserver_endpoints))
] ]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) def _init_splited_vars(self, split_method):
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) # update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list = []
grad_list = []
for p, g in self.params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads)
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints))
param_blocks = split_variable(param_list, len(self.pserver_endpoints))
assert (len(grad_blocks) == len(param_blocks)) assert (len(grad_blocks) == len(param_blocks))
# step2: Create new vars for the parameters and gradients blocks and # origin_varname -> [splited_var]
# add ops to do the split. self.param_var_mapping = self._create_vars_from_blocklist(
param_var_mapping = self._create_vars_from_blocklist(program, self.origin_program, param_blocks)
param_blocks) self.grad_var_mapping = self._create_vars_from_blocklist(
grad_var_mapping = self._create_vars_from_blocklist( self.origin_program,
program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) grad_blocks,
grad_param_mapping = dict() add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks): for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":") g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":") p_name, p_bid, _ = p.split(":")
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)] self.param_var_mapping[p_name][int(p_bid)]
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program # create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict() self.param_grad_ep_mapping = dict()
...@@ -322,10 +224,50 @@ class DistributeTranspiler: ...@@ -322,10 +224,50 @@ class DistributeTranspiler:
}) for ep in self.pserver_endpoints }) for ep in self.pserver_endpoints
] ]
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=RoundRobin,
sync_mode=True):
"""
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
# split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars(split_method)
# step 3.1: insert send op to send gradient vars to parameter servers # step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset() ps_dispatcher.reset()
send_vars = [] send_vars = []
for orig_varname, splited_vars in grad_var_mapping.items(): for orig_varname, splited_vars in self.grad_var_mapping.items():
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) == 1: if len(splited_vars) == 1:
orig_varname = splited_vars[0].name orig_varname = splited_vars[0].name
...@@ -367,7 +309,7 @@ class DistributeTranspiler: ...@@ -367,7 +309,7 @@ class DistributeTranspiler:
# step 3.2: insert recv op to receive parameters from parameter server # step 3.2: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
for _, var in enumerate(send_vars): for _, var in enumerate(send_vars):
recv_vars.append(grad_param_mapping[var]) recv_vars.append(self.grad_param_mapping[var])
ps_dispatcher.reset() ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars) eplist = ps_dispatcher.dispatch(recv_vars)
...@@ -375,7 +317,7 @@ class DistributeTranspiler: ...@@ -375,7 +317,7 @@ class DistributeTranspiler:
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in self.param_var_mapping.iteritems():
eps = [] eps = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
...@@ -399,7 +341,7 @@ class DistributeTranspiler: ...@@ -399,7 +341,7 @@ class DistributeTranspiler:
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in self.param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
...@@ -440,7 +382,6 @@ class DistributeTranspiler: ...@@ -440,7 +382,6 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives. # we don't need to create them when grad arrives.
# change client side var name to origin name by # change client side var name to origin name by
# removing ".trainer_%d" suffix # removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
...@@ -477,24 +418,14 @@ class DistributeTranspiler: ...@@ -477,24 +418,14 @@ class DistributeTranspiler:
# located on current pserver # located on current pserver
opt_op_on_pserver = [] opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op): if self._is_optimizer_op(op) and self._is_opt_op_on_pserver(
endpoint, op):
opt_op_on_pserver.append(op) opt_op_on_pserver.append(op)
# step 3.3 # step 3.3
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
# We try to put optimization program run parallelly, assume
# optimization program always looks like:
#
# prevop -> prevop -> opt op -> following op -> following op; ->
# prevop -> prevop -> opt op -> following op -> following op; ->
# global op -> global op
#
# we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops = [] global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2 # HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine. # replace it with dependency engine.
...@@ -502,12 +433,18 @@ class DistributeTranspiler: ...@@ -502,12 +433,18 @@ class DistributeTranspiler:
if self._is_adam_connected_op(op): if self._is_adam_connected_op(op):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id): def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
if self._is_opt_op(op): if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program) self.origin_program, merged_var)
else: else:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op, endpoint)
def __op_have_grad_input__(op):
for varname in op.input_arg_names:
if varname.find("@GRAD") >= 0:
return varname
return ""
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
...@@ -515,17 +452,26 @@ class DistributeTranspiler: ...@@ -515,17 +452,26 @@ class DistributeTranspiler:
lr_decay_block = pserver_program.create_block( lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1 pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx) per_opt_block = pserver_program.create_block(pre_block_idx)
# append grad merging ops before clip and weight decay
for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping
grad_varname_for_block = __op_have_grad_input__(op)
if ufind.is_connected(op, opt_op) and grad_varname_for_block:
merged_var = self._append_pserver_grad_merge_ops(
per_opt_block, grad_varname_for_block, endpoint,
grad_to_block_id, self.origin_program)
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops: if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id) __append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var)
# append global ops # append global ops
if global_ops: if global_ops:
...@@ -533,15 +479,7 @@ class DistributeTranspiler: ...@@ -533,15 +479,7 @@ class DistributeTranspiler:
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for glb_op in global_ops: for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block, __append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id) grad_to_block_id, None)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# process distributed lookup_table # process distributed lookup_table
prefetch_block = None prefetch_block = None
...@@ -631,6 +569,8 @@ class DistributeTranspiler: ...@@ -631,6 +569,8 @@ class DistributeTranspiler:
attrs=op.attrs) attrs=op.attrs)
return s_prog return s_prog
# ====================== private transpiler functions =====================
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
...@@ -836,7 +776,6 @@ class DistributeTranspiler: ...@@ -836,7 +776,6 @@ class DistributeTranspiler:
return table_opt_block return table_opt_block
# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self, def _create_vars_from_blocklist(self,
program, program,
block_list, block_list,
...@@ -979,17 +918,74 @@ class DistributeTranspiler: ...@@ -979,17 +918,74 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _orig_varname(self, varname): def _get_varname_parts(self, varname):
suff_idx = varname.find(".trainer_") # returns origin, blockid, trainerid
orig_var_name = "" orig_var_name = ""
if suff_idx >= 0: trainer_part = ""
orig_var_name = varname[:suff_idx] block_part = ""
trainer_idx = varname.find(".trainer_")
if trainer_idx >= 0:
trainer_part = varname[trainer_idx + 1:]
else:
trainer_idx = len(varname)
block_index = varname.find(".block")
if block_index >= 0:
block_part = varname[block_index + 1:trainer_idx]
else: else:
orig_var_name = varname block_index = len(varname)
return orig_var_name orig_var_name = varname[0:min(block_index, trainer_idx)]
return orig_var_name, block_part, trainer_part
def _orig_varname(self, varname):
orig, _, _ = self._get_varname_parts(varname)
return orig
def _append_pserver_grad_merge_ops(self, optimize_block,
grad_varname_for_block, endpoint,
grad_to_block_id, origin_program):
program = optimize_block.program
pserver_block = program.global_block()
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if self._orig_varname(g.name) == \
self._orig_varname(grad_varname_for_block):
grad_block = g
break
if not grad_block:
# do not append this op if current endpoint
# is not dealing with this grad block
return
orig_varname, block_name, trainer_name = self._get_varname_parts(
grad_block.name)
if block_name:
merged_var_name = '.'.join([orig_varname, block_name])
else:
merged_var_name = orig_varname
merged_var = \
pserver_block.vars[merged_var_name]
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
if self.sync_mode and self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \
(merged_var_name, i)
vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op(
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var})
# TODO(panyx0718): What if it's SELECTED_ROWS.
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
optimize_block.append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)})
return merged_var
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program): grad_to_block_id, origin_program, merged_var):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
...@@ -997,40 +993,6 @@ class DistributeTranspiler: ...@@ -997,40 +993,6 @@ class DistributeTranspiler:
# moment can use the updated shape # moment can use the updated shape
for key in opt_op.input_names: for key in opt_op.input_names:
if key == "Grad": if key == "Grad":
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(
self._orig_varname(g.name),
self._orig_varname(opt_op.input(key)[0])):
grad_block = g
break
if not grad_block:
# do not append this op if current endpoint
# is not dealing with this grad block
return
merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)]
grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
if self.sync_mode and self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \
(self._orig_varname(grad_block.name), i)
vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op(
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var})
# TODO(panyx0718): What if it's SELECTED_ROWS.
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
optimize_block.append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)})
new_inputs[key] = merged_var new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
# param is already created on global program # param is already created on global program
...@@ -1089,17 +1051,31 @@ class DistributeTranspiler: ...@@ -1089,17 +1051,31 @@ class DistributeTranspiler:
outputs=outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _is_splited_grad_var(self, var, var_dict):
grad_block = None
for _, g in var_dict.iteritems():
if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1:
grad_block = g
break
return grad_block
def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for varlist in inputs.itervalues(): for key, varlist in inputs.iteritems():
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
if not program.global_block().vars.has_key(var.name): # for ops like clipping and weight decay, get the splited var
# for inputs/outputs
grad_block = self._is_splited_grad_var(
var, program.global_block().vars)
if grad_block:
inputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name):
program.global_block().create_var( program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
...@@ -1108,13 +1084,16 @@ class DistributeTranspiler: ...@@ -1108,13 +1084,16 @@ class DistributeTranspiler:
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in outputs.iteritems():
for varlist in outputs.itervalues():
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
program.global_block().clone_variable(var) grad_block = self._is_splited_grad_var(
var, program.global_block().vars)
if grad_block:
outputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var)
optimize_block.append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
...@@ -1160,9 +1139,17 @@ class DistributeTranspiler: ...@@ -1160,9 +1139,17 @@ class DistributeTranspiler:
ufind.union(op1, op2) ufind.union(op1, op2)
return ufind return ufind
def _is_opt_op(self, op): def _is_opt_role_op(self, op):
# NOTE: It's a HACK implement. # NOTE: depend on oprole to find out whether this op is for
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... # optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attrs and \
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def _is_optimizer_op(self, op):
if "Param" in op.input_names and \ if "Param" in op.input_names and \
"LearningRate" in op.input_names: "LearningRate" in op.input_names:
return True return True
...@@ -1212,7 +1199,7 @@ class DistributeTranspiler: ...@@ -1212,7 +1199,7 @@ class DistributeTranspiler:
# find learning rate variables by optimize op # find learning rate variables by optimize op
lr_vars = set() lr_vars = set()
for op in self.optimize_ops: for op in self.optimize_ops:
if self._is_opt_op(op): if self._is_optimizer_op(op):
lr_vars.add(op.input("LearningRate")[0]) lr_vars.add(op.input("LearningRate")[0])
find_ops = [] find_ops = []
...@@ -1229,7 +1216,7 @@ class DistributeTranspiler: ...@@ -1229,7 +1216,7 @@ class DistributeTranspiler:
# NOTE: we need to skip all optimize ops, since it is connected # NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops. # with forward/backward ops and lr ops, we only need the lr ops.
if op1 != op2 and self._is_op_connected(op1, op2) and \ if op1 != op2 and self._is_op_connected(op1, op2) and \
not self._is_opt_op(op1) and not self._is_opt_op(op2): not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2):
ufind.union(op1, op2) ufind.union(op1, op2)
# find all ops which is related with lr var # find all ops which is related with lr var
for op1 in block.ops: for op1 in block.ops:
...@@ -1250,13 +1237,21 @@ class DistributeTranspiler: ...@@ -1250,13 +1237,21 @@ class DistributeTranspiler:
block = self.origin_program.global_block() block = self.origin_program.global_block()
opt_ops = [] opt_ops = []
params_grads = [] params_grads = []
origin_var_dict = self.origin_program.global_block().vars
for op in block.ops: for op in block.ops:
if self._is_opt_op(op): if self._is_opt_role_op(op):
opt_ops.append(op) opt_ops.append(op)
params_grads.append((self.origin_program.global_block().var( # HACK(wuyi): if we find grad vars from input of optimize
op.input("Param")[0]), # ops, we may get the output of clip op. Use syntax "@GRAD"
self.origin_program.global_block().var( # and op_role_var to get the pair.
op.input("Grad")[0]))) for input_name in op.input_arg_names:
if input_name.find("@GRAD") != -1 and \
op.attrs[RPC_OP_ROLE_ATTR_NAME]:
param_name = op.attrs[OP_ROLE_VAR_ATTR_NAME][0]
params_grads.append([
origin_var_dict[param_name],
origin_var_dict[input_name]
])
elif self._is_adam_connected_op(op): elif self._is_adam_connected_op(op):
opt_ops.append(op) opt_ops.append(op)
else: else:
......
...@@ -69,7 +69,8 @@ packages=['paddle', ...@@ -69,7 +69,8 @@ packages=['paddle',
'paddle.fluid.proto', 'paddle.fluid.proto',
'paddle.fluid.proto.profiler', 'paddle.fluid.proto.profiler',
'paddle.fluid.layers', 'paddle.fluid.layers',
'paddle.fluid.transpiler'] 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
if '${WITH_FLUID_ONLY}'== 'OFF': if '${WITH_FLUID_ONLY}'== 'OFF':
packages+=['paddle.proto', packages+=['paddle.proto',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册