提交 edad60e6 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize-cpp-reader

...@@ -54,7 +54,8 @@ option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) ...@@ -54,7 +54,8 @@ option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON)
option(WITH_DOUBLE "Compile PaddlePaddle with double precision" OFF) option(WITH_DOUBLE "Compile PaddlePaddle with double precision" OFF)
option(WITH_RDMA "Compile PaddlePaddle with RDMA support" OFF) option(WITH_RDMA "Compile PaddlePaddle with RDMA support" OFF)
option(WITH_TIMER "Compile PaddlePaddle with stats timer" OFF) option(WITH_TIMER "Compile PaddlePaddle with stats timer" OFF)
option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler" OFF) option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF)
option(WITH_JEMALLOC "Compile PaddlePaddle with jemalloc" OFF)
option(WITH_DOC "Compile PaddlePaddle with documentation" OFF) option(WITH_DOC "Compile PaddlePaddle with documentation" OFF)
option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF)
option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
...@@ -65,6 +66,7 @@ option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) ...@@ -65,6 +66,7 @@ option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON) option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with distributed support" OFF) option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF) option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
...@@ -125,18 +127,12 @@ if(ANDROID OR IOS) ...@@ -125,18 +127,12 @@ if(ANDROID OR IOS)
add_definitions(-DPADDLE_MOBILE_INFERENCE) add_definitions(-DPADDLE_MOBILE_INFERENCE)
endif() endif()
if (APPLE OR WIN32) if (APPLE)
set(WITH_MKL OFF CACHE STRING set(WITH_MKL OFF CACHE STRING
"Disable MKL for building on mac and windows" FORCE) "Disable MKL for building on mac" FORCE)
endif() endif()
if (WIN32) if (WIN32)
set(WITH_AVX OFF CACHE STRING
"Disable AVX when compiling for Windows" FORCE)
set(WITH_DSO OFF CACHE STRING
"Disable DSO when compiling for Windows" FORCE)
set(WITH_MKL OFF CACHE STRING
"Disable MKL when compiling for Windows" FORCE)
set(WITH_DISTRIBUTE OFF CACHE STRING set(WITH_DISTRIBUTE OFF CACHE STRING
"Disable DISTRIBUTE when compiling for Windows" FORCE) "Disable DISTRIBUTE when compiling for Windows" FORCE)
set(WITH_C_API OFF CACHE STRING set(WITH_C_API OFF CACHE STRING
...@@ -209,14 +205,20 @@ include(external/xxhash) # download xxhash ...@@ -209,14 +205,20 @@ include(external/xxhash) # download xxhash
include(external/dlpack) include(external/dlpack)
include(external/snappy) # download snappy include(external/snappy) # download snappy
include(external/snappystream) # download snappystream include(external/snappystream) # download snappystream
include(external/warpctc) # download, build, install warpctc
if (NOT WIN32) if (NOT WIN32)
# there is no official support of warpctc, nccl, cupti in windows # there is no official support of nccl, cupti in windows
include(external/warpctc) # download, build, install warpctc
include(cupti) include(cupti)
include(external/gzstream) include(external/gzstream)
endif (NOT WIN32) endif (NOT WIN32)
if(WITH_PSLIB)
include(external/libmct)
include(external/pslib_brpc)
include(external/pslib)
endif(WITH_PSLIB)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
if(WITH_GRPC) if(WITH_GRPC)
include(external/grpc) include(external/grpc)
...@@ -254,6 +256,18 @@ elseif() ...@@ -254,6 +256,18 @@ elseif()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in MKL only now." FORCE) set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in MKL only now." FORCE)
endif() endif()
if (WITH_PROFILER)
find_package(Gperftools REQUIRED)
include_directories(${GPERFTOOLS_INCLUDE_DIR})
add_definitions(-DWITH_GPERFTOOLS)
endif()
if (WITH_JEMALLOC)
find_package(JeMalloc REQUIRED)
include_directories(${JEMALLOC_INCLUDE_DIR})
add_definitions(-DWITH_JEMALLOC)
endif()
include(generic) # simplify cmake module include(generic) # simplify cmake module
include(package) # set paddle packages include(package) # set paddle packages
include(ccache) # set ccache for compilation include(ccache) # set ccache for compilation
...@@ -278,6 +292,12 @@ set(EXTERNAL_LIBS ...@@ -278,6 +292,12 @@ set(EXTERNAL_LIBS
${PYTHON_LIBRARIES} ${PYTHON_LIBRARIES}
) )
if(WITH_PSLIB)
list(APPEND EXTERNAL_LIBS pslib)
list(APPEND EXTERNAL_LIBS pslib_brpc)
list(APPEND EXTERNAL_LIBS libmct)
endif(WITH_PSLIB)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
find_package(HIP) find_package(HIP)
include(hip) include(hip)
......
...@@ -94,52 +94,52 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 ...@@ -94,52 +94,52 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U # specify sphinx version as 1.5.6 and remove -U option for [pip install -U
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest # sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
# version(1.7.1 for now), which causes building documentation failed. # version(1.7.1 for now), which causes building documentation failed.
RUN pip3 install -U wheel && \ RUN pip3 --no-cache-dir install -U wheel && \
pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.6 install -U wheel && \ pip3.6 --no-cache-dir install -U wheel && \
pip3.6 install -U docopt PyYAML sphinx==1.5.6 && \ pip3.6 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.6 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3.6 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.7 install -U wheel && \ pip3.7 --no-cache-dir install -U wheel && \
pip3.7 install -U docopt PyYAML sphinx==1.5.6 && \ pip3.7 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.7 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3.7 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip install -U pip setuptools wheel && \ pip --no-cache-dir install -U pip setuptools wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \ pip --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark pip --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ RUN pip3 --no-cache-dir install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 --no-cache-dir install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3 install opencv-python && \ pip3 --no-cache-dir install opencv-python && \
pip3.6 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip3.6 --no-cache-dir install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3.6 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3.6 --no-cache-dir install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3.6 install opencv-python && \ pip3.6 --no-cache-dir install opencv-python && \
pip3.7 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip3.7 --no-cache-dir install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3.7 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3.7 --no-cache-dir install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3.7 install opencv-python && \ pip3.7 --no-cache-dir install opencv-python && \
pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip --no-cache-dir install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip --no-cache-dir install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python pip --no-cache-dir install opencv-python
#For docstring checker #For docstring checker
RUN pip3 install pylint pytest astroid isort RUN pip3 --no-cache-dir install pylint pytest astroid isort
RUN pip3.6 install pylint pytest astroid isort RUN pip3.6 --no-cache-dir install pylint pytest astroid isort
RUN pip3.7 install pylint pytest astroid isort RUN pip3.7 --no-cache-dir install pylint pytest astroid isort
RUN pip install pylint pytest astroid isort LinkChecker RUN pip --no-cache-dir install pylint pytest astroid isort LinkChecker
COPY ./python/requirements.txt /root/ COPY ./python/requirements.txt /root/
RUN pip3 install -r /root/requirements.txt RUN pip3 --no-cache-dir install -r /root/requirements.txt
RUN pip3.6 install -r /root/requirements.txt RUN pip3.6 --no-cache-dir install -r /root/requirements.txt
RUN pip3.7 install -r /root/requirements.txt RUN pip3.7 --no-cache-dir install -r /root/requirements.txt
RUN pip install -r /root/requirements.txt RUN pip --no-cache-dir install -r /root/requirements.txt
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
# the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 # the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2
RUN apt-get install -y libssl-dev libffi-dev RUN apt-get install -y libssl-dev libffi-dev && apt-get clean -y
RUN pip3 install certifi urllib3[secure] RUN pip3 --no-cache-dir install certifi urllib3[secure]
RUN pip3.6 install certifi urllib3[secure] RUN pip3.6 --no-cache-dir install certifi urllib3[secure]
RUN pip3.7 install certifi urllib3[secure] RUN pip3.7 --no-cache-dir install certifi urllib3[secure]
RUN pip install certifi urllib3[secure] RUN pip --no-cache-dir install certifi urllib3[secure]
# Install woboq_codebrowser to /woboq # Install woboq_codebrowser to /woboq
...@@ -149,6 +149,14 @@ RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \ ...@@ -149,6 +149,14 @@ RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \
-DCMAKE_BUILD_TYPE=Release . \ -DCMAKE_BUILD_TYPE=Release . \
make) make)
# ar mishandles 4GB files
# https://sourceware.org/bugzilla/show_bug.cgi?id=14625
# remove them when apt-get support 2.27 and higher version
RUN wget -q https://launchpad.net/ubuntu/+archive/primary/+sourcefiles/binutils/2.27-9ubuntu1/binutils_2.27.orig.tar.gz && \
tar -xzf binutils_2.27.orig.tar.gz && \
cd binutils-2.27 && \
./configure && make -j && make install && cd .. && rm -rf binutils-2.27 binutils_2.27.orig.tar.gz
# Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service # Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service
RUN mkdir /var/run/sshd RUN mkdir /var/run/sshd
RUN echo 'root:root' | chpasswd RUN echo 'root:root' | chpasswd
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) [![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
...@@ -19,7 +19,16 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. ...@@ -19,7 +19,16 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 1.1.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.1) 欢迎来到 PaddlePaddle GitHub
PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效灵活、可扩展的深度学习平台,最初由百度科学家和工程师共同开发,目的是将深度学习技术应用到百度的众多产品中。
我们的愿景是让每个人都能通过PaddlePaddle接触深度学习
跟进PaddlePaddle最新特性请参考我们的[版本说明](https://github.com/PaddlePaddle/Paddle/releases)
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Install Latest Stable Release: ### Install Latest Stable Release:
``` ```
# Linux CPU # Linux CPU
...@@ -27,13 +36,30 @@ pip install paddlepaddle ...@@ -27,13 +36,30 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7 # Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.1.0.post87 pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5 # Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.1.0.post85 pip install paddlepaddle-gpu==1.2.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/ # For installation on other platform, refer to http://paddlepaddle.org/
``` ```
### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### 安装最新稳定版本:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
# 其他平台上的安装指引请参考 http://paddlepaddle.org/
```
## Features ## Features
- **Flexibility** - **Flexibility**
...@@ -74,35 +100,90 @@ pip install paddlepaddle-gpu==1.1.0.post85 ...@@ -74,35 +100,90 @@ pip install paddlepaddle-gpu==1.1.0.post85
Baidu and it has achieved a significant impact. We hope you can also explore Baidu and it has achieved a significant impact. We hope you can also explore
the capability of PaddlePaddle to make an impact on your product. the capability of PaddlePaddle to make an impact on your product.
## 特点
- **灵活性**
PaddlePaddle支持丰富的神经网络架构和优化算法。易于配置复杂模型,例如带有注意力机制或复杂记忆连接的神经网络机器翻译模型。
- **高效性**
为了高效使用异步计算资源,PaddlePaddle对框架的不同层进行优化,包括计算、存储、架构和通信。下面是一些样例:
- 通过SSE/AVX 内置函数、BLAS库(例如MKL、OpenBLAS、cuBLAS)或定制的CPU/GPU内核优化数学操作。
- 通过MKL-DNN库优化CNN网络
- 高度优化循环网络,无需执行 `padding` 操作即可处理 **变长** 序列
- 针对高维稀疏数据模型,优化了局部和分布式训练。
- **稳定性**
有了 PaddlePaddle,使得利用各种CPU/GPU和机器来加速训练变得简单。PaddlePaddle 通过优化通信可以实现巨大吞吐量和快速执行。
- **连接产品**
另外,PaddlePaddle 的设计也易于部署。在百度,PaddlePaddle 已经部署到含有巨大用户量的产品和服务上,包括广告点击率(CTR)预测、大规模图像分类、光学字符识别(OCR)、搜索排序,计算机病毒检测、推荐系统等等。PaddlePaddle广泛应用于百度产品中,产生了非常重要的影响。我们希望您也能探索 PaddlePaddle 的能力,为您的产品创造新的影响力和效果。
## Installation ## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) on our website. It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website.
## 安装
推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html)
## Documentation ## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) and We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) documentation. [Chinese](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book) - [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook. You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.1/user_guides/howto/training/cluster_howto.html) - [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters. You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/api/zh/1.1/fluid.html) - [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
Our new API enables much shorter programs. Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.1/advanced_usage/development/contribute_to_paddle.html) - [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
We appreciate your contributions! We appreciate your contributions!
## 文档
我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[中文](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) 文档
- [深度学习101](https://github.com/PaddlePaddle/book)
或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
可以在MPI集群上运行分布式训练任务
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
新的API支持代码更少更简洁的程序
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
欢迎您的贡献!
## Ask Questions ## Ask Questions
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues). You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
## 答疑
欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交
## Copyright and License ## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE). PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供
...@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog): ...@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = distribute_transpiler.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config) t = distribute_transpiler.DistributeTranspiler(config=config)
t.transpile( t.transpile(
trainer_id, trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to # NOTE: *MUST* use train_prog, for we are using with guard to
......
# Tries to find Gperftools.
#
# Usage of this module as follows:
#
# find_package(Gperftools)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# Gperftools_ROOT_DIR Set this variable to the root installation of
# Gperftools if the module has problems finding
# the proper installation path.
#
# Variables defined by this module:
#
# GPERFTOOLS_FOUND System has Gperftools libs/headers
# GPERFTOOLS_LIBRARIES The Gperftools libraries (tcmalloc & profiler)
# GPERFTOOLS_INCLUDE_DIR The location of Gperftools headers
find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_PROFILER
NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_TCMALLOC_AND_PROFILER
NAMES tcmalloc_and_profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
find_path(GPERFTOOLS_INCLUDE_DIR
NAMES gperftools/heap-profiler.h
HINTS ${Gperftools_ROOT_DIR}/include)
set(GPERFTOOLS_LIBRARIES ${GPERFTOOLS_TCMALLOC_AND_PROFILER})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
Gperftools
DEFAULT_MSG
GPERFTOOLS_LIBRARIES
GPERFTOOLS_INCLUDE_DIR)
mark_as_advanced(
Gperftools_ROOT_DIR
GPERFTOOLS_TCMALLOC
GPERFTOOLS_PROFILER
GPERFTOOLS_TCMALLOC_AND_PROFILER
GPERFTOOLS_LIBRARIES
GPERFTOOLS_INCLUDE_DIR)
# create IMPORTED targets
if (Gperftools_FOUND AND NOT TARGET gperftools::tcmalloc)
add_library(gperftools::tcmalloc UNKNOWN IMPORTED)
set_target_properties(gperftools::tcmalloc PROPERTIES
IMPORTED_LOCATION ${GPERFTOOLS_TCMALLOC}
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
add_library(gperftools::profiler UNKNOWN IMPORTED)
set_target_properties(gperftools::profiler PROPERTIES
IMPORTED_LOCATION ${GPERFTOOLS_PROFILER}
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
endif()
# - Find JeMalloc library
# Find the native JeMalloc includes and library
#
# JEMALLOC_INCLUDE_DIR - where to find jemalloc.h, etc.
# JEMALLOC_LIBRARIES - List of libraries when using jemalloc.
# JEMALLOC_FOUND - True if jemalloc found.
find_path(JEMALLOC_INCLUDE_DIR
NAMES jemalloc/jemalloc.h
HINTS ${JEMALLOC_ROOT_DIR}/include)
find_library(JEMALLOC_LIBRARIES
NAMES jemalloc
HINTS ${JEMALLOC_ROOT_DIR}/lib)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALLOC_INCLUDE_DIR)
mark_as_advanced(
JEMALLOC_LIBRARIES
JEMALLOC_INCLUDE_DIR)
if (JEMALLOC_FOUND)
add_library(jemalloc::jemalloc UNKNOWN IMPORTED)
set_target_properties(jemalloc::jemalloc PROPERTIES
IMPORTED_LOCATION ${JEMALLOC_LIBRARIES}
INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}")
endif()
...@@ -84,8 +84,13 @@ if(NOT WITH_GOLANG) ...@@ -84,8 +84,13 @@ if(NOT WITH_GOLANG)
add_definitions(-DPADDLE_WITHOUT_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG)
endif(NOT WITH_GOLANG) endif(NOT WITH_GOLANG)
if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB)
endif()
if(WITH_GPU) if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
FIND_PACKAGE(CUDA REQUIRED) FIND_PACKAGE(CUDA REQUIRED)
...@@ -129,6 +134,7 @@ if(WITH_GPU) ...@@ -129,6 +134,7 @@ if(WITH_GPU)
message(WARNING "Anakin needs CUDNN >= 7.0 to compile. Force WITH_ANAKIN=OFF") message(WARNING "Anakin needs CUDNN >= 7.0 to compile. Force WITH_ANAKIN=OFF")
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when CUDNN >= 7.0." FORCE) set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when CUDNN >= 7.0." FORCE)
endif() endif()
add_definitions(-DWITH_ANAKIN)
endif() endif()
if(WITH_ANAKIN) if(WITH_ANAKIN)
# NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR # NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR
......
...@@ -5,6 +5,8 @@ endif() ...@@ -5,6 +5,8 @@ endif()
set(paddle_known_gpu_archs "30 35 50 52 60 61 70") set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52") set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 60 61") set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
###################################################################################### ######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled) # A function for automatic detection of GPUs installed (if autodetection is enabled)
...@@ -59,7 +61,7 @@ endfunction() ...@@ -59,7 +61,7 @@ endfunction()
# select_nvcc_arch_flags(out_variable) # select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable) function(select_nvcc_arch_flags out_variable)
# List of arch names # List of arch names
set(archs_names "Kepler" "Maxwell" "Pascal" "All" "Manual") set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "All" "Manual")
set(archs_name_default "All") set(archs_name_default "All")
if(NOT CMAKE_CROSSCOMPILING) if(NOT CMAKE_CROSSCOMPILING)
list(APPEND archs_names "Auto") list(APPEND archs_names "Auto")
...@@ -93,6 +95,8 @@ function(select_nvcc_arch_flags out_variable) ...@@ -93,6 +95,8 @@ function(select_nvcc_arch_flags out_variable)
set(cuda_arch_bin "60 61") set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta") elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
set(cuda_arch_bin "70") set(cuda_arch_bin "70")
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "All") elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs}) set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto") elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
...@@ -139,10 +143,12 @@ endfunction() ...@@ -139,10 +143,12 @@ endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION}) message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0) if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
add_definitions("-DPADDLE_CUDA_BINVER=\"60\"")
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"70\"")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
...@@ -150,6 +156,17 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x ...@@ -150,6 +156,17 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the # CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now. # warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif() endif()
include_directories(${CUDA_INCLUDE_DIRS}) include_directories(${CUDA_INCLUDE_DIRS})
......
...@@ -44,9 +44,9 @@ if(WIN32) ...@@ -44,9 +44,9 @@ if(WIN32)
set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll") set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll")
endif(WIN32) endif(WIN32)
if(Apple) if(APPLE)
set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so") set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so")
endif(Apple) endif(APPLE)
find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a
PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist} PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist}
...@@ -89,6 +89,7 @@ if(CUDNN_FOUND) ...@@ -89,6 +89,7 @@ if(CUDNN_FOUND)
if(NOT CUDNN_MAJOR_VERSION) if(NOT CUDNN_MAJOR_VERSION)
set(CUDNN_VERSION "???") set(CUDNN_VERSION "???")
else() else()
add_definitions("-DPADDLE_CUDNN_BINVER=\"${CUDNN_MAJOR_VERSION}\"")
math(EXPR CUDNN_VERSION math(EXPR CUDNN_VERSION
"${CUDNN_MAJOR_VERSION} * 1000 + "${CUDNN_MAJOR_VERSION} * 1000 +
${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}") ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}")
......
...@@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost") ...@@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost")
# checked that the devtools package of CentOS 6 installs boost 1.41.0. # checked that the devtools package of CentOS 6 installs boost 1.41.0.
# So we use 1.41.0 here. # So we use 1.41.0 here.
set(BOOST_VER "1.41.0") set(BOOST_VER "1.41.0")
if((NOT DEFINED BOOST_TAR) OR (NOT DEFINED BOOST_URL)) set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
message(STATUS "use pre defined download url") set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
endif()
MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}")
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
find_library(SSL_LIBRARY NAMES ssl) find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL) ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY}) SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})
find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL) ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY}) SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc) SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc) SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
...@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr ...@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args # Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib") set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF # If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add( ExternalProject_Add(
extern_brpc extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc" GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a" GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR} PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
...@@ -50,7 +53,7 @@ ExternalProject_Add( ...@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path} -DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON -DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON -DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA} -DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
...@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL) ...@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES}) SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc) ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc) LIST(APPEND external_project_dependencies brpc)
...@@ -32,4 +32,4 @@ endif() ...@@ -32,4 +32,4 @@ endif()
add_dependencies(cub extern_cub) add_dependencies(cub extern_cub)
LIST(APPEND externl_project_dependencies cub) LIST(APPEND external_project_dependencies cub)
...@@ -28,4 +28,4 @@ endif() ...@@ -28,4 +28,4 @@ endif()
add_dependencies(dlpack extern_dlpack) add_dependencies(dlpack extern_dlpack)
LIST(APPEND externl_project_dependencies dlpack) LIST(APPEND external_project_dependencies dlpack)
...@@ -12,8 +12,12 @@ ...@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
IF(WITH_TESTING) #FIXME:(gongwb) Move brpc's gtest dependency.
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING() ENABLE_TESTING()
ENDIF(WITH_TESTING)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest) SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
...@@ -76,4 +80,4 @@ IF(WITH_TESTING) ...@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest) ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main) LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING) ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
...@@ -24,8 +24,8 @@ ExternalProject_Add( ...@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR} PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz" GIT_REPOSITORY "https://github.com/google/leveldb"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0" GIT_TAG v1.18
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/ INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_LIBMCT})
return()
ENDIF(NOT ${WITH_LIBMCT})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with LIBMCT in Paddle yet."
"Force WITH_LIBMCT=OFF")
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(LIBMCT_PROJECT "extern_libmct")
IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL))
MESSAGE(STATUS "use pre defined download url")
SET(LIBMCT_VER "0.1.0" CACHE STRING "" FORCE)
SET(LIBMCT_NAME "libmct" CACHE STRING "" FORCE)
SET(LIBMCT_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${LIBMCT_VER}/${LIBMCT_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "LIBMCT_NAME: ${LIBMCT_NAME}, LIBMCT_URL: ${LIBMCT_URL}")
SET(LIBMCT_SOURCE_DIR "${THIRD_PARTY_PATH}/libmct")
SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_SOURCE_DIR}/src/${LIBMCT_PROJECT}")
SET(LIBMCT_DST_DIR "libmct")
SET(LIBMCT_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(LIBMCT_INSTALL_DIR ${LIBMCT_INSTALL_ROOT}/${LIBMCT_DST_DIR})
SET(LIBMCT_ROOT ${LIBMCT_INSTALL_DIR})
SET(LIBMCT_INC_DIR ${LIBMCT_ROOT}/include)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${LIBMCT_ROOT}/lib")
INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})
FILE(WRITE ${LIBMCT_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(LIBMCT)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${LIBMCT_NAME}/include ${LIBMCT_NAME}/lib \n"
" DESTINATION ${LIBMCT_DST_DIR})\n")
ExternalProject_Add(
${LIBMCT_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LIBMCT_SOURCE_DIR}
DOWNLOAD_DIR ${LIBMCT_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_NAME}.tar.gz
&& tar zxvf ${LIBMCT_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${LIBMCT_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${LIBMCT_INSTALL_ROOT}
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32)
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c)
file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
add_library(libmct STATIC ${dummyfile})
else()
add_library(libmct INTERFACE)
endif()
#ADD_LIBRARY(libmct SHARED IMPORTED GLOBAL)
ADD_DEPENDENCIES(libmct ${LIBMCT_PROJECT})
LIST(APPEND external_project_dependencies libmct)
...@@ -23,15 +23,14 @@ SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn) ...@@ -23,15 +23,14 @@ SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
IF(WIN32 OR APPLE) IF(APPLE)
MESSAGE(WARNING MESSAGE(WARNING
"Windows or Mac is not supported with MKLDNN in Paddle yet." "Mac is not supported with MKLDNN in Paddle yet."
"Force WITH_MKLDNN=OFF") "Force WITH_MKLDNN=OFF")
SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in Windows and MacOS" FORCE) SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in MacOS" FORCE)
return() return()
ENDIF() ENDIF()
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
...@@ -44,22 +43,33 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -44,22 +43,33 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") IF(NOT WIN32)
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
ENDIF(NOT WIN32)
ExternalProject_Add( ExternalProject_Add(
${MKLDNN_PROJECT} ${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/intel/mkl-dnn.git"
GIT_TAG "830a10059a018cd2634d94195140cf2d8790a75a" GIT_TAG "830a10059a018cd2634d94195140cf2d8790a75a"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
CMAKE_ARGS -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
CMAKE_ARGS -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
CMAKE_ARGS -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
CMAKE_ARGS -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
CMAKE_ARGS -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON
CMAKE_ARGS -DMKLROOT=${MKLML_ROOT} CMAKE_ARGS -DMKLROOT=${MKLML_ROOT}
CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG} CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
...@@ -67,6 +77,11 @@ ExternalProject_Add( ...@@ -67,6 +77,11 @@ ExternalProject_Add(
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLML_ROOT} -DMKLROOT:PATH=${MKLML_ROOT}
) )
if(WIN32)
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/mkldnn.lib" CACHE FILEPATH "mkldnn library." FORCE)
else(WIN32)
SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE)
endif(WIN32)
ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL) ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB}) SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
...@@ -85,12 +100,16 @@ ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) ...@@ -85,12 +100,16 @@ ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
# copy the real so.0 lib to install dir # copy the real so.0 lib to install dir
# it can be directly contained in wheel or capi # it can be directly contained in wheel or capi
SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0) if(WIN32)
ADD_CUSTOM_COMMAND(OUTPUT ${MKLDNN_SHARED_LIB} SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/lib/mkldnn.dll)
COMMAND cp ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB} else(WIN32)
DEPENDS mkldnn) SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0)
ADD_CUSTOM_COMMAND(OUTPUT ${MKLDNN_SHARED_LIB}
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB}
DEPENDS mkldnn shared_mkldnn)
endif(WIN32)
ADD_CUSTOM_TARGET(mkldnn_shared_lib ALL DEPENDS ${MKLDNN_SHARED_LIB}) ADD_CUSTOM_TARGET(mkldnn_shared_lib ALL DEPENDS ${MKLDNN_SHARED_LIB})
ADD_DEPENDENCIES(mkldnn_shared_lib ${MKLDNN_PROJECT} mkldnn)
IF(WITH_C_API) IF(WITH_C_API)
INSTALL(FILES ${MKLDNN_SHARED_LIB} DESTINATION lib) INSTALL(FILES ${MKLDNN_SHARED_LIB} DESTINATION lib)
ENDIF() ENDIF()
......
...@@ -16,56 +16,60 @@ IF(NOT ${WITH_MKLML}) ...@@ -16,56 +16,60 @@ IF(NOT ${WITH_MKLML})
return() return()
ENDIF(NOT ${WITH_MKLML}) ENDIF(NOT ${WITH_MKLML})
IF(WIN32 OR APPLE) IF(APPLE)
MESSAGE(WARNING MESSAGE(WARNING "Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF.")
"Windows or Mac is not supported with MKLML in Paddle yet." SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in MacOS" FORCE)
"Force WITH_MKLML=OFF")
SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE)
return() return()
ENDIF() ENDIF()
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml")
IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL))
MESSAGE(STATUS "use pre defined download url")
SET(MKLML_VER "mklml_lnx_2019.0.20180710" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
SET(MKLML_DST_DIR "mklml") SET(MKLML_DST_DIR "mklml")
SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR})
SET(MKLML_INC_DIR ${MKLML_ROOT}/include) SET(MKLML_INC_DIR ${MKLML_ROOT}/include)
SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
INCLUDE_DIRECTORIES(${MKLML_INC_DIR}) SET(TIME_VERSION "2019.0.1.20181227")
IF(WIN32)
SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE()
SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
ENDIF()
FILE(WRITE ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt SET(MKLML_PROJECT "extern_mklml")
"PROJECT(MKLML)\n" MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}")
"cmake_minimum_required(VERSION 3.0)\n" SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
"install(DIRECTORY ${MKLML_VER}/include ${MKLML_VER}/lib \n" SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
" DESTINATION ${MKLML_DST_DIR})\n")
ExternalProject_Add( ExternalProject_Add(
${MKLML_PROJECT} ${MKLML_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${MKLML_SOURCE_DIR} PREFIX ${MKLML_SOURCE_DIR}
URL ${MKLML_URL}
DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR} DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${MKLML_URL} -c -q -O ${MKLML_VER}.tgz
&& tar zxf ${MKLML_VER}.tgz
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT} INSTALL_COMMAND
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLML_INSTALL_ROOT} ${CMAKE_COMMAND} -E copy_directory ${MKLML_DOWNLOAD_DIR}/include ${MKLML_INC_DIR} &&
${CMAKE_COMMAND} -E copy_directory ${MKLML_DOWNLOAD_DIR}/lib ${MKLML_LIB_DIR}
) )
INCLUDE_DIRECTORIES(${MKLML_INC_DIR})
ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL) ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB}) SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB})
ADD_DEPENDENCIES(mklml ${MKLML_PROJECT}) ADD_DEPENDENCIES(mklml ${MKLML_PROJECT})
......
...@@ -32,18 +32,23 @@ IF(NOT ${WITH_NGRAPH}) ...@@ -32,18 +32,23 @@ IF(NOT ${WITH_NGRAPH})
return() return()
ENDIF() ENDIF()
INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph") SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_VERSION "0.9") SET(NGRAPH_GIT_TAG "08851c2c45fcf9fa9c74871dd3dbc3fe38f37cc9")
SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION}) SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
SET(NGRAPH_SHARED_LIB_NAME libngraph.so)
SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so) SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so)
SET(NGRAPH_TBB_LIB_NAME libtbb.so.2) SET(NGRAPH_TBB_LIB_NAME libtbb.so.2)
SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git") SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
ExternalProject_Add( ExternalProject_Add(
${NGRAPH_PROJECT} ${NGRAPH_PROJECT}
...@@ -63,18 +68,6 @@ ExternalProject_Add( ...@@ -63,18 +68,6 @@ ExternalProject_Add(
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib
) )
if(UNIX AND NOT APPLE)
include(GNUInstallDirs)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
else()
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/lib)
endif()
MESSAGE(STATUS "nGraph lib will be installed at: ${NGRAPH_LIB_DIR}")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
# Workaround for nGraph expecting mklml to be in mkldnn install directory. # Workaround for nGraph expecting mklml to be in mkldnn install directory.
ExternalProject_Add_Step( ExternalProject_Add_Step(
${NGRAPH_PROJECT} ${NGRAPH_PROJECT}
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_PSLIB})
return()
ENDIF(NOT ${WITH_PSLIB})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB in Paddle yet."
"Force WITH_PSLIB=OFF")
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}")
SET(PSLIB_DST_DIR "pslib")
SET(PSLIB_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(PSLIB_INSTALL_DIR ${PSLIB_INSTALL_ROOT}/${PSLIB_DST_DIR})
SET(PSLIB_ROOT ${PSLIB_INSTALL_DIR})
SET(PSLIB_INC_DIR ${PSLIB_ROOT}/include)
SET(PSLIB_LIB_DIR ${PSLIB_ROOT}/lib)
SET(PSLIB_LIB ${PSLIB_LIB_DIR}/libps.so)
SET(PSLIB_IOMP_LIB ${PSLIB_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_ROOT}/lib")
INCLUDE_DIRECTORIES(${PSLIB_INC_DIR})
FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_NAME}/include ${PSLIB_NAME}/lib \n"
" DESTINATION ${PSLIB_DST_DIR})\n")
ExternalProject_Add(
${PSLIB_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
&& tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_INSTALL_ROOT}
)
ADD_LIBRARY(pslib SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET pslib PROPERTY IMPORTED_LOCATION ${PSLIB_LIB})
ADD_DEPENDENCIES(pslib ${PSLIB_PROJECT})
LIST(APPEND external_project_dependencies pslib)
IF(WITH_C_API)
INSTALL(FILES ${PSLIB_LIB} ${PSLIB_IOMP_LIB} DESTINATION lib)
ENDIF()
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_PSLIB_BRPC})
return()
ENDIF(NOT ${WITH_PSLIB_BRPC})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
"Force WITH_PSLIB_BRPC=OFF")
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(PSLIB_BRPC_PROJECT "extern_pslib_brpc")
IF((NOT DEFINED PSLIB_BRPC_NAME) OR (NOT DEFINED PSLIB_BRPC_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_BRPC_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_BRPC_NAME "pslib_brpc" CACHE STRING "" FORCE)
SET(PSLIB_BRPC_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_BRPC_VER}/${PSLIB_BRPC_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "PSLIB_BRPC_NAME: ${PSLIB_BRPC_NAME}, PSLIB_BRPC_URL: ${PSLIB_BRPC_URL}")
SET(PSLIB_BRPC_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib_brpc")
SET(PSLIB_BRPC_DOWNLOAD_DIR "${PSLIB_BRPC_SOURCE_DIR}/src/${PSLIB_BRPC_PROJECT}")
SET(PSLIB_BRPC_DST_DIR "pslib_brpc")
SET(PSLIB_BRPC_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(PSLIB_BRPC_INSTALL_DIR ${PSLIB_BRPC_INSTALL_ROOT}/${PSLIB_BRPC_DST_DIR})
SET(PSLIB_BRPC_ROOT ${PSLIB_BRPC_INSTALL_DIR})
SET(PSLIB_BRPC_INC_DIR ${PSLIB_BRPC_ROOT}/include)
SET(PSLIB_BRPC_LIB_DIR ${PSLIB_BRPC_ROOT}/lib)
SET(PSLIB_BRPC_LIB ${PSLIB_BRPC_LIB_DIR}/libbrpc.a)
SET(PSLIB_BRPC_IOMP_LIB ${PSLIB_BRPC_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_BRPC_ROOT}/lib")
INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})
FILE(WRITE ${PSLIB_BRPC_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB_BRPC)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_BRPC_NAME}/include ${PSLIB_BRPC_NAME}/lib \n"
" DESTINATION ${PSLIB_BRPC_DST_DIR})\n")
ExternalProject_Add(
${PSLIB_BRPC_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_BRPC_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_NAME}.tar.gz
&& tar zxvf ${PSLIB_BRPC_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_BRPC_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_BRPC_INSTALL_ROOT}
)
ADD_LIBRARY(pslib_brpc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET pslib_brpc PROPERTY IMPORTED_LOCATION ${PSLIB_BRPC_LIB})
ADD_DEPENDENCIES(pslib_brpc ${PSLIB_BRPC_PROJECT})
LIST(APPEND external_project_dependencies pslib_brpc)
IF(WITH_C_API)
INSTALL(FILES ${PSLIB_BRPC_LIB} ${PSLIB_BRPC_IOMP_LIB} DESTINATION lib)
ENDIF()
...@@ -18,8 +18,8 @@ ENDIF() ...@@ -18,8 +18,8 @@ ENDIF()
INCLUDE(python_module) INCLUDE(python_module)
FIND_PACKAGE(PythonInterp ${PY_VERSION}) FIND_PACKAGE(PythonInterp ${PY_VERSION} REQUIRED)
FIND_PACKAGE(PythonLibs ${PY_VERSION}) FIND_PACKAGE(PythonLibs ${PY_VERSION} REQUIRED)
if(WIN32) if(WIN32)
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
...@@ -79,6 +79,5 @@ IF(PYTHONINTERP_FOUND) ...@@ -79,6 +79,5 @@ IF(PYTHONINTERP_FOUND)
"please use pip to upgrade protobuf. pip install -U protobuf") "please use pip to upgrade protobuf. pip install -U protobuf")
ENDIF() ENDIF()
ENDIF(PYTHONINTERP_FOUND) ENDIF(PYTHONINTERP_FOUND)
INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR})
INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR})
...@@ -24,12 +24,6 @@ set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy) ...@@ -24,12 +24,6 @@ set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy) set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE) set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE)
if (WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
ExternalProject_Add( ExternalProject_Add(
extern_snappy extern_snappy
GIT_REPOSITORY "https://github.com/google/snappy" GIT_REPOSITORY "https://github.com/google/snappy"
...@@ -56,6 +50,16 @@ ExternalProject_Add( ...@@ -56,6 +50,16 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32)
IF(NOT EXISTS "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
add_custom_command(TARGET extern_snappy POST_BUILD
COMMAND cmake -E copy ${SNAPPY_INSTALL_DIR}/lib/snappy.lib ${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib
)
ENDIF()
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
add_library(snappy STATIC IMPORTED GLOBAL) add_library(snappy STATIC IMPORTED GLOBAL)
set_property(TARGET snappy PROPERTY IMPORTED_LOCATION ${SNAPPY_LIBRARIES}) set_property(TARGET snappy PROPERTY IMPORTED_LOCATION ${SNAPPY_LIBRARIES})
......
...@@ -26,25 +26,33 @@ SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" ...@@ -26,25 +26,33 @@ SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
# Used in unit test test_WarpCTCLayer # Used in unit test test_WarpCTCLayer
SET(WARPCTC_LIB_DIR "${WARPCTC_INSTALL_DIR}/lib" SET(WARPCTC_LIB_DIR "${WARPCTC_INSTALL_DIR}/lib"
CACHE PATH "Warp-ctc Library Directory" FORCE) CACHE PATH "Warp-ctc Library Directory" FORCE)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
IF(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" ) IF(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR WIN32)
SET(USE_OMP OFF) SET(USE_OMP OFF)
ELSE() ELSE()
SET(USE_OMP ON) SET(USE_OMP ON)
ENDIF() ENDIF()
IF(WIN32)
SET(WARPCTC_REPOSITORY "https://github.com/wopeizl/warp-ctc.git")
ELSE()
SET(WARPCTC_REPOSITORY "https://github.com/dzhwinter/warp-ctc.git")
ENDIF()
ExternalProject_Add( ExternalProject_Add(
extern_warpctc extern_warpctc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/dzhwinter/warp-ctc.git" GIT_REPOSITORY ${WARPCTC_REPOSITORY}
PREFIX ${WARPCTC_SOURCES_DIR} PREFIX ${WARPCTC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU} -DWITH_GPU=${WITH_GPU}
-DWITH_OMP=${USE_OMP} -DWITH_OMP=${USE_OMP}
...@@ -59,6 +67,18 @@ ExternalProject_Add( ...@@ -59,6 +67,18 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
) )
IF(WIN32)
IF(NOT EXISTS "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}")
add_custom_command(TARGET extern_warpctc POST_BUILD
COMMAND cmake -E copy ${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} ${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}
)
ENDIF()
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
else(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
ENDIF(WIN32)
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}") MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers. INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers.
......
...@@ -56,7 +56,12 @@ else() ...@@ -56,7 +56,12 @@ else()
endif() endif()
if (WIN32) if (WIN32)
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/xxhash.lib") IF(NOT EXISTS "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
add_custom_command(TARGET extern_xxhash POST_BUILD
COMMAND cmake -E copy ${XXHASH_INSTALL_DIR}/lib/xxhash.lib ${XXHASH_INSTALL_DIR}/lib/libxxhash.lib
)
ENDIF()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
else() else()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
endif () endif ()
......
...@@ -19,12 +19,6 @@ SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib) ...@@ -19,12 +19,6 @@ SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib)
SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE) SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE)
SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE) SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE)
IF(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${ZLIB_INCLUDE_DIR}) # For zlib code to include its own headers. INCLUDE_DIRECTORIES(${ZLIB_INCLUDE_DIR}) # For zlib code to include its own headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include zlib.h. INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include zlib.h.
...@@ -49,6 +43,16 @@ ExternalProject_Add( ...@@ -49,6 +43,16 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32)
IF(NOT EXISTS "${ZLIB_INSTALL_DIR}/lib/libz.lib")
add_custom_command(TARGET extern_zlib POST_BUILD
COMMAND cmake -E copy ${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib ${ZLIB_INSTALL_DIR}/lib/libz.lib
)
ENDIF()
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL) ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET zlib PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES}) SET_PROPERTY(TARGET zlib PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES})
......
...@@ -110,6 +110,18 @@ function(find_fluid_modules TARGET_NAME) ...@@ -110,6 +110,18 @@ function(find_fluid_modules TARGET_NAME)
endif() endif()
endfunction(find_fluid_modules) endfunction(find_fluid_modules)
function(common_link TARGET_NAME)
if (WITH_PROFILER)
target_link_libraries(${TARGET_NAME} gperftools::profiler)
endif()
if (WITH_JEMALLOC)
target_link_libraries(${TARGET_NAME} jemalloc::jemalloc)
endif()
endfunction()
# find all third_party modules is used for paddle static library # find all third_party modules is used for paddle static library
# for reduce the dependency when building the inference libs. # for reduce the dependency when building the inference libs.
set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY) set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY)
...@@ -259,7 +271,11 @@ function(cc_library TARGET_NAME) ...@@ -259,7 +271,11 @@ function(cc_library TARGET_NAME)
list(APPEND cc_library_DEPS dynload_mklml) list(APPEND cc_library_DEPS dynload_mklml)
endif() endif()
add_dependencies(${TARGET_NAME} mklml) add_dependencies(${TARGET_NAME} mklml)
if(WIN32)
target_link_libraries(${TARGET_NAME} ${MKLML_IOMP_LIB})
else(WIN32)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed") target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif(WIN32)
endif() endif()
# remove link to python, see notes at: # remove link to python, see notes at:
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually # https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
...@@ -274,6 +290,7 @@ function(cc_library TARGET_NAME) ...@@ -274,6 +290,7 @@ function(cc_library TARGET_NAME)
endif() endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
common_link(${TARGET_NAME})
endif() endif()
# cpplint code style # cpplint code style
...@@ -340,6 +357,7 @@ function(cc_binary TARGET_NAME) ...@@ -340,6 +357,7 @@ function(cc_binary TARGET_NAME)
if(cc_binary_DEPS) if(cc_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${cc_binary_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_binary_DEPS})
add_dependencies(${TARGET_NAME} ${cc_binary_DEPS}) add_dependencies(${TARGET_NAME} ${cc_binary_DEPS})
common_link(${TARGET_NAME})
endif() endif()
endfunction(cc_binary) endfunction(cc_binary)
...@@ -362,6 +380,7 @@ function(cc_test TARGET_NAME) ...@@ -362,6 +380,7 @@ function(cc_test TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${win32_deps}) target_link_libraries(${TARGET_NAME} ${win32_deps})
endif(WIN32) endif(WIN32)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS} COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
...@@ -420,6 +439,7 @@ function(nv_binary TARGET_NAME) ...@@ -420,6 +439,7 @@ function(nv_binary TARGET_NAME)
if(nv_binary_DEPS) if(nv_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS}) target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS})
add_dependencies(${TARGET_NAME} ${nv_binary_DEPS}) add_dependencies(${TARGET_NAME} ${nv_binary_DEPS})
common_link(${TARGET_NAME})
endif() endif()
endif() endif()
endfunction(nv_binary) endfunction(nv_binary)
...@@ -433,6 +453,7 @@ function(nv_test TARGET_NAME) ...@@ -433,6 +453,7 @@ function(nv_test TARGET_NAME)
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL) if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
...@@ -499,6 +520,7 @@ function(hip_binary TARGET_NAME) ...@@ -499,6 +520,7 @@ function(hip_binary TARGET_NAME)
if(hip_binary_DEPS) if(hip_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS}) target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS})
add_dependencies(${TARGET_NAME} ${hip_binary_DEPS}) add_dependencies(${TARGET_NAME} ${hip_binary_DEPS})
common_link(${TARGET_NAME})
endif() endif()
endif() endif()
endfunction(hip_binary) endfunction(hip_binary)
...@@ -518,6 +540,7 @@ function(hip_test TARGET_NAME) ...@@ -518,6 +540,7 @@ function(hip_test TARGET_NAME)
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP)
target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags)
add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
endif() endif()
endfunction(hip_test) endfunction(hip_test)
...@@ -560,6 +583,7 @@ function(go_library TARGET_NAME) ...@@ -560,6 +583,7 @@ function(go_library TARGET_NAME)
endif() endif()
if(go_library_DEPS) if(go_library_DEPS)
add_dependencies(${TARGET_NAME} ${go_library_DEPS}) add_dependencies(${TARGET_NAME} ${go_library_DEPS})
common_link(${TARGET_NAME})
endif(go_library_DEPS) endif(go_library_DEPS)
# The "source file" of the library is `${dummyfile}` which never # The "source file" of the library is `${dummyfile}` which never
......
...@@ -32,13 +32,23 @@ function(copy TARGET) ...@@ -32,13 +32,23 @@ function(copy TARGET)
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
if (WIN32) if (WIN32)
if(IS_DIRECTORY ${src})
get_filename_component(last_path ${src} NAME)
string(APPEND dst "/" ${last_path})
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}"
)
if(EXISTS ${src})
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND cmake -E copy_directory "${src}" "${dst}"
COMMENT "copying ${src} -> ${dst}")
else()
message(WARNING "${src} not exist!")
endif()
else()
# windows cmd shell will not expand wildcard automatically. # windows cmd shell will not expand wildcard automatically.
# below expand the files,libs and copy them by rules. # below expand the files, and copy them by rules.
file(GLOB header_files ${src} "*.h") file(GLOB src_files ${src})
file(GLOB static_lib_files ${src} "*.lib")
file(GLOB dll_lib_files ${src} "*.dll")
set(src_files ${header_files} ${static_lib_files} ${dll_lib_files})
if (NOT "${src_files}" STREQUAL "") if (NOT "${src_files}" STREQUAL "")
list(REMOVE_DUPLICATES src_files) list(REMOVE_DUPLICATES src_files)
endif () endif ()
...@@ -50,6 +60,7 @@ function(copy TARGET) ...@@ -50,6 +60,7 @@ function(copy TARGET)
COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}" COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}"
COMMENT "copying ${src_file} -> ${dst}") COMMENT "copying ${src_file} -> ${dst}")
endforeach () endforeach ()
endif()
else (WIN32) # not windows else (WIN32) # not windows
add_custom_command(TARGET ${TARGET} PRE_BUILD add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND mkdir -p "${dst}" COMMAND mkdir -p "${dst}"
...@@ -95,7 +106,7 @@ copy(xxhash_lib ...@@ -95,7 +106,7 @@ copy(xxhash_lib
DEPS xxhash DEPS xxhash
) )
if (NOT PROTOBUF_FOUND) if (NOT PROTOBUF_FOUND OR WIN32)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf")
copy(protobuf_lib copy(protobuf_lib
SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY} SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY}
...@@ -104,20 +115,20 @@ if (NOT PROTOBUF_FOUND) ...@@ -104,20 +115,20 @@ if (NOT PROTOBUF_FOUND)
) )
endif () endif ()
if (NOT CBLAS_FOUND) if (WITH_MKLML)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/openblas")
copy(openblas_lib
SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include
DSTS ${dst_dir} ${dst_dir}
DEPS extern_openblas
)
elseif (WITH_MKLML)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mklml") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mklml")
copy(mklml_lib copy(mklml_lib
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR} SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR}
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir} DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir}
DEPS mklml DEPS mklml
) )
elseif (NOT CBLAS_FOUND OR WIN32)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/openblas")
copy(openblas_lib
SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include
DSTS ${dst_dir} ${dst_dir}
DEPS extern_openblas
)
endif () endif ()
if (WITH_MKLDNN) if (WITH_MKLDNN)
...@@ -125,12 +136,20 @@ if (WITH_MKLDNN) ...@@ -125,12 +136,20 @@ if (WITH_MKLDNN)
copy(mkldnn_lib copy(mkldnn_lib
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB} SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB}
DSTS ${dst_dir} ${dst_dir}/lib DSTS ${dst_dir} ${dst_dir}/lib
DEPS mkldnn DEPS mkldnn_shared_lib
) )
endif () endif ()
if (NOT WIN32) if (WITH_NGRAPH)
if (NOT MOBILE_INFERENCE AND NOT RPI) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/ngraph")
copy(ngraph_lib
SRCS ${NGRAPH_INC_DIR} ${NGRAPH_LIB_DIR}
DSTS ${dst_dir} ${dst_dir}
DEPS ngraph
)
endif ()
if (NOT MOBILE_INFERENCE AND NOT RPI)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy")
copy(snappy_lib copy(snappy_lib
SRCS ${SNAPPY_INCLUDE_DIR} ${SNAPPY_LIBRARIES} SRCS ${SNAPPY_INCLUDE_DIR} ${SNAPPY_LIBRARIES}
...@@ -148,8 +167,7 @@ if (NOT WIN32) ...@@ -148,8 +167,7 @@ if (NOT WIN32)
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES} SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib DSTS ${dst_dir} ${dst_dir}/lib
DEPS zlib) DEPS zlib)
endif () endif ()
endif (NOT WIN32)
# paddle fluid module # paddle fluid module
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
...@@ -182,9 +200,21 @@ if (WITH_ANAKIN AND WITH_MKL) ...@@ -182,9 +200,21 @@ if (WITH_ANAKIN AND WITH_MKL)
list(APPEND inference_deps anakin_inference_lib) list(APPEND inference_deps anakin_inference_lib)
endif () endif ()
if (TENSORRT_FOUND)
copy(tensorrt_lib DEPS ${inference_deps}
SRCS ${TENSORRT_ROOT}/include/Nv*.h ${TENSORRT_ROOT}/lib/libnvinfer*
DSTS ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/include ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/lib)
endif ()
set(module "inference") set(module "inference")
if(WIN32)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/libpaddle_fluid.*)
else(WIN32)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*)
endif(WIN32)
copy(inference_lib DEPS ${inference_deps} copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* SRCS ${src_dir}/${module}/*.h ${paddle_fluid_lib}
${src_dir}/${module}/api/paddle_*.h ${src_dir}/${module}/api/paddle_*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
) )
...@@ -224,7 +254,7 @@ copy(third_party DEPS fluid_lib_dist ...@@ -224,7 +254,7 @@ copy(third_party DEPS fluid_lib_dist
# only need libpaddle_fluid.so/a and paddle_*.h for inference-only library # only need libpaddle_fluid.so/a and paddle_*.h for inference-only library
copy(inference_api_lib DEPS fluid_lib_dist copy(inference_api_lib DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.* SRCS ${paddle_fluid_lib}
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_*.h ${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_*.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
) )
......
...@@ -84,7 +84,7 @@ function(op_library TARGET) ...@@ -84,7 +84,7 @@ function(op_library TARGET)
endif() endif()
if (WIN32) if (WIN32)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops. # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op") foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}") if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return() return()
endif() endif()
...@@ -110,7 +110,7 @@ function(op_library TARGET) ...@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
...@@ -166,6 +166,8 @@ function(op_library TARGET) ...@@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif() endif()
......
...@@ -57,46 +57,43 @@ int main() ...@@ -57,46 +57,43 @@ int main()
return 0; return 0;
}" SSE3_FOUND) }" SSE3_FOUND)
# disable AVX by default on windows # Check AVX
if(NOT WIN32) set(CMAKE_REQUIRED_FLAGS ${AVX_FLAG})
# Check AVX set(AVX_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE)
set(CMAKE_REQUIRED_FLAGS ${AVX_FLAG}) CHECK_CXX_SOURCE_RUNS("
set(AVX_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) #include <immintrin.h>
CHECK_CXX_SOURCE_RUNS(" int main()
#include <immintrin.h> {
int main()
{
__m256 a = _mm256_set_ps (-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f); __m256 a = _mm256_set_ps (-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f);
__m256 b = _mm256_set_ps (1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f); __m256 b = _mm256_set_ps (1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f);
__m256 result = _mm256_add_ps (a, b); __m256 result = _mm256_add_ps (a, b);
return 0; return 0;
}" AVX_FOUND) }" AVX_FOUND)
# Check AVX 2 # Check AVX 2
set(CMAKE_REQUIRED_FLAGS ${AVX2_FLAG}) set(CMAKE_REQUIRED_FLAGS ${AVX2_FLAG})
set(AVX2_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) set(AVX2_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE)
CHECK_CXX_SOURCE_RUNS(" CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h> #include <immintrin.h>
int main() int main()
{ {
__m256i a = _mm256_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4); __m256i a = _mm256_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4);
__m256i result = _mm256_abs_epi32 (a); __m256i result = _mm256_abs_epi32 (a);
return 0; return 0;
}" AVX2_FOUND) }" AVX2_FOUND)
# Check AVX512F # Check AVX512F
set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG}) set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG})
set(AVX512F_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) set(AVX512F_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE)
CHECK_CXX_SOURCE_RUNS(" CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h> #include <immintrin.h>
int main() int main()
{ {
__m512i a = _mm512_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4, __m512i a = _mm512_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4,
13, -5, 6, -7, 9, 2, -6, 3); 13, -5, 6, -7, 9, 2, -6, 3);
__m512i result = _mm512_abs_epi32 (a); __m512i result = _mm512_abs_epi32 (a);
return 0; return 0;
}" AVX512F_FOUND) }" AVX512F_FOUND)
endif(NOT WIN32)
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED}) set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED})
mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND AVX512F_FOUND) mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND AVX512F_FOUND)
...@@ -60,7 +60,7 @@ class Float16Transpiler: ...@@ -60,7 +60,7 @@ class Float16Transpiler:
raise TypeError("place should be as CPUPlace/CUDAPlace type") raise TypeError("place should be as CPUPlace/CUDAPlace type")
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
if not isinstance(scope, core.Scope): if not isinstance(scope, core._Scope):
raise TypeError("scope should be as Scope type or None") raise TypeError("scope should be as Scope type or None")
self.scope = scope self.scope = scope
......
此差异已折叠。
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(imperative)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(recordio) add_subdirectory(recordio)
......
# windows treat symbolic file as a real file, which is different with unix #windows treat symbolic file as a real file, which is different with unix
# We create a hidden file and compile it instead of origin source file. #We create a hidden file and compile it instead of origin source file.
function(windows_symbolic TARGET) function(windows_symbolic TARGET)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS PATH)
cmake_parse_arguments(windows_symbolic "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(windows_symbolic "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(final_path ${CMAKE_CURRENT_SOURCE_DIR}/${windows_symbolic_PATH})
foreach(src ${windows_symbolic_SRCS}) foreach(src ${windows_symbolic_SRCS})
get_filename_component(src ${src} NAME_WE) get_filename_component(src ${src} NAME_WE)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu) if (NOT EXISTS ${final_path}/${src}.cc OR NOT EXISTS ${final_path}/${src}.cu)
message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.") message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.")
endif() endif()
# only copy the xx.cu to .xx.cu when the content are modified file(GENERATE OUTPUT ${final_path}/.${src}.cu INPUT ${final_path}/${src}.cc)
set(copy_flag 1)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu) add_custom_command(OUTPUT ${final_path}/.${src}.cu
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc SOURCE_STR) COMMAND ${CMAKE_COMMAND} -E copy_if_different "${final_path}/${src}.cc" "${final_path}/.${src}.cu"
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu TARGET_STR)
if (SOURCE_STR STREQUAL TARGET_STR)
set(copy_flag 0)
endif()
endif()
if (copy_flag)
add_custom_command(OUTPUT .${src}.cu
COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu"
COMMENT "create hidden file of ${src}.cu") COMMENT "create hidden file of ${src}.cu")
endif(copy_flag)
add_custom_target(${TARGET} ALL DEPENDS .${src}.cu) add_custom_target(${TARGET} ALL DEPENDS .${src}.cu)
endforeach() endforeach()
endfunction() endfunction()
add_subdirectory(ir) add_subdirectory(ir)
add_subdirectory(details) add_subdirectory(details)
# ddim lib #ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context) cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
cc_test(data_type_test SRCS data_type_test.cc DEPS data_type place tensor) cc_test(data_type_test SRCS data_type_test.cc DEPS data_type place tensor)
if(WITH_GPU) if(WITH_GPU)
...@@ -47,10 +39,10 @@ if(WITH_GPU) ...@@ -47,10 +39,10 @@ if(WITH_GPU)
nv_library(tensor SRCS tensor.cc .tensor_util.cu DEPS place memory data_type device_context) nv_library(tensor SRCS tensor.cc .tensor_util.cu DEPS place memory data_type device_context)
add_dependencies(tensor tensor_util) add_dependencies(tensor tensor_util)
else() else()
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context) nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context )
endif(WIN32) endif(WIN32)
else() else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context) cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context )
endif() endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
...@@ -72,25 +64,33 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto ...@@ -72,25 +64,33 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce) cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
if (WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool xxhash var_type_traits)
cc_library(scope_pool SRCS scope_pool.cc DEPS scope)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry device_context math_function) DEPS operator op_registry device_context math_function scope)
if(WITH_GPU) if(WITH_GPU)
if (WIN32) if (WIN32)
# windows treat symbolic file as a real file, which is different with unix #windows treat symbolic file as a real file, which is different with unix
# We create a hidden file and compile it instead of origin source file. #We create a hidden file and compile it instead of origin source file.
windows_symbolic(hidden_file SRCS data_type_transform.cu) windows_symbolic(hidden_file SRCS data_type_transform.cu)
nv_library(data_type_transform SRCS .data_type_transform.cu DEPS tensor) nv_library(data_type_transform SRCS .data_type_transform.cu DEPS tensor)
add_dependencies(data_type_transform hidden_file) add_dependencies(data_type_transform hidden_file)
...@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) ...@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -127,17 +128,19 @@ cc_library(version SRCS version.cc) ...@@ -127,17 +128,19 @@ cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version) cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
if(NOT WIN32) if(WITH_NGRAPH)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler) shape_inference data_transform lod_tensor profiler)
endif(NOT WIN32) endif(WITH_NGRAPH)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. #Generate an empty \
#__init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
if (NOT WIN32) if (NOT WIN32)
...@@ -163,24 +166,34 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) ...@@ -163,24 +166,34 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
if(NOT WIN32) if(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper)
else(NOT WIN32) else(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(NOT WIN32) endif(WITH_NGRAPH)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor garbage_collector)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
graph build_strategy graph build_strategy
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper) if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib timer)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper timer)
endif(WITH_PSLIB)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
...@@ -190,7 +203,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry ...@@ -190,7 +203,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc ) cc_test(tuple_test SRCS tuple_test.cc )
......
...@@ -15,34 +15,123 @@ ...@@ -15,34 +15,123 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/framework/unroll_array_ops.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T, size_t N> template <typename T, size_t N>
class Array { class Array {
static_assert(N > 0, "The size of array must be larger than 0");
public: public:
HOSTDEVICE Array() {} static constexpr size_t kSize = N;
HOSTDEVICE inline Array() {}
HOSTDEVICE explicit Array(const T &val) { template <typename... Args>
for (size_t i = 0; i < N; ++i) data_[i] = val; HOSTDEVICE inline explicit Array(const T &val, Args... args) {
static_assert(N == sizeof...(Args) + 1, "Invalid argument");
UnrollVarArgsAssign<T>::Run(data_, val, args...);
} }
HOSTDEVICE const T *Get() const { return data_; } HOSTDEVICE inline void Fill(const T &val) {
UnrollFillConstant<N>::Run(data_, val);
}
HOSTDEVICE T *GetMutable() { return data_; } HOSTDEVICE inline const T *Get() const { return data_; }
HOSTDEVICE T &operator[](size_t index) { return data_[index]; } HOSTDEVICE inline T *GetMutable() { return data_; }
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; } HOSTDEVICE inline T &operator[](size_t i) { return *advance(data_, i); }
// Writing "return data_[i]" would cause compilation warning/error:
// "array subscript is above array bound" in Python 35 CI.
// It seems that it is a false warning of GCC if we do not check the bounds
// of array index. But for better performance, we do not check in operator[]
// like what is in STL. If users want to check the bounds, use at() instead
HOSTDEVICE inline const T &operator[](size_t i) const {
return *advance(data_, i);
}
HOSTDEVICE inline T &at(size_t i) {
#ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds");
#endif
return (*this)[i];
}
HOSTDEVICE inline const T &at(size_t i) const {
#ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds");
#endif
return (*this)[i];
}
HOSTDEVICE constexpr size_t size() const { return N; } HOSTDEVICE constexpr size_t size() const { return N; }
HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
return UnrollCompare<N>::Run(data_, other.data_);
}
HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
return !(*this == other);
}
private: private:
template <typename U>
HOSTDEVICE static inline U *advance(U *ptr, size_t i) {
return ptr + i;
}
T data_[N]; T data_[N];
}; };
template <typename T>
class Array<T, 0> {
public:
static constexpr size_t kSize = 0;
HOSTDEVICE inline Array() {}
HOSTDEVICE inline void Fill(const T &val) {}
HOSTDEVICE inline constexpr T *Get() const { return nullptr; }
// Add constexpr to GetMutable() cause warning in MAC
HOSTDEVICE inline T *GetMutable() { return nullptr; }
HOSTDEVICE inline T &operator[](size_t) {
#ifdef __CUDA_ARCH__
static T obj();
return obj;
#else
PADDLE_THROW("Array<T, 0> has no element");
#endif
}
HOSTDEVICE inline const T &operator[](size_t) const {
#ifdef __CUDA_ARCH__
static const T obj();
return obj;
#else
PADDLE_THROW("Array<T, 0> has no element");
#endif
}
HOSTDEVICE inline T &at(size_t i) { return (*this)[i]; }
HOSTDEVICE inline const T &at(size_t i) const { return (*this)[i]; }
HOSTDEVICE constexpr size_t size() const { return 0; }
HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
return true;
}
HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
return false;
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,6 +29,9 @@ limitations under the License. */ ...@@ -29,6 +29,9 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -47,6 +50,11 @@ void AsyncExecutor::CreateThreads( ...@@ -47,6 +50,11 @@ void AsyncExecutor::CreateThreads(
worker->SetDataFeed(reader); worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names); worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory(); worker->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->SetParamConfig(&_param_config);
#endif
} }
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
...@@ -60,12 +68,177 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT ...@@ -60,12 +68,177 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
readers[0]->SetFileList(filelist); readers[0]->SetFileList(filelist);
} }
#ifdef PADDLE_WITH_PSLIB
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_server(dist_desc, index);
InitParamConfig();
}
void AsyncExecutor::InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list,
int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_worker(
dist_desc, const_cast<uint64_t*>(host_sign_list.data()), node_num, index);
InitParamConfig();
}
uint64_t AsyncExecutor::StartServer() { return _pslib_ptr->run_server(); }
void AsyncExecutor::StopServer() { _pslib_ptr->stop_server(); }
void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
int node_num) {
_pslib_ptr->gather_servers(const_cast<uint64_t*>(host_sign_list.data()),
node_num);
}
void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param_size();
++i) {
if (_pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.table_class()
.find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()
->server_param()
.downpour_server_param()
.downpour_table_param(i)
.accessor()
.fea_dim();
break;
}
}
_param_config.slot_dim = _param_config.fea_dim - 2;
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
++t) {
_param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] =
std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i));
}
std::vector<std::string> tmp_dense_gradient_variable_name;
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i));
}
_param_config.dense_variable_name[table.table_id()] =
std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim());
}
}
void AsyncExecutor::InitModel() {
for (auto table_id : _param_config.dense_table_id) {
std::vector<paddle::ps::Region> regions;
for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
}
void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
LOG(FATAL) << "save model failed";
exit(-1);
}
}
void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;
param.threshold = 1;
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread =
std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
}
#endif
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
const bool debug) { const std::string& mode, const bool debug) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
auto& block = main_program.Block(0); auto& block = main_program.Block(0);
...@@ -82,7 +255,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -82,7 +255,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc); &data_feed_desc);
int actual_thread_num = thread_num; actual_thread_num = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
...@@ -106,12 +279,22 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -106,12 +279,22 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode);
#endif
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num);
for (auto& worker : workers) { for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
worker.reset(new AsyncExecutorThreadWorker);
} else {
worker.reset(new ExecutorThreadWorker); worker.reset(new ExecutorThreadWorker);
} }
#else
worker.reset(new ExecutorThreadWorker);
#endif
}
// prepare thread resource here // prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
...@@ -121,14 +304,23 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -121,14 +304,23 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// start executing ops in multiple threads // start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get()));
} else {
threads.push_back( threads.push_back(
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get())); std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
} }
}
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
_pull_dense_thread->stop();
}
#endif
root_scope_->DropKids(); root_scope_->DropKids();
return; return;
......
...@@ -14,9 +14,11 @@ limitations under the License. */ ...@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <time.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <random> // local_random_engine
#include <set> #include <set>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -30,6 +32,31 @@ limitations under the License. */ ...@@ -30,6 +32,31 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline double current_realtime() {
#if !defined(_WIN32)
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
return tp.tv_sec + tp.tv_nsec * 1e-9;
#else
return 0.0;
#endif
}
inline std::default_random_engine& local_random_engine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++,
static_cast<uint64_t>(current_realtime() * 1000)};
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
class AsyncExecutor { class AsyncExecutor {
public: public:
AsyncExecutor(Scope* scope, const platform::Place& place); AsyncExecutor(Scope* scope, const platform::Place& place);
...@@ -39,7 +66,19 @@ class AsyncExecutor { ...@@ -39,7 +66,19 @@ class AsyncExecutor {
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const bool debug = false); const std::string& mode, const bool debug = false);
#ifdef PADDLE_WITH_PSLIB
void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list, int node_num,
int index);
uint64_t StartServer();
void StopServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel();
void SaveModel(const std::string& path);
void InitParamConfig();
#endif
private: private:
void CreateThreads(ExecutorThreadWorker* worker, void CreateThreads(ExecutorThreadWorker* worker,
...@@ -48,10 +87,21 @@ class AsyncExecutor { ...@@ -48,10 +87,21 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index, Scope* root_scope, const int thread_index,
const bool debug); const bool debug);
#ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode);
#endif
public: public:
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
AsyncWorkerParamConfig _param_config;
#endif
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
private:
int actual_thread_num;
}; };
} // namespace framework } // namespace framework
......
...@@ -165,7 +165,7 @@ template <typename T> ...@@ -165,7 +165,7 @@ template <typename T>
class GreaterThanChecker { class GreaterThanChecker {
public: public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails."); PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails.");
} }
...@@ -177,7 +177,7 @@ template <typename T> ...@@ -177,7 +177,7 @@ template <typename T>
class EqualGreaterThanChecker { class EqualGreaterThanChecker {
public: public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails.");
} }
...@@ -193,7 +193,7 @@ class DefaultValueSetter { ...@@ -193,7 +193,7 @@ class DefaultValueSetter {
public: public:
explicit DefaultValueSetter(T default_value) explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {} : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; } // NOLINT void operator()(T* value) const { *value = default_value_; }
private: private:
T default_value_; T default_value_;
...@@ -203,7 +203,7 @@ template <typename T> ...@@ -203,7 +203,7 @@ template <typename T>
class EnumInContainer { class EnumInContainer {
public: public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {} explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(T& val) const { void operator()(const T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(), PADDLE_ENFORCE(container_.find(val) != container_.end(),
"Value %s is not in enum container %s", val, "Value %s is not in enum container %s", val,
ContainerDebugString()); ContainerDebugString());
...@@ -232,7 +232,8 @@ class EnumInContainer { ...@@ -232,7 +232,8 @@ class EnumInContainer {
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
class TypedAttrChecker { class TypedAttrChecker {
typedef std::function<void(T&)> ValueChecker; typedef std::function<void(T*)> DefaultValueChecker;
typedef std::function<void(const T&)> ValueChecker;
public: public:
explicit TypedAttrChecker(const std::string& attr_name) explicit TypedAttrChecker(const std::string& attr_name)
...@@ -268,17 +269,17 @@ class TypedAttrChecker { ...@@ -268,17 +269,17 @@ class TypedAttrChecker {
return *this; return *this;
} }
void operator()(AttributeMap& attr_map) const { // NOLINT void operator()(AttributeMap* attr_map) const {
if (!attr_map.count(attr_name_)) { if (!attr_map->count(attr_name_)) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE(!default_value_setter_.empty(),
"Attribute '%s' is required!", attr_name_); "Attribute '%s' is required!", attr_name_);
// default_value_setter_ has no more than one element // default_value_setter_ has no more than one element
T val; T val;
(default_value_setter_[0])(val); (default_value_setter_[0])(&val);
attr_map[attr_name_] = val; (*attr_map)[attr_name_] = val;
} }
Attribute& attr = attr_map.at(attr_name_); Attribute& attr = attr_map->at(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_); ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(attr); T* attr_value = extract_attr(attr);
for (const auto& checker : value_checkers_) { for (const auto& checker : value_checkers_) {
...@@ -289,12 +290,12 @@ class TypedAttrChecker { ...@@ -289,12 +290,12 @@ class TypedAttrChecker {
private: private:
std::string attr_name_; std::string attr_name_;
std::vector<ValueChecker> value_checkers_; std::vector<ValueChecker> value_checkers_;
std::vector<ValueChecker> default_value_setter_; std::vector<DefaultValueChecker> default_value_setter_;
}; };
// check whether op's all attributes fit their own limits // check whether op's all attributes fit their own limits
class OpAttrChecker { class OpAttrChecker {
typedef std::function<void(AttributeMap&)> AttrChecker; typedef std::function<void(AttributeMap*)> AttrChecker;
public: public:
template <typename T> template <typename T>
...@@ -304,7 +305,7 @@ class OpAttrChecker { ...@@ -304,7 +305,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>()); return *(checker.target<TypedAttrChecker<T>>());
} }
void Check(AttributeMap& attr_map) const { // NOLINT void Check(AttributeMap* attr_map) const {
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(attr_map); checker(attr_map);
} }
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
......
...@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { ...@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit(); CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) { if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) { feed_vec_[i] = var->GetMutable<LoDTensor>();
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
} }
} }
} }
...@@ -68,6 +64,7 @@ bool DataFeed::PickOneFile(std::string* filename) { ...@@ -68,6 +64,7 @@ bool DataFeed::PickOneFile(std::string* filename) {
return false; return false;
} }
*filename = filelist_[file_idx_++]; *filename = filelist_[file_idx_++];
LOG(ERROR) << "pick file:" << *filename;
return true; return true;
} }
...@@ -200,22 +197,22 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) { ...@@ -200,22 +197,22 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
for (size_t i = 0; i < all_slots_.size(); ++i) { for (size_t i = 0; i < all_slots_.size(); ++i) {
int num = strtol(endptr, &endptr, 10); int num = strtol(endptr, &endptr, 10);
if (num < 0) { if (num < 0) {
VLOG(1) << "error: the number of ids is a negative number: " << num; VLOG(0) << "error: the number of ids is a negative number: " << num;
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} else if (num == 0) { } else if (num == 0) {
VLOG(1) VLOG(0)
<< "error: the number of ids can not be zero, you need " << "error: the number of ids can not be zero, you need "
"padding it in data generator; or if there is something wrong" "padding it in data generator; or if there is something wrong"
" with the data, please check if the data contains unresolvable " " with the data, please check if the data contains unresolvable "
"characters."; "characters.";
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} else if (errno == ERANGE || num > INT_MAX) { } else if (errno == ERANGE || num > INT_MAX) {
VLOG(1) << "error: the number of ids greater than INT_MAX"; VLOG(0) << "error: the number of ids greater than INT_MAX";
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
...@@ -223,15 +220,15 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) { ...@@ -223,15 +220,15 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
strtof(endptr, &endptr); strtof(endptr, &endptr);
if (errno == ERANGE) { if (errno == ERANGE) {
VLOG(1) << "error: the value is out of the range of " VLOG(0) << "error: the value is out of the range of "
"representable values for float"; "representable values for float";
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
if (i + 1 != num && endptr - str == len) { if (i + 1 != num && endptr - str == len) {
VLOG(1) << "error: there is a wrong with the number of ids."; VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
...@@ -240,32 +237,43 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) { ...@@ -240,32 +237,43 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
strtoull(endptr, &endptr, 10); strtoull(endptr, &endptr, 10);
if (errno == ERANGE) { if (errno == ERANGE) {
VLOG(1) << "error: the value is out of the range of " VLOG(0) << "error: the value is out of the range of "
"representable values for uint64_t"; "representable values for uint64_t";
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
if (i + 1 != num && endptr - str == len) { if (i + 1 != num && endptr - str == len) {
VLOG(1) << "error: there is a wrong with the number of ids."; VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(1) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
} }
} else { } else {
VLOG(1) << "error: this type<" << all_slots_type_[i] VLOG(0) << "error: this type<" << all_slots_type_[i]
<< "> is not supported"; << "> is not supported";
return false; return false;
} }
} }
if (endptr - str != len) { // It may be added '\t' character to the end of the output of reduce
VLOG(1) << "error: there is some data at the end of the line."; // task when processes data by Hadoop(when the output of the reduce
VLOG(1) << "please check line<" << instance_cout << "> in file<" // task of Hadoop has only one field, it will add a '\t' at the end
// of the line by default, and you can use this option to avoid it:
// `-D mapred.textoutputformat.ignoreseparator=true`), which does
// not affect the correctness of the data. Therefore, it should be
// judged that the data is not normal when the end of each line of
// data contains characters which are not spaces.
while (endptr - str != len) {
if (!isspace(*(endptr++))) {
VLOG(0)
<< "error: there is some extra characters at the end of the line.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
} }
}
VLOG(3) << "instances cout: " << instance_cout; VLOG(3) << "instances cout: " << instance_cout;
VLOG(3) << "The file format is correct"; VLOG(3) << "The file format is correct";
return true; return true;
...@@ -290,6 +298,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -290,6 +298,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"the data, please check if the data contains unresolvable " "the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s", "characters.\nplease check this error line: %s",
str); str);
if (idx != -1) { if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]); (*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float if ((*instance)[idx].GetType()[0] == 'f') { // float
...@@ -326,6 +335,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec( ...@@ -326,6 +335,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
(*ins_vec)[i].InitOffset(); (*ins_vec)[i].InitOffset();
} }
} }
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]); (*ins_vec)[i].AddIns(instance[i]);
} }
...@@ -337,36 +347,25 @@ void MultiSlotDataFeed::PutToFeedVec( ...@@ -337,36 +347,25 @@ void MultiSlotDataFeed::PutToFeedVec(
const auto& type = ins_vec[i].GetType(); const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset(); const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back()); int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData(); const auto& feasign = ins_vec[i].GetFloatData();
if (feed_vec_[i].IsDense()) { float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<float>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64 } else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle // no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data(); const auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) { int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<int64_t>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr =
feed_vec_[i].GetLoDTensor()->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
} }
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
} }
} }
} }
......
...@@ -30,35 +30,6 @@ limitations under the License. */ ...@@ -30,35 +30,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Pack Tensor type and LoDTensor type into MixTensor type, in order
// to record either Tensor or LoDTensor information at the same time.
class MixTensor {
public:
MixTensor() {}
explicit MixTensor(LoDTensor* lodtensor) {
is_dense_ = false;
lodtensor_ = lodtensor;
}
explicit MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() { return is_dense_; }
LoDTensor* GetLoDTensor() {
PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr.");
return lodtensor_;
}
Tensor* GetTensor() {
PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr.");
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
};
// DataFeed is the base virtual class for all ohther DataFeeds. // DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer. // It is used to read files and parse the data for subsequent trainer.
// Example: // Example:
...@@ -133,7 +104,7 @@ class DataFeed { ...@@ -133,7 +104,7 @@ class DataFeed {
use_slots_index_; // -1: not used; >=0: the index of use_slots_ use_slots_index_; // -1: not used; >=0: the index of use_slots_
// The data read by DataFeed will be stored here // The data read by DataFeed will be stored here
std::vector<MixTensor> feed_vec_; std::vector<LoDTensor*> feed_vec_;
// the batch size defined by user // the batch size defined by user
int default_batch_size_; int default_batch_size_;
......
...@@ -152,21 +152,15 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -152,21 +152,15 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
const auto& multi_slot_desc = data_feed_desc.multi_slot_desc(); const auto& multi_slot_desc = data_feed_desc.multi_slot_desc();
std::map<std::string, const paddle::framework::LoDTensor*> std::map<std::string, const paddle::framework::LoDTensor*>
lodtensor_targets; lodtensor_targets;
std::map<std::string, const paddle::framework::Tensor*> tensor_targets;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i); const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) { if (slot.is_used()) {
const auto& name = slot.name(); const auto& name = slot.name();
readers[idx]->AddFeedVar(scope->Var(name), name); readers[idx]->AddFeedVar(scope->Var(name), name);
if (slot.is_dense()) {
tensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::Tensor>();
} else {
lodtensor_targets[name] = lodtensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::LoDTensor>(); &scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
} }
} }
}
readers[idx]->Start(); readers[idx]->Start();
while (readers[idx]->Next()) { while (readers[idx]->Next()) {
int index = 0; int index = 0;
...@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
if (!slot.is_used()) { if (!slot.is_used()) {
continue; continue;
} }
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.is_dense()) { // dense branch if (slot.is_dense()) { // dense branch
const paddle::framework::Tensor* tens = tensor_targets[slot.name()];
if (slot.type() == "uint64") { if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>(); const int64_t* data = tens->data<int64_t>();
int batch_size = tens->dims()[0]; int batch_size = tens->dims()[0];
...@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
PADDLE_THROW("Error type in proto file."); PADDLE_THROW("Error type in proto file.");
} }
} else { // sparse branch } else { // sparse branch
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.type() == "uint64") { if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>(); const int64_t* data = tens->data<int64_t>();
for (size_t i = 0; i < tens->NumElements(); ++i) { for (size_t i = 0; i < tens->NumElements(); ++i) {
......
...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type()); out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(in.type()), in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32: case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16: case mkldnn::memory::data_type::s16:
...@@ -144,26 +144,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -144,26 +144,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
void* in_data = GetDataFromTensor(in, in_type);
// output tensor has the same dims as input. Reorder don't change dims // output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims()); out->Resize(in.dims());
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type()); auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data); auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory = auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data); memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory); platform::Reorder(in_memory, out_memory);
} else {
out->ShareDataWith(in);
}
out->set_layout(out_layout); out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel // reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef); out->set_format(memory::format::format_undef);
......
...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { ...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
} }
} }
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static const std::map<std::type_index, MKLDNNDataType> dict{ static std::unordered_map<int, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT {DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT {DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, {DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(type); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
......
...@@ -26,7 +26,7 @@ struct DataTypeMap { ...@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<int, size_t> proto_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map, ...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \ #define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type) RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. _ForEachDataType_(RegType);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().cpp_to_size_.find(type); auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().cpp_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
} }
} // namespace framework } // namespace framework
......
...@@ -22,46 +22,59 @@ limitations under the License. */ ...@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { #define VisitDataTypeCallback(cpp_type, proto_type) \
case proto::VarType::FP16: do { \
visitor.template apply<platform::float16>(); if (type == proto_type) { \
break; visitor.template apply<cpp_type>(); \
case proto::VarType::FP32: return; \
visitor.template apply<float>(); } \
break; } while (0)
case proto::VarType::FP64:
visitor.template apply<double>(); _ForEachDataType_(VisitDataTypeCallback);
break; #undef VisitDataTypeCallback
case proto::VarType::INT32:
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type); PADDLE_THROW("Not supported %d", type);
}
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -26,15 +26,15 @@ TEST(DataType, float16) { ...@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size // test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info // test debug info
std::string type = "float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
...@@ -18,312 +18,159 @@ limitations under the License. */ ...@@ -18,312 +18,159 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/// @cond HIDDEN
template <int i>
Dim<i> make_dim(const int64_t* d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1));
}
template <>
Dim<0> make_dim<0>(const int64_t* d) {
return Dim<0>(*d);
}
void make_ddim(DDim& ddim, const int64_t* dims, int n) {
switch (n) {
case 0:
ddim = make_dim<0>(dims);
break;
case 1:
ddim = make_dim<1>(dims);
break;
case 2:
ddim = make_dim<2>(dims);
break;
case 3:
ddim = make_dim<3>(dims);
break;
case 4:
ddim = make_dim<4>(dims);
break;
case 5:
ddim = make_dim<5>(dims);
break;
case 6:
ddim = make_dim<6>(dims);
break;
case 7:
ddim = make_dim<7>(dims);
break;
case 8:
ddim = make_dim<8>(dims);
break;
case 9:
ddim = make_dim<9>(dims);
break;
default:
PADDLE_THROW("Dynamic dimensions must have between [1, 9] dimensions.");
}
}
/// @endcond
DDim make_ddim(std::initializer_list<int64_t> dims) { DDim make_ddim(std::initializer_list<int64_t> dims) {
DDim result(make_dim(0)); return DDim(dims.begin(), dims.size());
make_ddim(result, dims.begin(), dims.size());
return result;
} }
DDim make_ddim(const std::vector<int64_t>& dims) { DDim make_ddim(const std::vector<int64_t>& dims) {
DDim result(make_dim(0)); return DDim(dims.data(), dims.size());
make_ddim(result, &dims[0], dims.size());
return result;
} }
DDim make_ddim(const std::vector<int>& dims) { DDim make_ddim(const std::vector<int>& dims) {
std::vector<int64_t> res(dims.size()); return DDim(dims.data(), dims.size());
std::transform(dims.begin(), dims.end(), res.begin(),
[](int d) { return static_cast<int64_t>(d); });
return make_ddim(res);
} }
/// @cond HIDDEN struct DDimEqualityVisitor {
// XXX For some reason, putting this in an anonymous namespace causes errors explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {
public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D> template <int D>
int64_t& operator()(Dim<D>& dim) const { inline bool operator()(const Dim<D>& self) const {
return dim[idx_]; return UnrollCompare<D>::Run(self.Get(), d_);
} }
private: const int64_t* d_;
int idx_;
}; };
class DynamicConstIndexer : public boost::static_visitor<int64_t> { bool DDim::operator==(const DDim& d) const {
public: return size() == d.size() &&
explicit DynamicConstIndexer(int idx) : idx_(idx) {} this->apply_visitor(DDimEqualityVisitor(d.Get()));
template <int D>
int64_t operator()(const Dim<D>& dim) const {
return dim[idx_];
}
private:
int idx_;
};
/// @endcond
int64_t& DDim::operator[](int idx) {
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
}
int64_t DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var);
} }
int DDim::size() const { return arity(*this); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
bool DDim::operator==(DDim d) const { struct DDimPlusVisitor {
if (var.which() != d.getVar().which()) { explicit DDimPlusVisitor(const int64_t* d1, const int64_t* d2)
return false; : d1_(d1), d2_(d2) {}
} else {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) { template <int D>
if (v1[i] != v2[i]) { inline void operator()(Dim<D>& self) const {
return false; UnrollAdd<D>::Run(d1_, d2_, self.GetMutable());
}
} }
return true; const int64_t* d1_;
} const int64_t* d2_;
} };
bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] + v2[i]);
}
return make_ddim(v3); DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(size() == d.size());
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret;
} }
DDim DDim::operator*(DDim d) const { struct DDimMulVisitor {
std::vector<int64_t> v1 = vectorize(*this); explicit DDimMulVisitor(const int64_t* d1, const int64_t* d2)
std::vector<int64_t> v2 = vectorize(d); : d1_(d1), d2_(d2) {}
std::vector<int64_t> v3; template <int D>
inline void operator()(Dim<D>& self) const {
assert(v1.size() == v2.size()); UnrollMul<D>::Run(d1_, d2_, self.GetMutable());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
} }
return make_ddim(v3); const int64_t* d1_;
const int64_t* d2_;
};
DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(size() == d.size());
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret;
} }
int64_t get(const DDim& ddim, int idx) { return ddim[idx]; } int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT
/// @cond HIDDEN
struct VectorizeVisitor : public boost::static_visitor<> {
std::vector<int64_t>& vector;
explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {}
template <typename T>
void operator()(const T& t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
void operator()(const Dim<0>& t) {}
};
/// @endcond
std::vector<int64_t> vectorize(const DDim& ddim) { std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int64_t> result; std::vector<int64_t> result(DDim::kMaxRank);
VectorizeVisitor visitor(result); dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
boost::apply_visitor(visitor, ddim); result.resize(ddim.size());
return result; return result;
} }
// NOTE: framework::vectorize converts to type int64_t // NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs. // which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) { std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int64_t> temp = vectorize(ddim); std::vector<int> result(DDim::kMaxRank);
std::vector<int> result(temp.begin(), temp.end()); dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result; return result;
} }
struct ProductVisitor : public boost::static_visitor<int64_t> { struct ProductVisitor {
template <int D> template <int D>
int64_t operator()(const Dim<D>& dim) { inline int64_t operator()(const Dim<D>& dim) {
return product(dim); return product(dim);
} }
}; };
int64_t product(const DDim& ddim) { int64_t product(const DDim& ddim) {
ProductVisitor visitor; return ddim.apply_visitor(ProductVisitor());
return boost::apply_visitor(visitor, ddim);
} }
struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int64_t>& vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e)
: vector(v), begin(b), end(e) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
}
template <int S>
void operator()(const Dim<S>& dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
}
--end;
if (end > 0) {
this->operator()(dim.tail);
}
}
void operator()(const Dim<0>& dim) {
PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
}
};
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
std::vector<int64_t> vec; PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
vec.reserve(end - begin); "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
SliceVectorizeVisitor visitor(vec, begin, end); begin, end, dim.size());
boost::apply_visitor(visitor, dim); // Constructor of DDim would check whether end - begin is valid
return make_ddim(vec); return DDim(dim.Get() + begin, end - begin);
} }
/// \cond HIDDEN int arity(const DDim& d) { return d.size(); }
struct ArityVisitor : boost::static_visitor<int> {
template <int D>
int operator()(Dim<D>) const {
return D;
}
};
/// \endcond struct DDimPrinter {
int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); }
/// \cond HIDDEN
struct DDimPrinter : boost::static_visitor<void> {
std::ostream& os; std::ostream& os;
explicit DDimPrinter(std::ostream& os_) : os(os_) {} explicit DDimPrinter(std::ostream& os_) : os(os_) {}
template <typename T> template <int D>
void operator()(const T& t) { void operator()(const Dim<D>& t) {
os << t; os << t;
} }
}; };
/// \endcond
std::ostream& operator<<(std::ostream& os, const DDim& ddim) { std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDimPrinter printer(os); ddim.apply_visitor(DDimPrinter(os));
boost::apply_visitor(printer, ddim);
return os; return os;
} }
DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list);
}
DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size(); return DDim({product(slice_ddim(src, 0, num_col_dims)),
return make_ddim({product(slice_ddim(src, 0, num_col_dims)), product(slice_ddim(src, num_col_dims, src.size()))});
product(slice_ddim(src, num_col_dims, rank))});
} }
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size()); DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1; strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1]; strides[i] = strides[i + 1] * ddim[i + 1];
} }
return framework::make_ddim(strides); return strides;
} }
DDim stride_numel(const framework::DDim& ddim) { DDim stride_numel(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size()); DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1]; strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i]; strides[i] = strides[i + 1] * ddim[i];
} }
return framework::make_ddim(strides); return strides;
} }
} // namespace framework } // namespace framework
......
...@@ -18,62 +18,145 @@ limitations under the License. */ ...@@ -18,62 +18,145 @@ limitations under the License. */
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "paddle/fluid/framework/dim.h" #include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define PADDLE_VISIT_DDIM_BASE(rank, callback) \
case (rank): { \
constexpr auto kRank = (rank); \
return (callback); \
}
#define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \
PADDLE_THROW("Invalid rank %d", rank); \
}
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
PADDLE_VISIT_DDIM(n, (static_dim_assign<kRank, T1, T2>(in, out)));
}
/** /**
* \brief A dynamically sized dimension. * \brief A dynamically sized dimension.
* *
* The number of dimensions must be between [1, 9]. * The number of dimensions must be between [1, 9].
*/ */
struct DDim { class DDim {
typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, public:
Dim<7>, Dim<8>, Dim<9>> constexpr static int kMaxRank = 9;
DDimVar;
DDimVar var; DDim() : rank_(1) { dim_[0] = 0; }
DDim() : var(Dim<1>()) {} DDim(const DDim& ddim) : dim_() { CopyFrom(ddim); }
DDim(const int* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
DDim(const int64_t* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
template <int D> template <int D>
explicit DDim(const Dim<D>& in) : var(in) {} /*implicit*/ DDim(const Dim<D>& in) : rank_(D) { // NOLINT
UnsafeCast<D>() = in;
}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list)
: DDim(init_list.begin(), init_list.size()) {}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list); inline DDim& operator=(const DDim& ddim) { return CopyFrom(ddim); }
template <int D> template <int D>
DDim& operator=(const Dim<D>& in) { inline DDim& operator=(const Dim<D>& dim) {
var = in; rank_ = D;
UnsafeCast<D>() = dim;
return *this; return *this;
} }
int64_t& operator[](int idx); inline int64_t& operator[](int idx) { return dim_[idx]; }
int64_t operator[](int idx) const;
inline int64_t operator[](int idx) const { return dim_[idx]; }
inline int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx];
}
inline int64_t at(int idx) const {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx];
}
template <typename Visitor> template <typename Visitor>
typename Visitor::result_type apply_visitor(Visitor& visitor) { typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor(
return var.apply_visitor(visitor); Visitor&& visitor) {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
} }
template <typename Visitor> template <typename Visitor>
typename Visitor::result_type apply_visitor(Visitor& visitor) const { typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor(
return var.apply_visitor(visitor); Visitor&& visitor) const {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
} }
DDimVar getVar() { return var; } bool operator==(const DDim& d) const;
bool operator!=(const DDim& d) const;
DDim operator+(const DDim& d) const;
bool operator==(DDim d) const; DDim operator*(const DDim& d) const;
bool operator!=(DDim d) const; inline const int64_t* Get() const { return dim_.Get(); }
DDim operator+(DDim d) const; inline int64_t* GetMutable() { return dim_.GetMutable(); }
DDim operator*(DDim d) const; inline int size() const { return rank_; }
private:
template <int D>
inline Dim<D>& UnsafeCast() {
static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
auto* p = static_cast<void*>(&dim_);
return *reinterpret_cast<Dim<D>*>(p);
}
template <int D>
inline const Dim<D>& UnsafeCast() const {
static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
auto* p = static_cast<const void*>(&dim_);
return *reinterpret_cast<const Dim<D>*>(p);
}
int size() const; inline DDim& CopyFrom(const DDim& ddim) {
PADDLE_VISIT_DDIM(ddim.rank_, (*this = ddim.UnsafeCast<kRank>()));
}
friend DDim stride(const DDim& ddim);
friend DDim stride_numel(const DDim& ddim);
private:
Dim<kMaxRank> dim_;
int rank_;
}; };
#undef PADDLE_VISIT_DDIM_BASE
#undef PADDLE_VISIT_DDIM
/** /**
* \brief Make a DDim from std::vector<int64_t> * \brief Make a DDim from std::vector<int64_t>
* *
...@@ -92,7 +175,7 @@ DDim make_ddim(const std::vector<int>& dims); ...@@ -92,7 +175,7 @@ DDim make_ddim(const std::vector<int>& dims);
DDim make_ddim(std::initializer_list<int64_t> dims); DDim make_ddim(std::initializer_list<int64_t> dims);
int64_t get(const DDim& dim, int idx); int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val); void set(DDim& dim, int idx, int val); // NOLINT
std::vector<int64_t> vectorize(const DDim& ddim); std::vector<int64_t> vectorize(const DDim& ddim);
std::vector<int> vectorize2int(const DDim& ddim); std::vector<int> vectorize2int(const DDim& ddim);
...@@ -129,12 +212,3 @@ DDim stride(const DDim& ddim); ...@@ -129,12 +212,3 @@ DDim stride(const DDim& ddim);
DDim stride_numel(const DDim& ddim); DDim stride_numel(const DDim& ddim);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace boost {
template <typename T>
T get(const paddle::framework::DDim& in) {
return boost::get<T>(in.var);
}
} // namespace boost
...@@ -12,17 +12,36 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc ...@@ -12,17 +12,36 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor)
endif()
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else() else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor) variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_rpc)
else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor)
endif()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif() endif()
...@@ -31,12 +50,14 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -31,12 +50,14 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif() cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
...@@ -44,16 +65,20 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he ...@@ -44,16 +65,20 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
cc_test(memory_reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types.cc DEPS framework_proto graph)
cc_test(analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
...@@ -68,4 +93,5 @@ cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fuse ...@@ -68,4 +93,5 @@ cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fuse
cc_library(build_strategy SRCS build_strategy.cc DEPS cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass) fuse_elewise_add_act_pass multi_batch_merge_pass
memory_optimize_pass lock_free_optimize_pass)
...@@ -19,6 +19,13 @@ ...@@ -19,6 +19,13 @@
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue:
// https://github.com/PaddlePaddle/Paddle/issues/15049
DEFINE_bool(
sync_nccl_allreduce, false,
"If set true, will call `cudaStreamSynchronize(nccl_stream)`"
"after allreduce, this mode can get better performance in some scenarios.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -48,10 +55,6 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -48,10 +55,6 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second); platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
WaitInputVarGenerated(); WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs()); auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs()); auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
...@@ -98,16 +101,32 @@ void AllReduceOpHandle::RunImpl() { ...@@ -98,16 +101,32 @@ void AllReduceOpHandle::RunImpl() {
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
ncclSum, comm, stream)); comm, stream));
}); });
} }
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) { for (auto &call : all_reduce_calls) {
call(); call();
} }
}
}); });
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaStreamSynchronize(stream);
}
}
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
...@@ -120,7 +139,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -120,7 +139,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =
...@@ -136,7 +155,6 @@ void AllReduceOpHandle::RunImpl() { ...@@ -136,7 +155,6 @@ void AllReduceOpHandle::RunImpl() {
}); });
} }
} }
}
} }
std::string AllReduceOpHandle::Name() const { return "all_reduce"; } std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
......
此差异已折叠。
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// sort op in bfs order
std::vector<ir::Node*> BFSSortGraphOps(const ir::Graph& graph);
class ControlFlowGraph;
class AnalysisVarPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;
// update program descs
void RenameVarInGraphDesc(const std::string& var,
const std::string& cache_var, size_t idx) const;
// update ir nodes
void RenameVarInGraphNode(const std::string& var,
const std::string& cache_var, size_t idx,
ir::Graph* graph) const;
void SubGraphOptimize(OpDesc* op_desc) const;
// valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node) const;
// scan subblock and collect the output/input variables.
std::unordered_set<std::string> GetSubBlockVars(
const std::unordered_set<ir::Node*>&) const;
// check op has subblock or not
bool OpHasSubBlock(OpDesc* desc) const;
private:
// Reuse Node Pool, Owned.
mutable OrderedNodePairPool pool_;
// controlflow Graph
mutable std::unique_ptr<ControlFlowGraph> cfg_;
// skip set
mutable std::unordered_set<std::string> skip_set_;
// var nodes
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
};
class ControlFlowGraph {
public:
ControlFlowGraph() = default;
// For IR Graph in parallelexecutor
explicit ControlFlowGraph(const ir::Graph& graph);
void LiveVariableAnalysis();
void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx);
const std::set<std::string> LiveIn(ir::Node* op) const;
const std::set<std::string> LiveOut(ir::Node* op) const;
const std::set<std::string> Use(ir::Node* op) const;
const std::vector<ir::Node*> Ops() const;
std::vector<ir::Node*>& Ops();
// for ssa-graph nodes
ir::Node* GetNodeFromVarName(const std::string& name, ir::Node* op) const;
private:
void BuildCFGGraph();
void ConnectNodes();
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
// successors ops use the output variables.
NodeListMap successors_;
// predecessors ops generated input variables.
NodeListMap predecessors_;
// variables lived before run current op.
VarSetMap live_in_;
// variables lived after run current op.
VarSetMap live_out_;
VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs
std::vector<ir::Node*> ops_; // op sequence by topology sort
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/analysis_var_pass.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
class DummyOp : public OperatorBase {
public:
DummyOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class AssignOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
}
};
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace paddle {
namespace framework {
namespace details {
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(proto::VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"e"});
}
return prog;
}
template <typename Container>
inline static std::string DebugString(const Container& c) {
std::stringstream ss;
for (auto& item : c) {
ss << item << " ";
}
return ss.str();
}
TEST(CFGGraph, IRGraph) {
// prepare ir graph
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
ControlFlowGraph cfg(graph);
cfg.LiveVariableAnalysis();
// test assign op
ASSERT_TRUE((std::set<std::string>{"a"} == cfg.LiveIn(cfg.Ops()[0])));
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveOut(cfg.Ops()[0])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveIn(cfg.Ops()[1])));
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveOut(cfg.Ops()[1])));
// test sum op
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveIn(cfg.Ops()[2])));
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveOut(cfg.Ops()[2])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveIn(cfg.Ops()[3])));
ASSERT_TRUE((std::set<std::string>{} == cfg.LiveOut(cfg.Ops()[3])));
}
// 1. normal test
TEST(SortOpLikeDescOrder, NormalTest) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = SortOpLikeDescOrder(graph);
auto op_descs = prog.Block(0).AllOps();
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 2. remove some op_desc
TEST(SortOpLikeDescOrder, RemoveOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = graph.Nodes();
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->outputs.back()->Name() == "e") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
graph.RemoveNode(found_node);
graph.RemoveNode(e);
// other node keeps the same order
auto remain_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < remain_nodes.size(); ++i) {
auto node = remain_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 3. add some op_desc
TEST(SortOpLikeDescOrder, AddOpDesc) {
auto prog = FillProgramDesc();
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto op_descs = prog.Block(0).AllOps();
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
op_descs.insert(op_descs.begin() + 4, op);
auto nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 4. add and delete some op_desc
TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// remove sum node
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
auto nodes = graph.Nodes();
for (auto node : nodes) {
if (node->Name() == "sum") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* c = find_node_in_graph("c");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(c->outputs.begin(), c->outputs.end(), found_node);
ir::Node* pending_op = found_node->outputs[0]->outputs[0];
graph.RemoveNode(e);
graph.RemoveNode(pending_op);
graph.RemoveNode(found_node);
}
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.insert(op_descs.begin() + 2, op);
// check the order
auto mynodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < mynodes.size(); ++i) {
auto node = mynodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 5. add and replace some op_desc inplace.
TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
auto op_descs = prog.Block(0).AllOps();
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.emplace_back(op);
// replace op_desc inplace
auto nodes = graph.Nodes();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->Op() && node->Name() == "assign") {
if (node->outputs.size() == 1 && node->outputs[0]->Name() == "e") {
found_node = node;
break;
}
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(e->inputs.begin(), e->inputs.end(), found_node);
graph.RemoveNode(found_node);
}
op_descs.erase(op_descs.begin() + 3);
auto replace_op = prog.MutableBlock(0)->AppendOp();
replace_op->SetType("sum");
replace_op->SetInput("X", {"d", "d1"});
replace_op->SetOutput("Out", {"e"});
{
ir::Node* sum2 = graph.CreateOpNode(replace_op);
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
ir::Node* d1 = find_node_in_graph("d1");
sum2->inputs.emplace_back(d);
sum2->inputs.emplace_back(d1);
sum2->outputs.emplace_back(e);
e->inputs.emplace_back(sum2);
d->outputs.emplace_back(sum2);
d1->outputs.emplace_back(sum2);
}
op_descs.emplace_back(replace_op);
// compare op order
auto graph_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < graph_nodes.size(); ++i) {
auto node = graph_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -60,20 +60,35 @@ struct BuildStrategy { ...@@ -60,20 +60,35 @@ struct BuildStrategy {
kCustomized = 2, kCustomized = 2,
}; };
enum class OptimizeStrategy {
// To be Implemented,bruteforce, recursive compute unused var names.
kBruteForce = 0,
kControlFlowGraph = 1, // use cfg_graph algorithm, faster speed.
};
ReduceStrategy reduce_{ReduceStrategy::kAllReduce}; ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
OptimizeStrategy strategy_{OptimizeStrategy::kControlFlowGraph};
std::string debug_graphviz_path_{""}; std::string debug_graphviz_path_{""};
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false}; bool memory_optimize_{false};
bool memory_early_delete_{false};
bool enable_sequential_execution_{false}; bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool is_distribution_{false};
int num_trainers_{1}; int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{false}; bool remove_unnecessary_lock_{false};
// NOTE: // NOTE:
...@@ -91,20 +106,29 @@ struct BuildStrategy { ...@@ -91,20 +106,29 @@ struct BuildStrategy {
bool IsFinalized() const { return is_finalized_; } bool IsFinalized() const { return is_finalized_; }
bool IsMultiDevPass(const std::string &pass_name) const;
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program,
const ProgramDesc &main_program,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const; const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
#else #else
const bool use_cuda) const; const bool use_cuda) const;
#endif #endif
// If set true, ParallelExecutor would build the main_program into multiple
// graphs,
// each of the graphs would run with one device. This approach can achieve
// better performance
// on some scenarios.
mutable bool enable_parallel_graph_ = false;
private: private:
mutable bool is_finalized_ = false; mutable bool is_finalized_ = false;
mutable std::shared_ptr<ir::PassBuilder> pass_builder_; mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
......
...@@ -20,11 +20,13 @@ namespace paddle { ...@@ -20,11 +20,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place) platform::Place place,
size_t scope_idx)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())), op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope), scope_(scope),
place_(place) {} place_(place),
scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
......
此差异已折叠。
...@@ -25,7 +25,7 @@ struct ExecutionStrategy { ...@@ -25,7 +25,7 @@ struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_cuda_{true}; bool use_cuda_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100}; size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
}; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册