提交 4e9d7ddb 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into fix-beam_search-dev

...@@ -23,7 +23,7 @@ repos: ...@@ -23,7 +23,7 @@ repos:
- id: clang-format-with-version-check - id: clang-format-with-version-check
name: clang-format name: clang-format
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./.clang_format.hook -i entry: bash ./tools/codestyle/clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: local - repo: local
...@@ -52,7 +52,7 @@ repos: ...@@ -52,7 +52,7 @@ repos:
hooks: hooks:
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python ./.copyright.hook entry: python ./tools/codestyle/copyright.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
...@@ -76,7 +76,8 @@ RUN easy_install -U pip && \ ...@@ -76,7 +76,8 @@ RUN easy_install -U pip && \
pip install sphinx-rtd-theme==0.1.9 recommonmark pip install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip install pre-commit 'ipython==5.3.0' && \ RUN pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python
#For docstring checker #For docstring checker
RUN pip install pylint pytest astroid isort RUN pip install pylint pytest astroid isort
......
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
# Use UBUNTU_MIRROR can speed up apt-get speed.
# ARG UBUNTU_MIRROR
# RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com/ubuntu#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi'
RUN apt-get update && apt-get install -y python python-pip iputils-ping libgtk2.0-dev wget vim net-tools iftop python-opencv RUN apt-get update && apt-get install -y python python-pip iputils-ping libgtk2.0-dev wget vim net-tools iftop python-opencv
RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/lib/libnccl.so RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/lib/libnccl.so
RUN pip install -U pip
RUN pip install -U kubernetes paddlepaddle
# IMPORTANT: # IMPORTANT:
# Add "ENV http_proxy=http://ip:port" if your download is slow, and don't forget to unset it at runtime. # Add "ENV http_proxy=http://ip:port" if your download is slow, and don't forget to unset it at runtime.
# exmaple: unset http_proxy && unset https_proxy && python fluid_benchmark.py ...
RUN pip install -U pip
RUN pip install -U kubernetes paddlepaddle
RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()\npaddle.dataset.flowers.fetch()" | python' RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()\npaddle.dataset.flowers.fetch()" | python'
RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.mnist.train()\npaddle.dataset.mnist.test()\npaddle.dataset.imdb.fetch()" | python' RUN sh -c 'echo "import paddle.v2 as paddle\npaddle.dataset.mnist.train()\npaddle.dataset.mnist.test()\npaddle.dataset.imdb.fetch()" | python'
...@@ -14,9 +21,11 @@ RUN pip uninstall -y paddlepaddle && mkdir /workspace ...@@ -14,9 +21,11 @@ RUN pip uninstall -y paddlepaddle && mkdir /workspace
ADD https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin ADD https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin
ADD https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root ADD https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root
RUN chmod +x /usr/bin/paddle_k8s
ADD *.whl / ADD *.whl /
RUN pip install /*.whl && rm -f /*.whl && chmod +x /usr/bin/paddle_k8s RUN pip install /*.whl && rm -f /*.whl
ENV LD_LIBRARY_PATH=/usr/local/lib ENV LD_LIBRARY_PATH=/usr/local/lib
ADD fluid_benchmark.py recordio_converter.py models/ /workspace/ ADD fluid_benchmark.py recordio_converter.py args.py recordio_converter.py run.sh run_fluid_benchmark.sh /workspace/
ADD models/ /workspace/models/
...@@ -97,7 +97,7 @@ def dist_transpile(trainer_id, args): ...@@ -97,7 +97,7 @@ def dist_transpile(trainer_id, args):
return train_program, fluid.default_startup_program() return train_program, fluid.default_startup_program()
else: else:
raise ValueError( raise ValueError(
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER' 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
) )
...@@ -264,8 +264,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -264,8 +264,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
break break
else: else:
loss, = exe.run([avg_loss.name], feed=feeder.feed(data)) loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
if args.update_method == "pserver":
exe.bcast_params()
if args.use_reader_op: if args.use_reader_op:
num_samples += args.batch_size * args.gpus num_samples += args.batch_size * args.gpus
else: else:
...@@ -301,9 +299,18 @@ def print_train_time(start_time, end_time, num_samples): ...@@ -301,9 +299,18 @@ def print_train_time(start_time, end_time, num_samples):
(num_samples, train_elapsed, examples_per_sec)) (num_samples, train_elapsed, examples_per_sec))
def print_paddle_envs():
print('----------- Configuration envs -----------')
for k in os.environ:
if "PADDLE_" in k:
print "ENV %s:%s" % (k, os.environ[k])
print('------------------------------------------------')
def main(): def main():
args = parse_args() args = parse_args()
print_arguments(args) print_arguments(args)
print_paddle_envs()
# the unique trainer id, starting from 0, needed by trainer # the unique trainer id, starting from 0, needed by trainer
# only # only
......
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import argparse import argparse
import random import random
import os import os
import copy
from kube_templates import pserver, trainer, envs from kube_templates import pserver, trainer, envs
...@@ -108,10 +109,9 @@ def gen_job(): ...@@ -108,10 +109,9 @@ def gen_job():
tn_container["ports"][0]["containerPort"] = spreadport tn_container["ports"][0]["containerPort"] = spreadport
envs.append({"name": "PADDLE_JOB_NAME", "value": args.jobname}) envs.append({"name": "PADDLE_JOB_NAME", "value": args.jobname})
envs.append({"name": "TRAINERS", "value": str(args.trainers)}) envs.append({"name": "PADDLE_TRAINERS", "value": str(args.trainers)})
envs.append({"name": "PSERVERS", "value": str(args.pservers)}) envs.append({"name": "PADDLE_PSERVERS", "value": str(args.pservers)})
envs.append({"name": "ENTRY", "value": args.entry}) envs.append({"name": "ENTRY", "value": args.entry})
envs.append({"name": "PADDLE_INIT_PORT", "value": str(args.port)})
envs.append({"name": "PADDLE_PSERVER_PORT", "value": str(args.port)}) envs.append({"name": "PADDLE_PSERVER_PORT", "value": str(args.port)})
# NOTE: these directories below are cluster specific, please modify # NOTE: these directories below are cluster specific, please modify
# this settings before you run on your own cluster. # this settings before you run on your own cluster.
...@@ -166,17 +166,23 @@ def gen_job(): ...@@ -166,17 +166,23 @@ def gen_job():
tn["spec"]["template"]["spec"]["volumes"] = volumes tn["spec"]["template"]["spec"]["volumes"] = volumes
tn_container["volumeMounts"] = volumeMounts tn_container["volumeMounts"] = volumeMounts
ps_container["env"] = envs ps_container["env"] = copy.deepcopy(envs)
ps_container["env"].append({"name": "TRAINING_ROLE", "value": "PSERVER"}) ps_container["env"].append({
"name": "PADDLE_TRAINING_ROLE",
"value": "PSERVER"
})
tn_container["env"] = envs tn_container["env"] = envs
if args.disttype == "pserver": if args.disttype == "pserver":
tn_container["env"].append({ tn_container["env"].append({
"name": "TRAINING_ROLE", "name": "PADDLE_TRAINING_ROLE",
"value": "TRAINER" "value": "TRAINER"
}) })
elif args.disttype == "nccl2" or args.disttype == "local": elif args.disttype == "nccl2" or args.disttype == "local":
# NCCL2 have no training role, set to plain WORKER # NCCL2 have no training role, set to plain WORKER
tn_container["env"].append({"name": "TRAINING_ROLE", "value": "WORKER"}) tn_container["env"].append({
"name": "PADDLE_TRAINING_ROLE",
"value": "WORKER"
})
os.mkdir(args.jobname) os.mkdir(args.jobname)
if args.disttype == "pserver": if args.disttype == "pserver":
......
...@@ -45,7 +45,8 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -45,7 +45,8 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-unused-result") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
ExternalProject_Add( ExternalProject_Add(
...@@ -53,7 +54,7 @@ ExternalProject_Add( ...@@ -53,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "db3424ad44901513c03a1ea31ccaacdf633fbe9f" GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
......
#!/bin/bash #!/bin/bash
python gen_doc.py layers --submodules control_flow device io nn ops tensor detection learning_rate_scheduler metric > layers.rst python gen_doc.py layers --submodules control_flow device io nn ops tensor detection learning_rate_scheduler metric > layers.rst
for module in data_feeder clip metrics executor initializer io nets optimizer param_attr profiler regularizer for module in data_feeder clip metrics executor initializer io nets optimizer param_attr profiler regularizer transpiler
do do
python gen_doc.py ${module} > ${module}.rst python gen_doc.py ${module} > ${module}.rst
done done
.. THIS FILE IS GENERATED BY `gen_doc.{py|sh}`
!DO NOT EDIT THIS FILE MANUALLY!
==========
transpiler
==========
DistributeTranspiler
--------------------
.. autoclass:: paddle.fluid.transpiler.DistributeTranspiler
:members:
:noindex:
InferenceTranspiler
-------------------
.. autoclass:: paddle.fluid.transpiler.InferenceTranspiler
:members:
:noindex:
memory_optimize
---------------
.. autofunction:: paddle.fluid.transpiler.memory_optimize
:noindex:
release_memory
--------------
.. autofunction:: paddle.fluid.transpiler.release_memory
:noindex:
HashName
--------
.. autoclass:: paddle.fluid.transpiler.HashName
:members:
:noindex:
RoundRobin
----------
.. autoclass:: paddle.fluid.transpiler.RoundRobin
:members:
:noindex:
...@@ -168,13 +168,13 @@ cd /paddle/python/paddle/fluid/tests/book ...@@ -168,13 +168,13 @@ cd /paddle/python/paddle/fluid/tests/book
第二步,启动Parameter Server: 第二步,启动Parameter Server:
```bash ```bash
PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py PADDLE_PSERVER_PORT=6174 PADDLE_PSERVER_IPS=192.168.1.2 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=192.168.1.2 PADDLE_TRAINER_ID=1 PADDLE_TRAINING_ROLE=PSERVER python test_fit_a_line.py
``` ```
执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。 执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。
第三步,启动Trainer: 第三步,启动Trainer:
```bash ```bash
PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py PADDLE_PSERVER_PORT=6174 PADDLE_PSERVER_IPS=192.168.1.3 PADDLE_TRAINERS=2 PADDLE_CURRENT_IPP=192.168.1.3 PADDLE_TRAINER_ID=1 PADDLE_TRAINING_ROLE=TRAINER python test_fit_a_line.py
``` ```
由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。 由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。
......
...@@ -114,8 +114,8 @@ def gen_train_list(file_pattern, trainers, trainer_id): ...@@ -114,8 +114,8 @@ def gen_train_list(file_pattern, trainers, trainer_id):
ret_list.append(f) ret_list.append(f)
return ret_list return ret_list
trainers = int(os.getenv("TRAINERS")) trainers = int(os.getenv("PADDLE_TRAINERS"))
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
data_file = fluid.layers.io.open_files( data_file = fluid.layers.io.open_files(
filenames=gen_train_list("./mnist-[0-9]*.recordio", 2, 0), filenames=gen_train_list("./mnist-[0-9]*.recordio", 2, 0),
thread_num=1, thread_num=1,
......
...@@ -13,6 +13,7 @@ cpu_noavx_openblas `fluid.tgz <https://guest:@paddleci.ngrok.io/repository ...@@ -13,6 +13,7 @@ cpu_noavx_openblas `fluid.tgz <https://guest:@paddleci.ngrok.io/repository
cuda7.5_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_ cuda7.5_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda8.0_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_ cuda8.0_cudnn5_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda8.0_cudnn7_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/fluid.tgz>`_ cuda8.0_cudnn7_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/fluid.tgz>`_
cuda9.0_cudnn7_avx_mkl `fluid.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda90cudnn7avxMkl/.lastSuccessful/fluid.tgz>`_
====================== ======================================== ====================== ========================================
从源码编译 从源码编译
......
...@@ -213,3 +213,12 @@ virtualenv本身也是Python的一个包,可以用pip进行安装: ...@@ -213,3 +213,12 @@ virtualenv本身也是Python的一个包,可以用pip进行安装:
保存并关闭文件。 保存并关闭文件。
这样,每次打开终端时就会自动启动名为‘paddle’的Python环境了。 这样,每次打开终端时就会自动启动名为‘paddle’的Python环境了。
10. 通过pip安装的PaddlePaddle在 :code:`import paddle.fluid` 报找不到 :code:`libmkldnn.so` 或 :code:`libmklml_intel.so`
------------------------------------------------------------------------------------------
出现这种问题的原因是在导入 :code:`paddle.fluid` 时需要加载 :code:`libmkldnn.so` 和 :code:`libmklml_intel.so`,
但是系统没有找到该文件。一般通过pip安装PaddlePaddle时会将 :code:`libmkldnn.so` 和 :code:`libmklml_intel.so`
拷贝到 :code:`/usr/local/lib` 路径下,所以解决办法是将该路径加到 :code:`LD_LIBRARY_PATH` 环境变量下,
即: :code:`export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH` 。
**注意**:如果是在虚拟环境中安装PaddlePaddle, :code:`libmkldnn.so` 和 :code:`libmklml_intel.so` 可能不在 :code:`/usr/local/lib` 路径下。
\ No newline at end of file
...@@ -40,10 +40,9 @@ void Main(bool use_gpu) { ...@@ -40,10 +40,9 @@ void Main(bool use_gpu) {
//# 2. Prepare input. //# 2. Prepare input.
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "", PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}), .shape = std::vector<int>({4, 1}),
.data = buf, .data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64}; .dtype = PaddleDType::INT64};
// For simplicity, we set all the slots with the same data. // For simplicity, we set all the slots with the same data.
...@@ -55,14 +54,12 @@ void Main(bool use_gpu) { ...@@ -55,14 +54,12 @@ void Main(bool use_gpu) {
//# 4. Get output. //# 4. Get output.
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "output buffer size: " << outputs.front().data.length; LOG(INFO) << "output buffer size: " << outputs.front().data.length();
const size_t num_elements = outputs.front().data.length / sizeof(float); const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i]; LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
} }
// TODO(Superjomn): this is should be free automatically
free(outputs[0].data.data);
} }
} }
...@@ -86,10 +83,9 @@ void MainThreads(int num_threads, bool use_gpu) { ...@@ -86,10 +83,9 @@ void MainThreads(int num_threads, bool use_gpu) {
for (int batch_id = 0; batch_id < num_batches; ++batch_id) { for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
// 2. Dummy Input Data // 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "", PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}), .shape = std::vector<int>({4, 1}),
.data = buf, .data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64}; .dtype = PaddleDType::INT64};
std::vector<PaddleTensor> inputs(4, tensor); std::vector<PaddleTensor> inputs(4, tensor);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
...@@ -99,13 +95,13 @@ void MainThreads(int num_threads, bool use_gpu) { ...@@ -99,13 +95,13 @@ void MainThreads(int num_threads, bool use_gpu) {
// 4. Get output. // 4. Get output.
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "TID: " << tid << ", " LOG(INFO) << "TID: " << tid << ", "
<< "output buffer size: " << outputs.front().data.length; << "output buffer size: " << outputs.front().data.length();
const size_t num_elements = outputs.front().data.length / sizeof(float); const size_t num_elements =
outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i]; LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
} }
free(outputs[0].data.data);
} }
}); });
} }
......
# Inference High-level APIs
This document describes the high-level inference APIs one can use to easily deploy a Paddle model for an application.
The APIs are described in `paddle_inference_api.h`, just one header file, and two libaries `libpaddle_fluid.so` and `libpaddle_fluid_api.so` are needed.
## PaddleTensor
We provide the `PaddleTensor` data structure is to give a general tensor interface.
The definition is
```c++
struct PaddleTensor {
std::string name; // variable name.
std::vector<int> shape;
PaddleBuf data; // blob of data.
PaddleDType dtype;
};
```
The data is stored in a continuous memory `PaddleBuf`, and tensor's data type is specified by a `PaddleDType`.
The `name` field is used to specify the name of input variable,
that is important when there are multiple inputs and need to distiuish which variable to set.
## engine
The inference APIs has two different underlying implementation, currently there are two valid engines:
- the native engine, which is consists of the native operators and framework,
- the Anakin engine, which is a Anakin library embeded.
The native engine takes a native Paddle model as input, and supports any model that trained by Paddle,
but the Anakin engine can only take the Anakin model as input(user need to manully transform the format first) and currently not all Paddle models are supported.
```c++
enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
};
```
## PaddlePredictor and how to create one
The main interface is `PaddlePredictor`, there are following methods
- `bool Run(const std::vector<PaddleTensor>& inputs, std::vector<PaddleTensor>* output_data)`
- take inputs and output `output_data`
- `Clone` to clone a predictor from an existing one, with model parameter shared.
There is a factory method to help create a predictor, and the user takes the ownership of this object.
```c++
template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
```
By specifying the engine kind and config, one can get an specific implementation.
## Reference
- [paddle_inference_api.h](./paddle_inference_api.h)
- [demos](./demo)
...@@ -13,3 +13,53 @@ See the License for the specific language governing permissions and ...@@ -13,3 +13,53 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/contrib/inference/paddle_inference_api.h" #include "paddle/contrib/inference/paddle_inference_api.h"
namespace paddle {
PaddleBuf::PaddleBuf(PaddleBuf&& other)
: data_(other.data_),
length_(other.length_),
memory_owned_(other.memory_owned_) {
other.memory_owned_ = false;
other.data_ = nullptr;
other.length_ = 0;
}
PaddleBuf::PaddleBuf(const PaddleBuf& other) { *this = other; }
PaddleBuf& PaddleBuf::operator=(const PaddleBuf& other) {
// only the buffer with external memory can be copied
assert(!other.memory_owned_);
data_ = other.data_;
length_ = other.length_;
memory_owned_ = other.memory_owned_;
return *this;
}
void PaddleBuf::Resize(size_t length) {
// Only the owned memory can be reset, the external memory can't be changed.
if (length_ == length) return;
assert(memory_owned_);
Free();
data_ = new char[length];
length_ = length;
memory_owned_ = true;
}
void PaddleBuf::Reset(void* data, size_t length) {
Free();
memory_owned_ = false;
data_ = data;
length_ = length;
}
void PaddleBuf::Free() {
if (memory_owned_ && data_) {
assert(length_ > 0);
delete static_cast<char*>(data_);
data_ = nullptr;
length_ = 0;
}
}
} // namespace paddle
\ No newline at end of file
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cassert>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -32,12 +33,38 @@ enum PaddleDType { ...@@ -32,12 +33,38 @@ enum PaddleDType {
INT64, INT64,
}; };
struct PaddleBuf { class PaddleBuf {
void* data; // pointer to the data memory. public:
size_t length; // number of memory bytes. PaddleBuf() = default;
PaddleBuf(PaddleBuf&& other);
// Copy only available when memory is managed externally.
explicit PaddleBuf(const PaddleBuf&);
PaddleBuf& operator=(const PaddleBuf&);
// Do not own the memory.
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
// Reset to external memory.
void Reset(void* data, size_t length);
bool empty() const { return length_ == 0; }
void* data() const { return data_; }
size_t length() const { return length_; }
~PaddleBuf() { Free(); }
private:
void Free();
void* data_{nullptr}; // pointer to the data memory.
size_t length_{0}; // number of memory bytes.
bool memory_owned_{true};
}; };
struct PaddleTensor { struct PaddleTensor {
PaddleTensor() = default;
std::string name; // variable name. std::string name; // variable name.
std::vector<int> shape; std::vector<int> shape;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed. // TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
...@@ -67,8 +94,9 @@ class PaddlePredictor { ...@@ -67,8 +94,9 @@ class PaddlePredictor {
// Predict an record. // Predict an record.
// The caller should be responsible for allocating and releasing the memory of // The caller should be responsible for allocating and releasing the memory of
// `inputs`. `inputs` should be alive until Run returns. caller should be // `inputs`. `inputs` should be available until Run returns. Caller should be
// responsible for releasing the memory of `output_data`. // responsible for the output tensor's buffer, either allocated or passed from
// outside.
virtual bool Run(const std::vector<PaddleTensor>& inputs, virtual bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data) = 0; std::vector<PaddleTensor>* output_data) = 0;
...@@ -81,8 +109,7 @@ class PaddlePredictor { ...@@ -81,8 +109,7 @@ class PaddlePredictor {
// The common configs for all the predictors. // The common configs for all the predictors.
struct Config { struct Config {
std::string model_dir; // path to the model directory. std::string model_dir; // path to the model directory.
bool enable_engine{false}; // Enable to execute (part of) the model on
}; };
}; };
......
...@@ -48,7 +48,7 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -48,7 +48,7 @@ bool PaddleInferenceAnakinPredictor::Run(
auto d_tensor_in_p = executor_.get_in(input.name); auto d_tensor_in_p = executor_.get_in(input.name);
float *d_data_p = d_tensor_in_p->mutable_data(); float *d_data_p = d_tensor_in_p->mutable_data();
if (cudaMemcpy(d_data_p, if (cudaMemcpy(d_data_p,
static_cast<float *>(input.data.data), static_cast<float *>(input.data.data()),
d_tensor_in_p->valid_size() * sizeof(float), d_tensor_in_p->valid_size() * sizeof(float),
cudaMemcpyHostToDevice) != 0) { cudaMemcpyHostToDevice) != 0) {
LOG(ERROR) << "copy data from CPU to GPU error"; LOG(ERROR) << "copy data from CPU to GPU error";
...@@ -65,8 +65,11 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -65,8 +65,11 @@ bool PaddleInferenceAnakinPredictor::Run(
for (auto &output : *output_data) { for (auto &output : *output_data) {
auto *tensor = executor_.get_out(output.name); auto *tensor = executor_.get_out(output.name);
output.shape = tensor->shape(); output.shape = tensor->shape();
if (output.data.length() < tensor->valid_size() * sizeof(float)) {
output.data.Resize(tensor->valid_size() * sizeof(float));
}
// Copy data from GPU -> CPU // Copy data from GPU -> CPU
if (cudaMemcpy(output.data.data, if (cudaMemcpy(output.data.data(),
tensor->mutable_data(), tensor->mutable_data(),
tensor->valid_size() * sizeof(float), tensor->valid_size() * sizeof(float),
cudaMemcpyDeviceToHost) != 0) { cudaMemcpyDeviceToHost) != 0) {
......
...@@ -37,28 +37,26 @@ TEST(inference, anakin) { ...@@ -37,28 +37,26 @@ TEST(inference, anakin) {
float data[1 * 3 * 224 * 224] = {1.0f}; float data[1 * 3 * 224 * 224] = {1.0f};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "input_0", PaddleTensor tensor{.name = "input_0",
.shape = std::vector<int>({1, 3, 224, 224}), .shape = std::vector<int>({1, 3, 224, 224}),
.data = buf, .data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::FLOAT32}; .dtype = PaddleDType::FLOAT32};
// For simplicity, we set all the slots with the same data. // For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor); std::vector<PaddleTensor> paddle_tensor_feeds;
paddle_tensor_feeds.emplace_back(std::move(tensor));
float data_out[1000];
PaddleBuf buf_out{.data = data_out, .length = sizeof(data)};
PaddleTensor tensor_out{.name = "prob_out", PaddleTensor tensor_out{.name = "prob_out",
.shape = std::vector<int>({1000, 1}), .shape = std::vector<int>({1000, 1}),
.data = buf_out, .data = PaddleBuf(),
.dtype = PaddleDType::FLOAT32}; .dtype = PaddleDType::FLOAT32};
std::vector<PaddleTensor> outputs(1, tensor_out); std::vector<PaddleTensor> outputs;
outputs.emplace_back(std::move(tensor_out));
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
float* data_o = static_cast<float*>(outputs[0].data.data); float* data_o = static_cast<float*>(outputs[0].data.data());
for (size_t j = 0; j < 1000; ++j) { for (size_t j = 0; j < 1000; ++j) {
LOG(INFO) << "output[" << j << "]: " << data_o[j]; LOG(INFO) << "output[" << j << "]: " << data_o[j];
} }
......
...@@ -178,8 +178,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -178,8 +178,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), std::memcpy(static_cast<void *>(input_ptr),
inputs[i].data.data, inputs[i].data.data(),
inputs[i].data.length); inputs[i].data.length());
feeds->push_back(input); feeds->push_back(input);
} }
return true; return true;
...@@ -241,10 +241,11 @@ bool NativePaddlePredictor::GetFetch( ...@@ -241,10 +241,11 @@ bool NativePaddlePredictor::GetFetch(
} }
outputs->at(i).shape = shape; outputs->at(i).shape = shape;
outputs->at(i).data.length = sizeof(float) * data.size(); auto &buffer = outputs->at(i).data;
outputs->at(i).data.data = malloc(outputs->at(i).data.length); if (buffer.empty() || buffer.length() < sizeof(float) * data.size()) {
std::memcpy( buffer.Resize(sizeof(float) * data.size());
outputs->at(i).data.data, data.data(), outputs->at(i).data.length); }
std::memcpy(buffer.data(), data.data(), buffer.length());
outputs->at(i).dtype = PaddleDType::FLOAT32; outputs->at(i).dtype = PaddleDType::FLOAT32;
// TODO(panyx0718): support other types? fill tensor name? avoid a copy. // TODO(panyx0718): support other types? fill tensor name? avoid a copy.
} }
......
...@@ -27,13 +27,12 @@ namespace paddle { ...@@ -27,13 +27,12 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
pt.data.data = t->data<void>();
if (t->type() == typeid(int64_t)) { if (t->type() == typeid(int64_t)) {
pt.data.length = t->numel() * sizeof(int64_t); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == typeid(float)) {
pt.data.length = t->numel() * sizeof(float); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {
LOG(FATAL) << "unsupported type."; LOG(FATAL) << "unsupported type.";
...@@ -79,8 +78,8 @@ void MainWord2Vec(bool use_gpu) { ...@@ -79,8 +78,8 @@ void MainWord2Vec(bool use_gpu) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
size_t len = outputs[0].data.length; size_t len = outputs[0].data.length();
float* data = static_cast<float*>(outputs[0].data.data); float* data = static_cast<float*>(outputs[0].data.data());
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0); ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0); ASSERT_GT(data[j], -1.0);
...@@ -103,8 +102,6 @@ void MainWord2Vec(bool use_gpu) { ...@@ -103,8 +102,6 @@ void MainWord2Vec(bool use_gpu) {
EXPECT_LT(lod_data[i] - data[i], 1e-3); EXPECT_LT(lod_data[i] - data[i], 1e-3);
EXPECT_GT(lod_data[i] - data[i], -1e-3); EXPECT_GT(lod_data[i] - data[i], -1e-3);
} }
free(outputs[0].data.data);
} }
void MainImageClassification(bool use_gpu) { void MainImageClassification(bool use_gpu) {
...@@ -143,13 +140,12 @@ void MainImageClassification(bool use_gpu) { ...@@ -143,13 +140,12 @@ void MainImageClassification(bool use_gpu) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
size_t len = outputs[0].data.length; size_t len = outputs[0].data.length();
float* data = static_cast<float*>(outputs[0].data.data); float* data = static_cast<float*>(outputs[0].data.data());
float* lod_data = output1.data<float>(); float* lod_data = output1.data<float>();
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
EXPECT_NEAR(lod_data[j], data[j], 1e-3); EXPECT_NEAR(lod_data[j], data[j], 1e-3);
} }
free(data);
} }
void MainThreadsWord2Vec(bool use_gpu) { void MainThreadsWord2Vec(bool use_gpu) {
...@@ -192,8 +188,8 @@ void MainThreadsWord2Vec(bool use_gpu) { ...@@ -192,8 +188,8 @@ void MainThreadsWord2Vec(bool use_gpu) {
// check outputs range // check outputs range
ASSERT_EQ(local_outputs.size(), 1UL); ASSERT_EQ(local_outputs.size(), 1UL);
const size_t len = local_outputs[0].data.length; const size_t len = local_outputs[0].data.length();
float* data = static_cast<float*>(local_outputs[0].data.data); float* data = static_cast<float*>(local_outputs[0].data.data());
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0); ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0); ASSERT_GT(data[j], -1.0);
...@@ -205,7 +201,6 @@ void MainThreadsWord2Vec(bool use_gpu) { ...@@ -205,7 +201,6 @@ void MainThreadsWord2Vec(bool use_gpu) {
for (int i = 0; i < refs[tid].numel(); ++i) { for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3); EXPECT_NEAR(ref_data[i], data[i], 1e-3);
} }
free(data);
}); });
} }
for (int i = 0; i < num_jobs; ++i) { for (int i = 0; i < num_jobs; ++i) {
...@@ -251,14 +246,13 @@ void MainThreadsImageClassification(bool use_gpu) { ...@@ -251,14 +246,13 @@ void MainThreadsImageClassification(bool use_gpu) {
// check outputs correctness // check outputs correctness
ASSERT_EQ(local_outputs.size(), 1UL); ASSERT_EQ(local_outputs.size(), 1UL);
const size_t len = local_outputs[0].data.length; const size_t len = local_outputs[0].data.length();
float* data = static_cast<float*>(local_outputs[0].data.data); float* data = static_cast<float*>(local_outputs[0].data.data());
float* ref_data = refs[tid].data<float>(); float* ref_data = refs[tid].data<float>();
EXPECT_EQ(refs[tid].numel(), len / sizeof(float)); EXPECT_EQ(refs[tid].numel(), len / sizeof(float));
for (int i = 0; i < refs[tid].numel(); ++i) { for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3); EXPECT_NEAR(ref_data[i], data[i], 1e-3);
} }
free(data);
}); });
} }
for (int i = 0; i < num_jobs; ++i) { for (int i = 0; i < num_jobs; ++i) {
......
...@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() { ...@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device; int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls; std::vector<std::function<void()>> broadcast_calls;
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
for (auto out_var_handle : out_var_handles) { for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name_);
...@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() { ...@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>()); send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle; out_handle = out_var_handle;
} else { } else {
send_recv_buffer = send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
VariableVisitor::GetMutableTensor(out_var).mutable_data( .Resize(in_tensor.dims())
out_var_handle->place_); .mutable_data(out_var_handle->place_);
} }
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back( broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] { [send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
......
...@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0);
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars); checker(op.InputArgumentNames(), recv_vars);
} }
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var; all_vars_.emplace(var->Name(), var);
} }
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(program);
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op); CreateDistTrainOp(&result, *op);
...@@ -199,15 +200,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -199,15 +200,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
} else { } else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name); var_name_on_devices_.emplace(var_name, op_dev_id);
} }
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
...@@ -230,19 +235,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -230,19 +235,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name); var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(all_vars, g_name)) { if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0); CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0);
} else { } else {
InsertAllReduceOp(&result, g_name); InsertAllReduceOp(&result, g_name);
} }
break; break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
} }
} }
} catch (boost::bad_get e) { } catch (boost::bad_get e) {
...@@ -261,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -261,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. hazards need to be handled.
*/ */
PolishGraphToSupportDataHazards(&result); PolishGraphToSupportDataHazards(&result);
...@@ -273,11 +281,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -273,11 +281,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient( bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
const std::unordered_map<std::string, VarDesc *> &all_vars, PADDLE_ENFORCE(all_vars_.count(og) != 0);
const std::string &og) const { if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE(all_vars.count(og) != 0);
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true; return true;
} }
return false; return false;
...@@ -345,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, ...@@ -345,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size() - 1, i, og, p); auto var = new VarHandle(vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -363,24 +369,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -363,24 +369,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int var_dev_id = -1; for (auto &varname : op.InputArgumentNames()) {
for (auto &var_name : op.InputArgumentNames()) { int dev_id = GetVarDeviceID(varname);
if (var_dev_id != -1) break; if (dev_id != -1) {
for (size_t i = 0; i < var_name_on_devices.size(); ++i) { return dev_id;
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
} }
} }
return var_dev_id; return -1;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
...@@ -442,13 +447,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -442,13 +447,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->vars_[dst_dev_id][og]; auto &vars = result->vars_[dst_dev_id][og];
auto var = auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
return var; return var;
} }
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) { for (auto &prev_op : result->ops_) {
...@@ -463,16 +469,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, ...@@ -463,16 +469,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
CreateComputationalOp(result, op, 0); int op_dev_id = -1;
if (op.Type() == "split_byref") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type());
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") { if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
} }
} }
// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
result->ops_.emplace_back( int op_dev_id = -1;
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); if (op.Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
} else if (op.Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
// send_barrier and fetch_barrier op can be scheduled on device 0
op_dev_id = 0;
}
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send"); ConnectOp(result, result->ops_.back().get(), "send");
...@@ -488,9 +544,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -488,9 +544,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
"send, send_barrier. recv, fetch_barrier]"); "send, send_barrier. recv, fetch_barrier]");
} }
// TODO(Yancey1989): schedule rpc op on different place may CreateOpHandleIOs(result, op, op_dev_id);
// increate throughput
CreateOpHandleIOs(result, op, 0);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
...@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const; size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID( int GetOpDeviceID(const OpDesc &op) const;
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient( bool IsSparseGradient(const std::string &og) const;
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const; size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include <map>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -122,11 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -122,11 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
std::function<void()> method = callback; std::function<void()> method = callback;
// NOTE(zcd): device context must be ordered here because RecordEvent
// will use a mutex to ensure the safe of multi-threads.
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
ordered_ctxes.emplace(p.second, p.first);
}
for (auto &p : ordered_ctxes) {
method = [method, p, this]() { method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent( static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device), events_.at(boost::get<platform::CUDAPlace>(p.second).device),
method); method);
}; };
} }
......
...@@ -30,6 +30,7 @@ class SSAGraphBuilder { ...@@ -30,6 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout); auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) { if (timeout) {
std::lock_guard<std::mutex> l(exception_mu_);
if (exception_) { if (exception_) {
auto exp = *exception_; auto exp = *exception_;
exception_.reset(); exception_.reset();
...@@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp(
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(10) << op << " " << op->Name() << "Signal posted";
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
std::lock_guard<std::mutex> l(exception_mu_);
exception_.reset(new platform::EnforceNotMet(ex)); exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) { } catch (...) {
LOG(FATAL) << "Unknown exception catched"; LOG(FATAL) << "Unknown exception catched";
......
...@@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
std::mutex exception_mu_;
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_; std::atomic<int> running_ops_;
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#endif #endif
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {} ...@@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() { void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>() ::paddle::operators::distributed::GRPCClient>()
->SendComplete(); ->SendComplete();
} }
#endif #endif
...@@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) { const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id); std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
return std::unique_ptr<ExecutorPrepareContext>(ctx); return ctx;
} }
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
...@@ -320,7 +321,8 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -320,7 +321,8 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
} }
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars,
bool keep_kids) {
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
if (create_local_scope) { if (create_local_scope) {
...@@ -343,12 +345,20 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -343,12 +345,20 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (create_vars && create_local_scope) { if (local_scope != scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} else { } else {
// Delete the local scopes created in operators. if (!keep_kids) {
scope->DropKids(); // By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
scope->DropKids();
}
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------"; VLOG(2) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: " VLOG(2) << "Memory used after deleting local scope: "
......
...@@ -78,7 +78,7 @@ class Executor { ...@@ -78,7 +78,7 @@ class Executor {
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true); bool create_vars = true, bool keep_kids = false);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
......
...@@ -27,6 +27,7 @@ enum AttrType { ...@@ -27,6 +27,7 @@ enum AttrType {
BOOLEANS = 7; BOOLEANS = 7;
BLOCK = 8; BLOCK = 8;
LONG = 9; LONG = 9;
BLOCKS = 10;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -46,6 +47,7 @@ message OpDesc { ...@@ -46,6 +47,7 @@ message OpDesc {
repeated bool bools = 11; repeated bool bools = 11;
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14;
}; };
message Var { message Var {
......
...@@ -51,8 +51,6 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { ...@@ -51,8 +51,6 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
} }
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code());
if (!platform::is_cpu_place(t.place())) { if (!platform::is_cpu_place(t.place())) {
LoDTensor tt; LoDTensor tt;
framework::TensorCopy(t, platform::CPUPlace(), &tt); framework::TensorCopy(t, platform::CPUPlace(), &tt);
...@@ -70,7 +68,13 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -70,7 +68,13 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
os << t.data<float>()[i] << " "; if (t.type().hash_code() == typeid(float).hash_code()) {
os << t.data<float>()[i] << " ";
} else if (t.type().hash_code() == typeid(int64_t).hash_code()) {
os << t.data<int64_t>()[i] << " ";
} else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
}
} }
return os; return os;
......
...@@ -26,6 +26,20 @@ ...@@ -26,6 +26,20 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(LoD, PrintLoDTensor) {
LoDTensor tensor1;
tensor1.mutable_data<float>(platform::CPUPlace());
tensor1.data<float>()[0] = 0.2;
tensor1.data<float>()[1] = 0.5;
LOG(INFO) << tensor1;
LoDTensor tensor2;
tensor2.mutable_data<int64_t>(platform::CPUPlace());
tensor2.data<int64_t>()[0] = 1;
tensor2.data<int64_t>()[1] = 2;
LOG(INFO) << tensor2;
}
TEST(LoD, data) { TEST(LoD, data) {
LoD lod{{0, 1, 2}}; LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5}); lod.push_back({0, 2, 4, 5});
...@@ -37,7 +51,7 @@ TEST(LoD, data) { ...@@ -37,7 +51,7 @@ TEST(LoD, data) {
} }
} }
TEST(LodExpand, test) { TEST(LoD, ExpandLoD) {
LoD lod{{0, 2}}; LoD lod{{0, 2}};
LoDTensor tensor; LoDTensor tensor;
tensor.set_lod(lod); tensor.set_lod(lod);
......
...@@ -211,6 +211,12 @@ void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) { ...@@ -211,6 +211,12 @@ void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
need_update_ = true; need_update_ = true;
} }
void OpDesc::SetBlocksAttr(const std::string &name,
std::vector<BlockDesc *> blocks) {
this->attrs_[name] = blocks;
need_update_ = true;
}
void OpDesc::SetAttrMap( void OpDesc::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) { const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map; attrs_ = attr_map;
...@@ -305,6 +311,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -305,6 +311,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const { void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools()); VectorToRepeated(v, attr_->mutable_bools());
} }
void operator()(const std::vector<BlockDesc *> &v) const {
std::vector<int> blocks_idx;
for (auto blk : v) {
blocks_idx.push_back(blk->ID());
}
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
......
...@@ -77,6 +77,8 @@ class OpDesc { ...@@ -77,6 +77,8 @@ class OpDesc {
void SetBlockAttr(const std::string &name, BlockDesc *block); void SetBlockAttr(const std::string &name, BlockDesc *block);
void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const;
......
...@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor(
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
details::SSAGraphBuilderFactory builder_factory( details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy); build_strategy);
...@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor(
#endif #endif
} }
builder_ = builder_factory.Create();
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program))); builder_->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
...@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor( ...@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
auto *main_scope = member_->local_scopes_[0]; // the the initialize bcast, all vars would be bcast from device(0), otherwise
// bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false;
for (auto &var : vars) { for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var); int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initialize) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
}
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
...@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs(
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
if (i == 0) {
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
......
...@@ -19,12 +19,14 @@ limitations under the License. */ ...@@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -68,6 +70,7 @@ class ParallelExecutor { ...@@ -68,6 +70,7 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
}; };
} // namespace framework } // namespace framework
......
...@@ -35,7 +35,8 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -35,7 +35,8 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t>; std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -27,7 +27,7 @@ void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { ...@@ -27,7 +27,7 @@ void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_); SubGraphFuse(graph, node_inside_subgraph_teller_);
} }
} // analysis } // namespace analysis
} // inference } // namespace inference
} // paddle } // namespace paddle
...@@ -184,9 +184,9 @@ else() ...@@ -184,9 +184,9 @@ else()
set(DEPS_OPS ${DEPS_OPS} nccl_op) set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif() endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
...@@ -195,18 +195,11 @@ if(WITH_DISTRIBUTE) ...@@ -195,18 +195,11 @@ if(WITH_DISTRIBUTE)
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS}) foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
op_library(recv_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) endforeach()
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL) # listen_and_serv_op sum_op executor SERIAL)
......
...@@ -143,7 +143,7 @@ $$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ ...@@ -143,7 +143,7 @@ $$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
__attribute__((unused)) constexpr char TanhShrinkDoc[] = R"DOC( __attribute__((unused)) constexpr char TanhShrinkDoc[] = R"DOC(
TanhShrink Activation Operator. TanhShrink Activation Operator.
$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ $$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC"; )DOC";
...@@ -385,7 +385,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -385,7 +385,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
STanh Activation Operator. STanh Activation Operator.
$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ $$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
)DOC"); )DOC");
} }
......
...@@ -21,8 +21,6 @@ namespace operators { ...@@ -21,8 +21,6 @@ namespace operators {
using batch_norm_bwd = mkldnn::batch_normalization_backward; using batch_norm_bwd = mkldnn::batch_normalization_backward;
using batch_norm_fwd = mkldnn::batch_normalization_forward; using batch_norm_fwd = mkldnn::batch_normalization_forward;
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory; using mkldnn::memory;
using mkldnn::primitive; using mkldnn::primitive;
using mkldnn::reorder; using mkldnn::reorder;
...@@ -31,18 +29,6 @@ using paddle::platform::MKLDNNDeviceContext; ...@@ -31,18 +29,6 @@ using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
namespace { namespace {
template <typename T> template <typename T>
struct bn_type_traits { struct bn_type_traits {
......
...@@ -22,22 +22,6 @@ limitations under the License. */ ...@@ -22,22 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
class BatchNormOp : public framework::OperatorWithKernel { class BatchNormOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
......
...@@ -19,6 +19,22 @@ limitations under the License. */ ...@@ -19,6 +19,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> { class BatchNormKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -110,6 +110,7 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp, ...@@ -110,6 +110,7 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
ops::BilinearInterpOpMaker, ops::BilinearInterpOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad); REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>); REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
ops::BilinearInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradKernel<float>); ops::BilinearInterpGradKernel<float>);
...@@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> { ...@@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
int in_chw = channels * in_hw; int in_chw = channels * in_hw;
int out_chw = channels * out_hw; int out_chw = channels * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; float ratio_h =
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; (out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T)); memcpy(output, input, input_t->numel() * sizeof(T));
...@@ -56,24 +58,24 @@ class BilinearInterpKernel : public framework::OpKernel<T> { ...@@ -56,24 +58,24 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
for (int i = 0; i < out_h; ++i) { // loop for images for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i; int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0; int hid = (h < in_h - 1) ? 1 : 0;
T h1lambda = ratio_h * i - h; float h1lambda = ratio_h * i - h;
T h2lambda = 1 - h1lambda; float h2lambda = 1.f - h1lambda;
for (int j = 0; j < out_w; ++j) { for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j; int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0; int wid = (w < in_w - 1) ? 1 : 0;
T w1lambda = ratio_w * j - w; float w1lambda = ratio_w * j - w;
T w2lambda = 1 - w1lambda; float w2lambda = 1.f - w1lambda;
// calculate four position for bilinear interpolation // calculate four position for bilinear interpolation
const T* in_pos = &input[k * in_chw + h * in_w + w]; const T* in_pos = &input[k * in_chw + h * in_w + w];
T* out_pos = &output[k * out_chw + i * out_w + j]; T* out_pos = &output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels for (int c = 0; c < channels; ++c) { // loop for channels
// bilinear interpolation // bilinear interpolation
out_pos[0] = out_pos[0] = static_cast<T>(
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
h1lambda * (w2lambda * in_pos[hid * in_w] + h1lambda * (w2lambda * in_pos[hid * in_w] +
w1lambda * in_pos[hid * in_w + wid]); w1lambda * in_pos[hid * in_w + wid]));
in_pos += in_hw; in_pos += in_hw;
out_pos += out_hw; out_pos += out_hw;
} }
...@@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
int in_chw = channels * in_hw; int in_chw = channels * in_hw;
int out_chw = channels * out_hw; int out_chw = channels * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; float ratio_h =
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; (out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
...@@ -127,22 +131,24 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -127,22 +131,24 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
for (int i = 0; i < out_h; ++i) { // loop for images for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i; int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0; int hid = (h < in_h - 1) ? 1 : 0;
T h1lambda = ratio_h * i - h; float h1lambda = ratio_h * i - h;
T h2lambda = 1 - h1lambda; float h2lambda = 1 - h1lambda;
for (int j = 0; j < out_w; ++j) { for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j; int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0; int wid = (w < in_w - 1) ? 1 : 0;
T w1lambda = ratio_w * j - w; float w1lambda = ratio_w * j - w;
T w2lambda = 1 - w1lambda; float w2lambda = 1 - w1lambda;
T* in_pos = &d_input[k * in_chw + h * in_w + w]; T* in_pos = &d_input[k * in_chw + h * in_w + w];
const T* out_pos = &d_output[k * out_chw + i * out_w + j]; const T* out_pos = &d_output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels for (int c = 0; c < channels; ++c) { // loop for channels
in_pos[0] += h2lambda * w2lambda * out_pos[0]; in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
in_pos[wid] += h2lambda * w1lambda * out_pos[0]; in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
in_pos[hid * in_w] += h1lambda * w2lambda * out_pos[0]; in_pos[hid * in_w] +=
in_pos[hid * in_w + wid] += h1lambda * w1lambda * out_pos[0]; static_cast<T>(h1lambda * w2lambda * out_pos[0]);
in_pos[hid * in_w + wid] +=
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
in_pos += in_hw; in_pos += in_hw;
out_pos += out_hw; out_pos += out_hw;
} }
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_GRPC #ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer #define RPCSERVER_T distributed::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient #define RPCCLIENT_T distributed::GRPCClient
#else #else
#include "paddle/fluid/operators/detail/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer #define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient #define RPCCLIENT_T distributed::BRPCClient
#endif #endif
...@@ -175,12 +175,12 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -175,12 +175,12 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Detection mAP evaluate operator. Detection mAP evaluate operator.
The general steps are as follows. First, calculate the true positive and The general steps are as follows. First, calculate the true positive and
false positive according to the input of detection and labels, then false positive according to the input of detection and labels, then
calculate the mAP evaluate value. calculate the mAP evaluate value.
Supporting '11 point' and 'integral' mAP algorithm. Please get more information Supporting '11 point' and 'integral' mAP algorithm. Please get more information
from the following articles: from the following articles:
https://sanchom.wordpress.com/tag/average-precision/ https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325 https://arxiv.org/abs/1512.02325
)DOC"); )DOC");
} }
......
if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
DEFINE_int32(brpc_channel_num, 24, DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server"); "Number of channels to send requests connected to one server");
...@@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ...@@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
return q; return q;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -31,13 +31,13 @@ limitations under the License. */ ...@@ -31,13 +31,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
struct ChannelContext { struct ChannelContext {
brpc::Channel channel; brpc::Channel channel;
...@@ -95,6 +95,6 @@ class BRPCClient : public RPCClient { ...@@ -95,6 +95,6 @@ class BRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(BRPCClient); DISABLE_COPY_AND_ASSIGN(BRPCClient);
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv { namespace sendrecv {
typedef std::unordered_map<std::string, typedef std::unordered_map<std::string,
paddle::operators::detail::RequestHandler*> paddle::operators::distributed::RequestHandler*>
HandlerMap; HandlerMap;
class BRPCServiceImpl : public SendRecvService { class BRPCServiceImpl : public SendRecvService {
...@@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService {
: request_send_h_(nullptr), : request_send_h_(nullptr),
request_get_h_(nullptr), request_get_h_(nullptr),
request_prefetch_h_(nullptr) { request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend); auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_send_h_ = it->second; request_send_h_ = it->second;
} }
it = rpc_call_map.find(paddle::operators::detail::kRequestSend); it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_get_h_ = it->second; request_get_h_ = it->second;
} }
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch); it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second; request_prefetch_h_ = it->second;
} }
...@@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService { ...@@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService {
} }
private: private:
paddle::operators::detail::RequestHandler* request_send_h_; paddle::operators::distributed::RequestHandler* request_send_h_;
paddle::operators::detail::RequestHandler* request_get_h_; paddle::operators::distributed::RequestHandler* request_get_h_;
paddle::operators::detail::RequestHandler* request_prefetch_h_; paddle::operators::distributed::RequestHandler* request_prefetch_h_;
}; };
} // namespace sendrecv } // namespace sendrecv
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void AsyncBRPCServer::StartServer() { void AsyncBRPCServer::StartServer() {
// Instance of your service. // Instance of your service.
...@@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() { ...@@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer WaitSeverReady"; VLOG(3) << "AsyncGRPCServer WaitSeverReady";
} }
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -19,12 +19,12 @@ limitations under the License. */ ...@@ -19,12 +19,12 @@ limitations under the License. */
#include <string> #include <string>
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class AsyncBRPCServer final : public RPCServer { class AsyncBRPCServer final : public RPCServer {
public: public:
...@@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer { ...@@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer {
int ready_; int ready_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -17,11 +17,11 @@ limitations under the License. */ ...@@ -17,11 +17,11 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC // file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data. // requests without too much copying of the tensor data.
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
GrpcByteBufferSource::GrpcByteBufferSource() {} GrpcByteBufferSource::GrpcByteBufferSource() {}
...@@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const { ...@@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_; return byte_count_;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -106,7 +106,7 @@ class GrpcBufferReader final ...@@ -106,7 +106,7 @@ class GrpcBufferReader final
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
// Source provides a way for a particular RPC implementation to provide // Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom. // received data to ParseFrom.
class Source { class Source {
...@@ -183,6 +183,6 @@ class GrpcByteSource : public Source { ...@@ -183,6 +183,6 @@ class GrpcByteSource : public Source {
char space_[sizeof(Reader)]; char space_[sizeof(Reader)];
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,19 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,19 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#include <sys/time.h> #include <sys/time.h>
#include <limits> #include <limits>
#include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void GRPCClient::InitImpl() { InitEventLoop(); } void GRPCClient::InitImpl() { InitEventLoop(); }
...@@ -75,6 +76,9 @@ bool GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -75,6 +76,9 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
var_h.scope = p_scope; var_h.scope = p_scope;
var_h.name = var_name_val; var_h.name = var_name_val;
var_h.ctx = p_ctx; var_h.ctx = p_ctx;
var_h.method = "Send";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
SendProcessor* s = new SendProcessor(ch); SendProcessor* s = new SendProcessor(ch);
...@@ -129,6 +133,9 @@ bool GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -129,6 +133,9 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
var_h.scope = p_scope; var_h.scope = p_scope;
var_h.name = var_name_val; var_h.name = var_name_val;
var_h.ctx = p_ctx; var_h.ctx = p_ctx;
var_h.method = "Get";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
...@@ -172,6 +179,9 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -172,6 +179,9 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
var_h.scope = p_scope; var_h.scope = p_scope;
var_h.name = out_var_name_val; var_h.name = out_var_name_val;
var_h.ctx = p_ctx; var_h.ctx = p_ctx;
var_h.method = "Prefetch";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
...@@ -243,10 +253,11 @@ void GRPCClient::Proceed() { ...@@ -243,10 +253,11 @@ void GRPCClient::Proceed() {
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(c); PADDLE_ENFORCE(c);
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process";
c->Process(); c->Process();
} else { } else {
LOG(FATAL) << "var: " << c->var_h_.String() LOG(FATAL) << c->var_h_.String()
<< " grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
} }
delete c; delete c;
{ {
...@@ -276,6 +287,6 @@ std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { ...@@ -276,6 +287,6 @@ std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
return ch; return ch;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -38,23 +38,27 @@ limitations under the License. */ ...@@ -38,23 +38,27 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
struct VarHandle { struct VarHandle {
// RPC endpoint.
std::string ep; std::string ep;
const platform::DeviceContext* ctx; const platform::DeviceContext* ctx;
const framework::Scope* scope; const framework::Scope* scope;
// Variable name.
std::string name; std::string name;
// RPC method name.
std::string method;
std::string String() const { std::string String() const {
std::ostringstream s; std::ostringstream s;
s << "name:[" << name << "] ep:[" << ep << "]"; s << method << " name:[" << name << "], ep:[" << ep << "]";
return s.str(); return s.str();
} }
}; };
...@@ -226,6 +230,6 @@ class GRPCClient : public RPCClient { ...@@ -226,6 +230,6 @@ class GRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(GRPCClient); DISABLE_COPY_AND_ASSIGN(GRPCClient);
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,8 +21,8 @@ limitations under the License. */ ...@@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < 564; ++i) rows->push_back(i); for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0)); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
...@@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// deserialize zero-copy // deserialize zero-copy
// framework::Variable var2; // framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx); operators::distributed::VariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar(); framework::Variable* var2 = resp.GetVar();
...@@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math::set_constant(ctx, tensor, 31.9); math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0)); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
...@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy // deserialize zero-copy
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx); operators::distributed::VariableResponse resp(&scope, &ctx);
if (from_type == 0) { if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
} else { } else {
......
...@@ -15,13 +15,13 @@ limitations under the License. */ ...@@ -15,13 +15,13 @@ limitations under the License. */
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
// reference: // reference:
...@@ -41,6 +41,19 @@ class RequestBase { ...@@ -41,6 +41,19 @@ class RequestBase {
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() = 0; virtual void Process() = 0;
std::string Status2String(const std::string& method) {
std::string status = "Process";
if (status_ == FINISH) {
status = "Finish";
}
std::ostringstream s;
s << method << " name:[" << GetReqName() << "]"
<< ", ep:[" << ctx_.peer() << "]"
<< " " << status << " using req_id:" << req_id_;
return s.str();
}
CallStatus Status() const { CallStatus Status() const {
std::lock_guard<std::mutex> l(status_mu_); std::lock_guard<std::mutex> l(status_mu_);
return status_; return status_;
...@@ -74,7 +87,7 @@ class RequestSend final : public RequestBase { ...@@ -74,7 +87,7 @@ class RequestSend final : public RequestBase {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), request_handler->dev_ctx(),
!request_handler->sync_mode())); !request_handler->sync_mode()));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
...@@ -106,7 +119,7 @@ class RequestGet final : public RequestBase { ...@@ -106,7 +119,7 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_, method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
...@@ -150,7 +163,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -150,7 +163,8 @@ class RequestPrefetch final : public RequestBase {
local_scope_(nullptr) { local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true)); request_handler->dev_ctx(), true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
...@@ -271,7 +285,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -271,7 +285,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) { int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; LOG(WARNING) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
...@@ -305,14 +319,14 @@ void AsyncGRPCServer::HandleRequest( ...@@ -305,14 +319,14 @@ void AsyncGRPCServer::HandleRequest(
bool ok = false; bool ok = false;
while (true) { while (true) {
VLOG(3) << "HandleRequest " << rpc_name << " wait next"; VLOG(4) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!"; LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!";
break; break;
} }
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag)); int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
VLOG(3) << "HandleRequest " << rpc_name << ", req_id:" << req_id VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
<< " get next"; << " get next";
auto& reqs = rpc_reqs_[rpc_name]; auto& reqs = rpc_reqs_[rpc_name];
...@@ -323,22 +337,21 @@ void AsyncGRPCServer::HandleRequest( ...@@ -323,22 +337,21 @@ void AsyncGRPCServer::HandleRequest(
base = reqs[req_id]; base = reqs[req_id];
} }
VLOG(3) << base->Status2String(rpc_name);
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) { if (!ok) {
LOG(WARNING) << "completion queue:" << rpc_name LOG(WARNING) << "completion queue:" << rpc_name
<< " recv no regular event:argument name[" << " recv no regular event"
<< base->GetReqName() << "]"; << " context:" << base->Status2String(rpc_name);
TryToRegisterNewOne(rpc_name, req_id); TryToRegisterNewOne(rpc_name, req_id);
delete base; delete base;
continue; continue;
} }
VLOG(3) << "queue id:" << rpc_name << ", req_id:" << req_id
<< ", status:" << base->Status();
switch (base->Status()) { switch (base->Status()) {
case PROCESS: { case PROCESS: {
base->Process(); base->Process();
...@@ -354,6 +367,6 @@ void AsyncGRPCServer::HandleRequest( ...@@ -354,6 +367,6 @@ void AsyncGRPCServer::HandleRequest(
} }
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -29,17 +29,17 @@ limitations under the License. */ ...@@ -29,17 +29,17 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h" #include "paddle/fluid/operators/distributed/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RequestBase; class RequestBase;
...@@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer { ...@@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer {
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_; std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h> #include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h> #include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -42,24 +42,25 @@ class ServerContext; ...@@ -42,24 +42,25 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse. // Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse. // Wire-format is identical to RecvVariableResponse.
template <> template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> { class SerializationTraits<paddle::operators::distributed::VariableResponse> {
public: public:
static Status Serialize( static Status Serialize(
const paddle::operators::detail::VariableResponse& msg, const paddle::operators::distributed::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) { grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status(); return Status();
} }
static Status Deserialize(grpc_byte_buffer* buffer, static Status Deserialize(
paddle::operators::detail::VariableResponse* msg, grpc_byte_buffer* buffer,
int max_message_size = INT_MAX) { paddle::operators::distributed::VariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) { if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload"); return Status(StatusCode::INTERNAL, "No payload");
} }
Status result = g_core_codegen_interface->ok(); Status result = g_core_codegen_interface->ok();
if (result.ok()) { if (result.ok()) {
paddle::operators::detail::GrpcByteSource source(buffer); paddle::operators::distributed::GrpcByteSource source(buffer);
int ret = msg->Parse(&source); int ret = msg->Parse(&source);
if (ret != 0) { if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
...@@ -73,7 +74,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> { ...@@ -73,7 +74,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> {
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum class GrpcMethod { enum class GrpcMethod {
kSendVariable, kSendVariable,
...@@ -118,6 +119,6 @@ class GrpcService final { ...@@ -118,6 +119,6 @@ class GrpcService final {
}; };
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
char* EncodeVarint32(char* dst, uint32_t v) { char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds // Operate on characters as unsigneds
...@@ -144,6 +144,6 @@ class ProtoEncodeHelper { ...@@ -144,6 +144,6 @@ class ProtoEncodeHelper {
char* limit_; // Just for CHECKs char* limit_; // Just for CHECKs
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
...@@ -124,6 +124,6 @@ class RequestHandler { ...@@ -124,6 +124,6 @@ class RequestHandler {
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
...@@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true; return true;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -28,11 +28,11 @@ ...@@ -28,11 +28,11 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RequestSendHandler final : public RequestHandler { class RequestSendHandler final : public RequestHandler {
public: public:
...@@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
std::once_flag RPCClient::init_flag_; std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RPCClient { class RPCClient {
public: public:
...@@ -84,6 +84,6 @@ class RPCClient { ...@@ -84,6 +84,6 @@ class RPCClient {
static std::once_flag init_flag_; static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_; static std::unique_ptr<RPCClient> rpc_client_;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void RPCServer::ShutDown() { void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown "; LOG(INFO) << "RPCServer ShutDown ";
...@@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) { ...@@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RPCServer { class RPCServer {
public: public:
...@@ -86,6 +86,6 @@ class RPCServer { ...@@ -86,6 +86,6 @@ class RPCServer {
friend class RequestHandler; friend class RequestHandler;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -22,18 +22,18 @@ limitations under the License. */ ...@@ -22,18 +22,18 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace detail = paddle::operators::detail; namespace distributed = paddle::operators::distributed;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
...@@ -113,19 +113,21 @@ void StartServer() { ...@@ -113,19 +113,21 @@ void StartServer() {
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get()); g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true)); g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
...@@ -23,14 +23,14 @@ limitations under the License. */ ...@@ -23,14 +23,14 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h" #include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
...@@ -222,11 +222,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -222,11 +222,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var) {
operators::detail::VariableResponse resp(scope, &ctx); operators::distributed::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -25,12 +25,12 @@ limitations under the License. */ ...@@ -25,12 +25,12 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
...@@ -61,6 +61,6 @@ inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { ...@@ -61,6 +61,6 @@ inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
} }
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
#endif #endif
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum WireType { enum WireType {
WIRETYPE_VARINT = 0, WIRETYPE_VARINT = 0,
...@@ -76,6 +76,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -76,6 +76,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
if (total_written + size_to_write > length) { if (total_written + size_to_write > length) {
size_to_write = length - total_written; size_to_write = length - total_written;
} }
// This log is useful to see how long a internal block size is of rpc.
VLOG(7) << "copy " << size_to_write << " data to CUDAPlace";
memory::Copy(boost::get<platform::CUDAPlace>(place), memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write, reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream()); gpu_dev_ctx.stream());
...@@ -103,6 +105,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -103,6 +105,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
} }
// TODO(gongwb): can we avoid copy? // TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu; platform::CPUPlace cpu;
// This log is useful to see how long a internal block size is of rpc.
VLOG(7) << "copy " << size_to_write << " data to CPUPlace";
memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write); memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write; p += size_to_write;
...@@ -158,13 +162,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -158,13 +162,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
static_cast<size_t>(tensor->numel()), length / framework::SizeOfType(
length / framework::SizeOfType( paddle::operators::distributed::ToTypeIndex(
paddle::operators::detail::ToTypeIndex(meta_.data_type()))); meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -480,6 +484,6 @@ int VariableResponse::Parse(Source* source) { ...@@ -480,6 +484,6 @@ int VariableResponse::Parse(Source* source) {
return 0; return 0;
} }
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -22,17 +22,17 @@ ...@@ -22,17 +22,17 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class VariableResponse { class VariableResponse {
public: public:
...@@ -99,6 +99,6 @@ class VariableResponse { ...@@ -99,6 +99,6 @@ class VariableResponse {
sendrecv::VariableMessage meta_; sendrecv::VariableMessage meta_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -42,8 +42,8 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -42,8 +42,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait(); rpc_client->Wait();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/mean_op.h"
namespace paddle {
namespace operators {
using framework::DataLayout;
template <typename T>
class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
// The format of output is set as the mkldnn's format
// TODO(@mozga-intel) The format of matrix sets inside the another layers.
tensor->set_layout(DataLayout::kMKLDNN);
tensor->set_format(mkldnn::memory::format::oihw);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(gaussian_random, MKLDNN, ::paddle::platform::CPUPlace,
ops::GaussianMKLDNNKernel<float>);
...@@ -15,6 +15,10 @@ limitations under the License. */ ...@@ -15,6 +15,10 @@ limitations under the License. */
#include <random> #include <random>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -62,9 +66,20 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -62,9 +66,20 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context()); ctx.device_context(), layout, library);
} }
}; };
...@@ -95,7 +110,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -95,7 +110,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
"(int, default 5(FP32)) " "(int, default 5(FP32)) "
"Output data type.") "Output data type.")
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
GaussianRandom Operator. GaussianRandom Operator.
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
...@@ -60,7 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -60,7 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
...@@ -80,11 +81,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -80,11 +81,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
detail::RequestSendHandler rpc_h(true); distributed::RequestSendHandler rpc_h(true);
std::unique_ptr<detail::RPCServer> rpc_service( std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1)); new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(detail::kRequestSend, &rpc_h); rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get()); rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
...@@ -95,11 +96,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -95,11 +96,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_h.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(detail::kRequestSend); rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(detail::kRequestSend); rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service->ShutDown(); rpc_service->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
......
...@@ -21,14 +21,14 @@ limitations under the License. */ ...@@ -21,14 +21,14 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void RunServer(std::shared_ptr<detail::RPCServer> service) { void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer(); service->StartServer();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
...@@ -101,17 +101,16 @@ void ListenAndServOp::RunSyncLoop( ...@@ -101,17 +101,16 @@ void ListenAndServOp::RunSyncLoop(
framework::Scope *recv_scope, framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const { const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> optimize_block_id_list; std::vector<int> optimize_blocks_idx;
for (int blkid = 1; blkid < num_blocks; ++blkid) { for (auto blk : optimize_blocks) {
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(), optimize_blocks_idx.push_back(blk->ID());
blkid) == prefetch_block_id_list.end()) {
optimize_block_id_list.push_back(blkid);
}
} }
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list); auto optimize_prepared = executor->Prepare(*program, optimize_blocks_idx);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert( optimize_prepared.insert(
optimize_prepared.begin(), optimize_prepared.begin(),
...@@ -121,12 +120,12 @@ void ListenAndServOp::RunSyncLoop( ...@@ -121,12 +120,12 @@ void ListenAndServOp::RunSyncLoop(
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(detail::kRequestSend); rpc_service_->SetCond(distributed::kRequestSend);
rpc_service_->WaitBarrier(detail::kRequestSend); rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!"; LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
break; break;
} }
...@@ -134,14 +133,14 @@ void ListenAndServOp::RunSyncLoop( ...@@ -134,14 +133,14 @@ void ListenAndServOp::RunSyncLoop(
// and this will still work. // and this will still work.
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future // TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp(); double ts = GetTimestamp();
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) { for (size_t i = 1; i < optimize_blocks.size(); ++i) {
// skip the first optimize block because it is already in the // skip the first optimize block because it is already in the
// parallel_blkids. // parallel_blkids.
int blkid = optimize_block_id_list[i]; int blkid = optimize_blocks[i]->ID();
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope); program, recv_scope);
...@@ -154,11 +153,11 @@ void ListenAndServOp::RunSyncLoop( ...@@ -154,11 +153,11 @@ void ListenAndServOp::RunSyncLoop(
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
// reset received sparse vars to avoid reuse it in the next mini-batch // reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get()) dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder(); ->ResetSparseVarRecorder();
} // while(true) } // while(true)
} }
...@@ -215,13 +214,13 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -215,13 +214,13 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} }
static void FillRequestCtx( static void FillRequestCtx(
detail::RequestHandler *h, framework::Scope *scope, distributed::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx, framework::Executor *executor, platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx, *prefetch_ctx,
detail::RPCServer *rpc_server) { distributed::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
h->SetDevCtx(dev_ctx); h->SetDevCtx(dev_ctx);
h->SetExecutor(executor); h->SetExecutor(executor);
...@@ -249,18 +248,23 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -249,18 +248,23 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode)); request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestSend,
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch, rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto optimize_blocks =
auto *program = optimize_block->Program(); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() >= 1,
"optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare for prefetch // prepare for prefetch
...@@ -337,8 +341,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -337,8 +341,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<std::vector<framework::BlockDesc *>>(
"BlockID to run on server side."); kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({});
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.") "prefetch blocks to run on server side.")
.SetDefault({}); .SetDefault({});
......
...@@ -24,16 +24,16 @@ limitations under the License. */ ...@@ -24,16 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
...@@ -62,10 +62,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -62,10 +62,11 @@ class ListenAndServOp : public framework::OperatorBase {
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
protected: protected:
mutable std::shared_ptr<detail::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_; mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
}; };
......
...@@ -146,6 +146,6 @@ REGISTER_UNARY_LOGICAL_OP(logical_not, "$$Out = !X$$"); ...@@ -146,6 +146,6 @@ REGISTER_UNARY_LOGICAL_OP(logical_not, "$$Out = !X$$");
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU, REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU,
paddle::operators::LogicalNotFunctor); paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_xor, REGISTER_BINARY_LOGICAL_OP(logical_xor,
"$$Out = (X || Y) \\, \\&\\& \\, !(X \\&\\& Y)$$"); "$$Out = (X || Y) \\&\\& !(X \\&\\& Y)$$");
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU, REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU,
paddle::operators::LogicalXorFunctor); paddle::operators::LogicalXorFunctor);
...@@ -93,10 +93,10 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -93,10 +93,10 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation // computation
for (size_t k = 0; k < input_rows; ++k) { for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols; const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0; int col_idx = 0;
for (int j = 0; j < num; ++j) { for (size_t j = 0; j < num; ++j) {
int col_len = output_cols[j]; int col_len = output_cols[j];
auto* out_tensor = outputs->at(j); auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) { if (out_tensor != nullptr) {
......
...@@ -22,43 +22,24 @@ namespace paddle { ...@@ -22,43 +22,24 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
const T* orig = first;
const T* it = nullptr;
T step = 0;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (!(val < *it)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first - orig;
}
template <typename T> template <typename T>
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols, const int output_rows, const int output_cols,
T* output) { T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1; int curr_segment = 0;
int curr_offset = input_cols[0];
int curr_offset = input_cols[segment];
int curr_segment = segment;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset; int curr_col_offset = input_cols[curr_segment + 1];
while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset; curr_offset = curr_col_offset;
++curr_segment; ++curr_segment;
curr_col_offset = input_cols[curr_segment + 1];
} }
int local_col = tid_x - curr_offset; int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset; int segment_width = curr_col_offset - curr_offset;
T* input_ptr = inputs[curr_segment]; T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
...@@ -89,14 +70,14 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, ...@@ -89,14 +70,14 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
const int in_col, const int* out_cols, const int in_col, const int* out_cols,
int out_cols_size, T** outputs_data) { int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(out_cols, out_cols_size, tid_x) - 1; int curr_segment = 0;
int curr_offset = out_cols[segment]; int curr_offset = out_cols[0];
int curr_segment = segment;
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset; int curr_col_offset = out_cols[curr_segment + 1];
while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) { while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset; curr_offset = curr_col_offset;
++curr_segment; ++curr_segment;
curr_col_offset = out_cols[curr_segment + 1];
} }
int local_col = tid_x - curr_offset; int local_col = tid_x - curr_offset;
...@@ -228,7 +209,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -228,7 +209,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
outputs_cols[0] = 0; outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) { for (int i = 0; i < o_num; ++i) {
int t_col = outputs->at(i)->numel() / out_row; int t_col = ref_inputs.at(i)->numel() / out_row;
if (sameShape) { if (sameShape) {
if (t_col != out0_col) sameShape = false; if (t_col != out0_col) sameShape = false;
} }
......
...@@ -30,6 +30,7 @@ template struct SetConstant<platform::CPUDeviceContext, double>; ...@@ -30,6 +30,7 @@ template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>; template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>; template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>; template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \ template struct Transpose<platform::CPUDeviceContext, platform::float16, \
......
...@@ -295,7 +295,7 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -295,7 +295,7 @@ class ParallelDoGradOp : public framework::OperatorBase {
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}}, "sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}},
framework::AttributeMap{}); framework::AttributeMap{{"use_mkldnn", {false}}});
VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]); VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]);
sum_op->Run(*sub_scopes[0], places[0]); sum_op->Run(*sub_scopes[0], places[0]);
WaitOnPlace(places[0]); WaitOnPlace(places[0]);
......
...@@ -41,8 +41,8 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -41,8 +41,8 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -429,7 +429,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -429,7 +429,8 @@ class RecurrentGradOp : public RecurrentBase {
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); {{"Out", {pg_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, place); sum_op->Run(cur_scope, place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
......
...@@ -43,8 +43,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -43,8 +43,8 @@ class RecvOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
......
...@@ -44,8 +44,8 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -44,8 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
......
...@@ -45,8 +45,8 @@ class SendOp : public framework::OperatorBase { ...@@ -45,8 +45,8 @@ class SendOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -129,7 +129,10 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) { ...@@ -129,7 +129,10 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
const auto &root_block = program.Block(0); const auto &root_block = program.Block(0);
std::vector<framework::BlockDesc *> optimize_blocks;
auto *optimize_block = program.AppendBlock(root_block); auto *optimize_block = program.AppendBlock(root_block);
optimize_blocks.push_back(optimize_block);
auto *prefetch_block = program.AppendBlock(root_block); auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape. // X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block, AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block,
...@@ -139,7 +142,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) { ...@@ -139,7 +142,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs.insert({"Fanin", 1}); attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"optimize_blocks", optimize_blocks});
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})}); attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true}); attrs.insert({"sync_mode", true});
......
...@@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc; ...@@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; // Note: paddle has also "memory" namespace using mkldnn::memory; // Note: paddle has also "memory" namespace
using mkldnn::primitive; using mkldnn::primitive;
using mkldnn::softmax_forward; using mkldnn::softmax_forward;
using mkldnn::softmax_backward;
using mkldnn::prop_kind; using mkldnn::prop_kind;
using mkldnn::stream; using mkldnn::stream;
using platform::to_void_cast;
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
public:
SoftmaxMKLDNNHandler(
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd) {}
SoftmaxMKLDNNHandler(
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd),
softmax_bwd_pd_(softmax_bwd_pd) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_ += "-BWD";
}
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@softmax_p";
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax primitive in device context");
if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*(softmax_pd_.get()),
*(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p);
} else {
is_reusing_ = true;
}
return softmax_p;
}
std::shared_ptr<mkldnn::softmax_backward> AcquireSoftmaxBackward(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax backward primitive in device context");
if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *(dst_memory_p.get()), *(diff_dst_memory_p.get()),
*(diff_src_memory_p.get()));
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} else {
is_reusing_ = true;
}
return softmax_bwd_p;
}
private:
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd_;
};
template <typename T> template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
...@@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function const std::string key =
auto gethash = [](memory::dims& operand_dims) { platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
return std::string(std::to_string(operand_dims[0]) + "-" + const std::string key_softmax_pd = key + "@softmax_pd";
std::to_string(operand_dims[1]));
}; // Currently only NC data format is supported
const std::string key = gethash(softmax_tz); auto softmax_md = MKLDNNMemDesc(
const std::string key_softmax_p = key + "@softmax_p"; {softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
const std::string key_softmax_src_mem_p = key + "@softmax_src_mem_p"; // Normalization is made after innermost dimension eg. C out of NC
const std::string key_softmax_dst_mem_p = key + "@softmax_dst_mem_p"; auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
std::shared_ptr<void> softmax_p = dev_ctx.GetBlob(key_softmax_p); auto softmax_pd = std::make_shared<mkldnn::softmax_forward::primitive_desc>(
if (softmax_p == nullptr) { softmax_desc, mkldnn_engine);
// Currently only NC data format is supported dev_ctx.SetBlob(key_softmax_pd, softmax_pd);
auto softmax_md =
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc); SoftmaxMKLDNNHandler handler(softmax_pd, dev_ctx, mkldnn_engine, key);
// Normalization is made after innermost dimension eg. C out of NC auto softmax_src_memory_p =
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data));
softmax_md, 1 /*dim: C*/); auto softmax_dst_memory_p =
// create memory primitives handler.AcquireDstMemory(softmax_md, to_void_cast<T>(output_data));
auto softmax_src_memory_p = std::make_shared<memory>( auto softmax_p =
memory::primitive_desc{softmax_md, mkldnn_engine}, handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
static_cast<void*>(const_cast<T*>(input_data)));
dev_ctx.SetBlob(key_softmax_src_mem_p, softmax_src_memory_p);
auto softmax_dst_memory_p = std::make_shared<memory>(
memory::primitive_desc{softmax_md, mkldnn_engine},
static_cast<void*>(output_data));
dev_ctx.SetBlob(key_softmax_dst_mem_p, softmax_dst_memory_p);
auto softmax_forward_pd =
std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
mkldnn_engine);
softmax_p = std::make_shared<softmax_forward>(
*(softmax_forward_pd.get()),
*(static_cast<memory*>(softmax_src_memory_p.get())),
*(static_cast<memory*>(softmax_dst_memory_p.get())));
dev_ctx.SetBlob(key_softmax_p, softmax_p);
} else {
// Primitives already exist
auto src_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_softmax_src_mem_p));
PADDLE_ENFORCE(src_memory_p != nullptr,
"Fail to find softmax src mem_p in device context");
auto dst_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_softmax_dst_mem_p));
PADDLE_ENFORCE(dst_memory_p != nullptr,
"Fail to find softmax dst mem_p in device context");
src_memory_p->set_data_handle(
reinterpret_cast<void*>(const_cast<T*>(input_data)));
dst_memory_p->set_data_handle(output_data);
}
std::vector<primitive> pipeline{ std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))}; *(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
...@@ -120,6 +164,77 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -120,6 +164,77 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
}; };
template <typename T>
class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out");
const T* dst_data = output->data<T>();
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto* diff_dst_ptr = dout->template data<T>();
auto* dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
T* diff_src_ptr = dx->template mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
std::vector<int> src_tz(dst_tz);
PADDLE_ENFORCE(output->dims().size() == 2UL,
"The input of softmax op must be a 2D matrix.");
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
// we will make normalization after final eg. axis: 1
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
"Softmax input and output dimensions should match");
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";
auto softmax_pd =
std::static_pointer_cast<mkldnn::softmax_forward::primitive_desc>(
dev_ctx.GetBlob(key_softmax_pd));
PADDLE_ENFORCE(softmax_pd != nullptr,
"Fail to find softmax_pd in device context");
// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
auto data_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
auto diff_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_bwd_desc =
softmax_backward::desc(diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
auto softmax_bwd_pd =
std::make_shared<mkldnn::softmax_backward::primitive_desc>(
softmax_bwd_desc, mkldnn_engine, *softmax_pd);
SoftmaxMKLDNNHandler handler(softmax_pd, softmax_bwd_pd, dev_ctx,
mkldnn_engine, key);
auto dst_memory_p =
handler.AcquireDstMemory(data_softmax_md, to_void_cast<T>(dst_data));
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(
diff_softmax_md, to_void_cast<T>(diff_dst_ptr));
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(
diff_softmax_md, to_void_cast<T>(diff_src_ptr));
// Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
dst_memory_p, diff_dst_memory_p, diff_src_memory_p);
std::vector<primitive> pipeline{*softmax_bwd_p};
stream(stream::kind::eager).submit(pipeline).wait();
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -127,3 +242,5 @@ namespace ops = paddle::operators; ...@@ -127,3 +242,5 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNKernel<float>); ops::SoftmaxMKLDNNKernel<float>);
REGISTER_OP_KERNEL(softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNGradKernel<float>);
...@@ -145,16 +145,30 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -145,16 +145,30 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
std::string data_format = ctx.Attr<std::string>("data_format"); #ifdef PADDLE_WITH_MKLDNN
return framework::OpKernelType( if (library_ == framework::LibraryType::kPlain &&
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), platform::CanMKLDNNBeUsed(ctx)) {
framework::StringToDataLayout(data_format), library_); library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place");
}
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
} }
}; };
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*Licensed under the Apache License, Version 2.0(the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::CPUDeviceContext;
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using mkldnn::sum;
using mkldnn::reorder;
using platform::to_void_cast;
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X");
const int N = in_vars.size();
auto out_var = ctx.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
LoDTensor* output = ctx.Output<LoDTensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = framework::vectorize2int(output->dims());
auto src_tz = dst_tz;
memory::format output_format{memory::format::format_undef};
std::vector<float> scales;
std::vector<memory::primitive_desc> srcs_mpd;
std::vector<mkldnn::memory> srcs_mem;
PADDLE_ENFORCE(in_vars[0]->IsType<LoDTensor>(),
"Input[0] must be LoDTensors");
auto& input0 = in_vars[0]->Get<LoDTensor>();
PADDLE_ENFORCE(input0.layout() == DataLayout::kMKLDNN &&
input0.format() != memory::format::format_undef,
"Wrong layout/format for inputs[0]");
memory::format input_format = input0.format();
if (src_tz.size() == 1 && (input_format == memory::format::nchw ||
input_format == memory::format::nhwc)) {
input_format = memory::format::x;
}
if (src_tz.size() == 2 && (input_format == memory::format::nchw ||
input_format == memory::format::nhwc)) {
input_format = memory::format::nc;
}
for (int i = in_place ? 1 : 0; i < N; i++) {
PADDLE_ENFORCE(in_vars[i]->IsType<LoDTensor>(),
"all inputs must be all LoDTensors");
auto& input = in_vars[i]->Get<LoDTensor>();
PADDLE_ENFORCE(input.layout() == DataLayout::kMKLDNN &&
input.format() != memory::format::format_undef,
"Wrong layout/format for inputs");
if (input.numel() == 0) {
continue;
}
const T* input_data = input.data<T>();
auto src_md =
memory::desc(src_tz, memory::data_type::f32, input_format);
auto src_mpd = memory::primitive_desc(src_md, mkldnn_engine);
auto src_mem = memory(src_mpd, to_void_cast(input_data));
srcs_mpd.push_back(src_mpd);
srcs_mem.push_back(src_mem);
scales.push_back(1.0);
}
auto dst_md =
memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
std::shared_ptr<memory> dst_mem;
if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
} else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
}
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
inputs.push_back(srcs_mem[i]);
}
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
primitive reorder_prim;
std::shared_ptr<memory> target_mem;
if (in_place) {
output_format = input_format;
target_mem.reset(new memory(
{{{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine},
output_data));
reorder_prim = reorder(*dst_mem, *target_mem);
}
std::vector<primitive> pipeline;
pipeline.push_back(sum_prim);
if (in_place) pipeline.push_back(reorder_prim);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
} else if (out_var->IsType<framework::SelectedRows>()) {
// TODO(@mozga-intel) Add MKLDNN SelectedRows support
std::unique_ptr<framework::SelectedRows> in0;
if (in_place) {
// If is in_place, we store the input[0] to in0
auto& in_sel0 = in_vars[0]->Get<SelectedRows>();
auto& rows = in_sel0.rows();
in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
in0->mutable_value()->ShareDataWith(in_sel0.value());
}
auto get_selected_row = [&](size_t i) -> const SelectedRows& {
if (i == 0 && in0) {
return *in0.get();
} else {
return in_vars[i]->Get<SelectedRows>();
}
};
auto* out = ctx.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
auto* out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (int i = 0; i < N; i++) {
auto& sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
auto in_dim =
framework::vectorize(get_selected_row(N - 1).value().dims());
in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim));
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
}
out_value->mutable_data<T>(ctx.GetPlace());
math::SelectedRowsAddTo<CPUDeviceContext, T> functor;
int64_t offset = 0;
for (int i = 0; i < N; i++) {
auto& sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(ctx.template device_context<CPUDeviceContext>(), sel_row,
offset, out);
offset += sel_row.value().numel();
}
} else if (out_var->IsType<framework::LoDTensorArray>()) {
// TODO(@mozga-intel) Add MKLDNN LoDTensorArray support
auto& out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
"Only support all inputs are TensorArray");
auto& in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].numel() != 0) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
framework::TensorCopy(in_array[i], in_array[i].place(),
ctx.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
result.device(*ctx.template device_context<MKLDNNDeviceContext>()
.eigen_device()) = result + in;
}
}
}
}
} else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
out_var->Type().name());
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<float>);
...@@ -18,6 +18,10 @@ limitations under the License. */ ...@@ -18,6 +18,10 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
...@@ -63,6 +67,18 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -63,6 +67,18 @@ class SumOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X"); auto x_vars = ctx.MultiInputVar("X");
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1; int dtype = -1;
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
...@@ -80,26 +96,27 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -80,26 +96,27 @@ class SumOp : public framework::OperatorWithKernel {
"Sum operator should have at least one tensor"); "Sum operator should have at least one tensor");
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(dtype), static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
ctx.device_context()); layout, library);
} else if (x_vars[0]->IsType<framework::SelectedRows>()) { } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
for (auto& var : x_vars) { for (auto& var : x_vars) {
auto& value = var->Get<framework::SelectedRows>().value(); auto& value = var->Get<framework::SelectedRows>().value();
if (value.IsInitialized()) { if (value.IsInitialized()) {
return framework::OpKernelType(framework::ToDataType(value.type()), return framework::OpKernelType(framework::ToDataType(value.type()),
ctx.device_context()); ctx.device_context(), layout, library);
} }
} }
// if input sparse vars are not initialized, use an default kernel type. // if input sparse vars are not initialized, use an default kernel type.
return framework::OpKernelType(framework::proto::VarType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context()); ctx.device_context(), layout, library);
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) { } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
auto& array = x_var->Get<framework::LoDTensorArray>(); auto& array = x_var->Get<framework::LoDTensorArray>();
for (auto& each : array) { for (auto& each : array) {
if (each.numel() != 0) { if (each.numel() != 0) {
return framework::OpKernelType(framework::ToDataType(each.type()), return framework::OpKernelType(framework::ToDataType(each.type()),
ctx.device_context()); ctx.device_context(), layout,
library);
} }
} }
} }
...@@ -116,6 +133,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,6 +133,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X"); AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Sum operator. Sum operator.
...@@ -132,7 +152,6 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -132,7 +152,6 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarType::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name).GetType(); << block->FindRecursiveOrCreateVar(name).GetType();
...@@ -206,6 +225,7 @@ namespace ops = paddle::operators; ...@@ -206,6 +225,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker, REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
ops::SumOpVarTypeInference); ops::SumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>, sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
ops::SumKernel<paddle::platform::CPUDeviceContext, double>, ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/tensorrt_engine_op.h" #include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -16,10 +16,12 @@ ...@@ -16,10 +16,12 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -179,7 +179,6 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -179,7 +179,6 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
const std::string& z_name, bool x_created, const std::string& z_name, bool x_created,
const shape_t& x_shape, const shape_t& y_shape, const shape_t& x_shape, const shape_t& y_shape,
const shape_t& z_shape) { const shape_t& z_shape) {
LOG(INFO) << "create fc op"; LOG(INFO) << "create fc op";
auto* fc = block_desc.AppendOp(); auto* fc = block_desc.AppendOp();
fc->SetType("mul"); fc->SetType("mul");
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -37,11 +37,11 @@ USE_NO_KERNEL_OP(listen_and_serv); ...@@ -37,11 +37,11 @@ USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail; namespace distributed = paddle::operators::distributed;
namespace string = paddle::string; namespace string = paddle::string;
std::unique_ptr<detail::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
void StartServer() { void StartServer() {
f::Scope scope; f::Scope scope;
...@@ -57,14 +57,14 @@ void StartServer() { ...@@ -57,14 +57,14 @@ void StartServer() {
g_req_handler->SetProgram(&empty_program); g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor); g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get()); g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend); g_rpc_service->SetCond(distributed::kRequestSend);
g_rpc_service->WaitBarrier(detail::kRequestSend); g_rpc_service->WaitBarrier(distributed::kRequestSend);
LOG(INFO) << "got nccl id and stop server..."; LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown(); g_rpc_service->ShutDown();
...@@ -72,7 +72,7 @@ void StartServer() { ...@@ -72,7 +72,7 @@ void StartServer() {
} }
TEST(SendNcclId, RPCServer) { TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new detail::RequestSendHandler(true)); g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
...@@ -91,7 +91,8 @@ TEST(SendNcclId, RPCServer) { ...@@ -91,7 +91,8 @@ TEST(SendNcclId, RPCServer) {
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
LOG(INFO) << "connect to server" << ep; LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
......
...@@ -203,11 +203,11 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -203,11 +203,11 @@ class WhileGradOp : public framework::OperatorBase {
->set_lod(inside_tensor.lod()); ->set_lod(inside_tensor.lod());
} }
} }
auto new_inside_name = cur_scope.Rename(inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); {{"Out", {pg_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place); sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册