未验证 提交 a79a3ea2 编写于 作者: G guru4elephant 提交者: GitHub

Merge branch 'develop' into develop

......@@ -54,7 +54,7 @@ option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON)
option(WITH_DOUBLE "Compile PaddlePaddle with double precision" OFF)
option(WITH_RDMA "Compile PaddlePaddle with RDMA support" OFF)
option(WITH_TIMER "Compile PaddlePaddle with stats timer" OFF)
option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler" OFF)
option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF)
option(WITH_DOC "Compile PaddlePaddle with documentation" OFF)
option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF)
option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
......@@ -132,8 +132,6 @@ if (APPLE OR WIN32)
endif()
if (WIN32)
set(WITH_AVX OFF CACHE STRING
"Disable AVX when compiling for Windows" FORCE)
set(WITH_DSO OFF CACHE STRING
"Disable DSO when compiling for Windows" FORCE)
set(WITH_MKL OFF CACHE STRING
......@@ -261,6 +259,12 @@ elseif()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in MKL only now." FORCE)
endif()
if (WITH_PROFILER)
find_package(Gperftools REQUIRED)
include_directories(${GPERFTOOLS_INCLUDE_DIR})
add_definitions(-DWITH_GPERFTOOLS)
endif()
include(generic) # simplify cmake module
include(package) # set paddle packages
include(ccache) # set ccache for compilation
......
......@@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
......@@ -19,7 +19,16 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 1.1.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.1)
欢迎来到 PaddlePaddle GitHub
PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效灵活、可扩展的深度学习平台,最初由百度科学家和工程师共同开发,目的是将深度学习技术应用到百度的众多产品中。
我们的愿景是让每个人都能通过PaddlePaddle接触深度学习
跟进PaddlePaddle最新特性请参考我们的[版本说明](https://github.com/PaddlePaddle/Paddle/releases)
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Install Latest Stable Release:
```
# Linux CPU
......@@ -27,13 +36,30 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.1.0.post87
pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.1.0.post85
pip install paddlepaddle-gpu==1.2.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/
```
### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### 安装最新稳定版本:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
# 其他平台上的安装指引请参考 http://paddlepaddle.org/
```
## Features
- **Flexibility**
......@@ -74,35 +100,90 @@ pip install paddlepaddle-gpu==1.1.0.post85
Baidu and it has achieved a significant impact. We hope you can also explore
the capability of PaddlePaddle to make an impact on your product.
## 特点
- **灵活性**
PaddlePaddle支持丰富的神经网络架构和优化算法。易于配置复杂模型,例如带有注意力机制或复杂记忆连接的神经网络机器翻译模型。
- **高效性**
为了高效使用异步计算资源,PaddlePaddle对框架的不同层进行优化,包括计算、存储、架构和通信。下面是一些样例:
- 通过SSE/AVX 内置函数、BLAS库(例如MKL、OpenBLAS、cuBLAS)或定制的CPU/GPU内核优化数学操作。
- 通过MKL-DNN库优化CNN网络
- 高度优化循环网络,无需执行 `padding` 操作即可处理 **变长** 序列
- 针对高维稀疏数据模型,优化了局部和分布式训练。
- **稳定性**
有了 PaddlePaddle,使得利用各种CPU/GPU和机器来加速训练变得简单。PaddlePaddle 通过优化通信可以实现巨大吞吐量和快速执行。
- **连接产品**
另外,PaddlePaddle 的设计也易于部署。在百度,PaddlePaddle 已经部署到含有巨大用户量的产品和服务上,包括广告点击率(CTR)预测、大规模图像分类、光学字符识别(OCR)、搜索排序,计算机病毒检测、推荐系统等等。PaddlePaddle广泛应用于百度产品中,产生了非常重要的影响。我们希望您也能探索 PaddlePaddle 的能力,为您的产品创造新的影响力和效果。
## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) on our website.
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website.
## 安装
推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html)
## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) documentation.
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.1/user_guides/howto/training/cluster_howto.html)
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/api/zh/1.1/fluid.html)
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.1/advanced_usage/development/contribute_to_paddle.html)
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
We appreciate your contributions!
## 文档
我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[中文](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) 文档
- [深度学习101](https://github.com/PaddlePaddle/book)
或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
可以在MPI集群上运行分布式训练任务
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
新的API支持代码更少更简洁的程序
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
欢迎您的贡献!
## Ask Questions
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
## 答疑
欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交
## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供
......@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = distribute_transpiler.DistributeTranspilerConfig()
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to
......
# Tries to find Gperftools.
#
# Usage of this module as follows:
#
# find_package(Gperftools)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# Gperftools_ROOT_DIR Set this variable to the root installation of
# Gperftools if the module has problems finding
# the proper installation path.
#
# Variables defined by this module:
#
# GPERFTOOLS_FOUND System has Gperftools libs/headers
# GPERFTOOLS_LIBRARIES The Gperftools libraries (tcmalloc & profiler)
# GPERFTOOLS_INCLUDE_DIR The location of Gperftools headers
find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_PROFILER
NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
find_library(GPERFTOOLS_TCMALLOC_AND_PROFILER
NAMES tcmalloc_and_profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
find_path(GPERFTOOLS_INCLUDE_DIR
NAMES gperftools/heap-profiler.h
HINTS ${Gperftools_ROOT_DIR}/include)
set(GPERFTOOLS_LIBRARIES ${GPERFTOOLS_TCMALLOC_AND_PROFILER})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
Gperftools
DEFAULT_MSG
GPERFTOOLS_LIBRARIES
GPERFTOOLS_INCLUDE_DIR)
mark_as_advanced(
Gperftools_ROOT_DIR
GPERFTOOLS_TCMALLOC
GPERFTOOLS_PROFILER
GPERFTOOLS_TCMALLOC_AND_PROFILER
GPERFTOOLS_LIBRARIES
GPERFTOOLS_INCLUDE_DIR)
# create IMPORTED targets
if (Gperftools_FOUND AND NOT TARGET gperftools::tcmalloc)
add_library(gperftools::tcmalloc UNKNOWN IMPORTED)
set_target_properties(gperftools::tcmalloc PROPERTIES
IMPORTED_LOCATION ${GPERFTOOLS_TCMALLOC}
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
add_library(gperftools::profiler UNKNOWN IMPORTED)
set_target_properties(gperftools::profiler PROPERTIES
IMPORTED_LOCATION ${GPERFTOOLS_PROFILER}
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
endif()
......@@ -90,6 +90,7 @@ endif()
if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
FIND_PACKAGE(CUDA REQUIRED)
......
......@@ -14,14 +14,16 @@
INCLUDE(ExternalProject)
find_library(SSL_LIBRARY NAMES ssl)
find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY})
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})
find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY})
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
......@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib")
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a"
GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS}
......@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc)
......@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
IF(WITH_TESTING)
#FIXME:(gongwb) Move brpc's gtest dependency.
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF(WITH_TESTING)
INCLUDE(ExternalProject)
SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
......@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING)
ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
......@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
GIT_REPOSITORY "https://github.com/google/leveldb"
GIT_TAG v1.18
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
......
......@@ -32,6 +32,8 @@ IF(NOT ${WITH_NGRAPH})
return()
ENDIF()
INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph")
......@@ -40,10 +42,14 @@ SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION})
SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so)
SET(NGRAPH_TBB_LIB_NAME libtbb.so.2)
SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
ExternalProject_Add(
${NGRAPH_PROJECT}
......@@ -63,18 +69,6 @@ ExternalProject_Add(
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib
)
if(UNIX AND NOT APPLE)
include(GNUInstallDirs)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
else()
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/lib)
endif()
MESSAGE(STATUS "nGraph lib will be installed at: ${NGRAPH_LIB_DIR}")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
# Workaround for nGraph expecting mklml to be in mkldnn install directory.
ExternalProject_Add_Step(
${NGRAPH_PROJECT}
......
......@@ -24,12 +24,6 @@ set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE)
if (WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
ExternalProject_Add(
extern_snappy
GIT_REPOSITORY "https://github.com/google/snappy"
......@@ -56,6 +50,16 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
add_custom_command(TARGET extern_snappy POST_BUILD
COMMAND cmake -E copy ${SNAPPY_INSTALL_DIR}/lib/snappy.lib ${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib
)
ENDIF()
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
add_library(snappy STATIC IMPORTED GLOBAL)
set_property(TARGET snappy PROPERTY IMPORTED_LOCATION ${SNAPPY_LIBRARIES})
......
......@@ -56,7 +56,12 @@ else()
endif()
if (WIN32)
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/xxhash.lib")
IF(NOT EXISTS "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
add_custom_command(TARGET extern_xxhash POST_BUILD
COMMAND cmake -E copy ${XXHASH_INSTALL_DIR}/lib/xxhash.lib ${XXHASH_INSTALL_DIR}/lib/libxxhash.lib
)
ENDIF()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
else()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
endif ()
......
......@@ -19,12 +19,6 @@ SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib)
SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE)
SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE)
IF(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${ZLIB_INCLUDE_DIR}) # For zlib code to include its own headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include zlib.h.
......@@ -49,6 +43,16 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${ZLIB_INSTALL_DIR}/lib/libz.lib")
add_custom_command(TARGET extern_zlib POST_BUILD
COMMAND cmake -E copy ${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib ${ZLIB_INSTALL_DIR}/lib/libz.lib
)
ENDIF()
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET zlib PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES})
......
......@@ -110,6 +110,14 @@ function(find_fluid_modules TARGET_NAME)
endif()
endfunction(find_fluid_modules)
function(common_link TARGET_NAME)
if (WITH_PROFILER)
target_link_libraries(${TARGET_NAME} gperftools::profiler)
endif()
endfunction()
# find all third_party modules is used for paddle static library
# for reduce the dependency when building the inference libs.
set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY)
......@@ -274,6 +282,7 @@ function(cc_library TARGET_NAME)
endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
common_link(${TARGET_NAME})
endif()
# cpplint code style
......@@ -340,6 +349,7 @@ function(cc_binary TARGET_NAME)
if(cc_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${cc_binary_DEPS})
add_dependencies(${TARGET_NAME} ${cc_binary_DEPS})
common_link(${TARGET_NAME})
endif()
endfunction(cc_binary)
......@@ -362,6 +372,7 @@ function(cc_test TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${win32_deps})
endif(WIN32)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......@@ -420,6 +431,7 @@ function(nv_binary TARGET_NAME)
if(nv_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS})
add_dependencies(${TARGET_NAME} ${nv_binary_DEPS})
common_link(${TARGET_NAME})
endif()
endif()
endfunction(nv_binary)
......@@ -433,6 +445,7 @@ function(nv_test TARGET_NAME)
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
......@@ -499,6 +512,7 @@ function(hip_binary TARGET_NAME)
if(hip_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS})
add_dependencies(${TARGET_NAME} ${hip_binary_DEPS})
common_link(${TARGET_NAME})
endif()
endif()
endfunction(hip_binary)
......@@ -518,6 +532,7 @@ function(hip_test TARGET_NAME)
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP)
target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags)
add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
endif()
endfunction(hip_test)
......@@ -560,6 +575,7 @@ function(go_library TARGET_NAME)
endif()
if(go_library_DEPS)
add_dependencies(${TARGET_NAME} ${go_library_DEPS})
common_link(${TARGET_NAME})
endif(go_library_DEPS)
# The "source file" of the library is `${dummyfile}` which never
......
......@@ -32,13 +32,23 @@ function(copy TARGET)
list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst)
if (WIN32)
if(IS_DIRECTORY ${src})
get_filename_component(last_path ${src} NAME)
string(APPEND dst "/" ${last_path})
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}"
)
if(EXISTS ${src})
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND cmake -E copy_directory "${src}" "${dst}"
COMMENT "copying ${src} -> ${dst}")
else()
message(WARNING "${src} not exist!")
endif()
else()
# windows cmd shell will not expand wildcard automatically.
# below expand the files,libs and copy them by rules.
file(GLOB header_files ${src} "*.h")
file(GLOB static_lib_files ${src} "*.lib")
file(GLOB dll_lib_files ${src} "*.dll")
set(src_files ${header_files} ${static_lib_files} ${dll_lib_files})
# below expand the files, and copy them by rules.
file(GLOB src_files ${src})
if (NOT "${src_files}" STREQUAL "")
list(REMOVE_DUPLICATES src_files)
endif ()
......@@ -50,6 +60,7 @@ function(copy TARGET)
COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}"
COMMENT "copying ${src_file} -> ${dst}")
endforeach ()
endif()
else (WIN32) # not windows
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND mkdir -p "${dst}"
......@@ -95,7 +106,7 @@ copy(xxhash_lib
DEPS xxhash
)
if (NOT PROTOBUF_FOUND)
if (NOT PROTOBUF_FOUND OR WIN32)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf")
copy(protobuf_lib
SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY}
......@@ -129,8 +140,16 @@ if (WITH_MKLDNN)
)
endif ()
if (NOT WIN32)
if (NOT MOBILE_INFERENCE AND NOT RPI)
if (WITH_NGRAPH)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/ngraph")
copy(ngraph_lib
SRCS ${NGRAPH_INC_DIR} ${NGRAPH_LIB_DIR}
DSTS ${dst_dir} ${dst_dir}
DEPS ngraph
)
endif ()
if (NOT MOBILE_INFERENCE AND NOT RPI)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy")
copy(snappy_lib
SRCS ${SNAPPY_INCLUDE_DIR} ${SNAPPY_LIBRARIES}
......@@ -148,8 +167,7 @@ if (NOT WIN32)
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib
DEPS zlib)
endif ()
endif (NOT WIN32)
endif ()
# paddle fluid module
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
......@@ -183,8 +201,13 @@ if (WITH_ANAKIN AND WITH_MKL)
endif ()
set(module "inference")
if(WIN32)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/libpaddle_fluid.*)
else(WIN32)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*)
endif(WIN32)
copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
SRCS ${src_dir}/${module}/*.h ${paddle_fluid_lib}
${src_dir}/${module}/api/paddle_*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
)
......@@ -224,7 +247,7 @@ copy(third_party DEPS fluid_lib_dist
# only need libpaddle_fluid.so/a and paddle_*.h for inference-only library
copy(inference_api_lib DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.*
SRCS ${paddle_fluid_lib}
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_*.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
)
......
......@@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
......
......@@ -74,6 +74,7 @@ paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr
paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
paddle.fluid.layers.bpr_loss ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None))
......@@ -84,6 +85,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
......@@ -190,7 +193,7 @@ paddle.fluid.layers.clip ArgSpec(args=['x', 'min', 'max', 'name'], varargs=None,
paddle.fluid.layers.clip_by_norm ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name'], varargs=None, keywords=None, defaults=(-100, None))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
......@@ -202,6 +205,10 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......@@ -306,6 +313,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......@@ -426,3 +434,17 @@ paddle.fluid.Scope.drop_kids drop_kids(self: paddle.fluid.core.Scope) -> None
paddle.fluid.Scope.find_var find_var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable
paddle.fluid.Scope.new_scope new_scope(self: paddle.fluid.core.Scope) -> paddle.fluid.core.Scope
paddle.fluid.Scope.var var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable
paddle.reader.map_readers ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None)
paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None)
paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None)
paddle.reader.chain ArgSpec(args=[], varargs='readers', keywords=None, defaults=None)
paddle.reader.shuffle ArgSpec(args=['reader', 'buf_size'], varargs=None, keywords=None, defaults=None)
paddle.reader.firstn ArgSpec(args=['reader', 'n'], varargs=None, keywords=None, defaults=None)
paddle.reader.xmap_readers ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,))
paddle.reader.PipeReader.__init__ ArgSpec(args=['self', 'command', 'bufsize', 'file_type'], varargs=None, keywords=None, defaults=(8192, 'plain'))
paddle.reader.PipeReader.get_line ArgSpec(args=['self', 'cut_lines', 'line_break'], varargs=None, keywords=None, defaults=(True, '\n'))
paddle.reader.multiprocess_reader ArgSpec(args=['readers', 'use_pipe', 'queue_size'], varargs=None, keywords=None, defaults=(True, 1000))
paddle.reader.Fake.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.reader.creator.np_array ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.reader.creator.text_file ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None)
paddle.reader.creator.recordio ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,))
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(imperative)
add_subdirectory(operators)
add_subdirectory(string)
add_subdirectory(recordio)
......
......@@ -3,8 +3,9 @@
#We create a hidden file and compile it instead of origin source file.
function(windows_symbolic TARGET)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(multiValueArgs SRCS PATH)
cmake_parse_arguments(windows_symbolic "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(final_path ${CMAKE_CURRENT_SOURCE_DIR}/${windows_symbolic_PATH})
foreach(src ${windows_symbolic_SRCS})
get_filename_component(src ${src} NAME_WE)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu)
......@@ -72,6 +73,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
......@@ -118,8 +121,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......@@ -127,11 +131,14 @@ cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
if(NOT WIN32)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler)
endif(NOT WIN32)
if(WITH_NGRAPH)
if(NOT WIN32)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler ngraph)
endif(NOT WIN32)
endif(WITH_NGRAPH)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......@@ -164,18 +171,27 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
if(WITH_NGRAPH)
if(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph ngraph_operator variable_helper)
else(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(NOT WIN32)
else(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(WITH_NGRAPH)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
target_link_libraries(executor garbage_collector)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph build_strategy
......@@ -196,7 +212,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
......
......@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) {
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
feed_vec_[i] = var->GetMutable<LoDTensor>();
}
}
}
......@@ -201,22 +197,22 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
for (size_t i = 0; i < all_slots_.size(); ++i) {
int num = strtol(endptr, &endptr, 10);
if (num < 0) {
VLOG(1) << "error: the number of ids is a negative number: " << num;
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "error: the number of ids is a negative number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
} else if (num == 0) {
VLOG(1)
VLOG(0)
<< "error: the number of ids can not be zero, you need "
"padding it in data generator; or if there is something wrong"
" with the data, please check if the data contains unresolvable "
"characters.";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
} else if (errno == ERANGE || num > INT_MAX) {
VLOG(1) << "error: the number of ids greater than INT_MAX";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "error: the number of ids greater than INT_MAX";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
......@@ -224,15 +220,15 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
for (int i = 0; i < num; ++i) {
strtof(endptr, &endptr);
if (errno == ERANGE) {
VLOG(1) << "error: the value is out of the range of "
VLOG(0) << "error: the value is out of the range of "
"representable values for float";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
if (i + 1 != num && endptr - str == len) {
VLOG(1) << "error: there is a wrong with the number of ids.";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
......@@ -241,32 +237,43 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
for (int i = 0; i < num; ++i) {
strtoull(endptr, &endptr, 10);
if (errno == ERANGE) {
VLOG(1) << "error: the value is out of the range of "
VLOG(0) << "error: the value is out of the range of "
"representable values for uint64_t";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
if (i + 1 != num && endptr - str == len) {
VLOG(1) << "error: there is a wrong with the number of ids.";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
}
} else {
VLOG(1) << "error: this type<" << all_slots_type_[i]
VLOG(0) << "error: this type<" << all_slots_type_[i]
<< "> is not supported";
return false;
}
}
if (endptr - str != len) {
VLOG(1) << "error: there is some data at the end of the line.";
VLOG(1) << "please check line<" << instance_cout << "> in file<"
// It may be added '\t' character to the end of the output of reduce
// task when processes data by Hadoop(when the output of the reduce
// task of Hadoop has only one field, it will add a '\t' at the end
// of the line by default, and you can use this option to avoid it:
// `-D mapred.textoutputformat.ignoreseparator=true`), which does
// not affect the correctness of the data. Therefore, it should be
// judged that the data is not normal when the end of each line of
// data contains characters which are not spaces.
while (endptr - str != len) {
if (!isspace(*(endptr++))) {
VLOG(0)
<< "error: there is some extra characters at the end of the line.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
}
}
VLOG(3) << "instances cout: " << instance_cout;
VLOG(3) << "The file format is correct";
return true;
......@@ -291,6 +298,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
......@@ -327,6 +335,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
(*ins_vec)[i].InitOffset();
}
}
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
......@@ -338,36 +347,25 @@ void MultiSlotDataFeed::PutToFeedVec(
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<float>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data<float>(
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<int64_t>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr =
feed_vec_[i].GetLoDTensor()->mutable_data<int64_t>(
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
}
}
}
......
......@@ -30,35 +30,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
// Pack Tensor type and LoDTensor type into MixTensor type, in order
// to record either Tensor or LoDTensor information at the same time.
class MixTensor {
public:
MixTensor() {}
explicit MixTensor(LoDTensor* lodtensor) {
is_dense_ = false;
lodtensor_ = lodtensor;
}
explicit MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() { return is_dense_; }
LoDTensor* GetLoDTensor() {
PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr.");
return lodtensor_;
}
Tensor* GetTensor() {
PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr.");
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
};
// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
......@@ -133,7 +104,7 @@ class DataFeed {
use_slots_index_; // -1: not used; >=0: the index of use_slots_
// The data read by DataFeed will be stored here
std::vector<MixTensor> feed_vec_;
std::vector<LoDTensor*> feed_vec_;
// the batch size defined by user
int default_batch_size_;
......
......@@ -152,21 +152,15 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
const auto& multi_slot_desc = data_feed_desc.multi_slot_desc();
std::map<std::string, const paddle::framework::LoDTensor*>
lodtensor_targets;
std::map<std::string, const paddle::framework::Tensor*> tensor_targets;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
const auto& name = slot.name();
readers[idx]->AddFeedVar(scope->Var(name), name);
if (slot.is_dense()) {
tensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::Tensor>();
} else {
lodtensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
}
}
}
readers[idx]->Start();
while (readers[idx]->Next()) {
int index = 0;
......@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
if (!slot.is_used()) {
continue;
}
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.is_dense()) { // dense branch
const paddle::framework::Tensor* tens = tensor_targets[slot.name()];
if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>();
int batch_size = tens->dims()[0];
......@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
PADDLE_THROW("Error type in proto file.");
}
} else { // sparse branch
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>();
for (size_t i = 0; i < tens->NumElements(); ++i) {
......
......@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType(
framework::ToDataType(in.type()),
in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_);
......@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>());
return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16:
......@@ -144,26 +144,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name());
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
void* in_data = GetDataFromTensor(in, in_type);
// output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims());
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory);
} else {
out->ShareDataWith(in);
}
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef);
......
......@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
}
}
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
static const std::map<std::type_index, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}};
auto iter = dict.find(type);
inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static std::unordered_map<int, MKLDNNDataType> dict{
{DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef;
}
......
......@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
std::unordered_map<int, size_t> proto_to_size_;
};
static DataTypeMap* InitDataTypeMap();
......@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
}
static DataTypeMap* InitDataTypeMap() {
......@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
_ForEachDataType_(RegType);
#undef RegType
return retv;
......@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type));
}
size_t SizeOfType(std::type_index type) {
auto it = gDataTypeMap().cpp_to_size_.find(type);
if (it != gDataTypeMap().cpp_to_size_.end()) {
size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
}
} // namespace framework
......
......@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
visitor.template apply<platform::float16>();
break;
case proto::VarType::FP32:
visitor.template apply<float>();
break;
case proto::VarType::FP64:
visitor.template apply<double>();
break;
case proto::VarType::INT32:
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
#define VisitDataTypeCallback(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachDataType_(VisitDataTypeCallback);
#undef VisitDataTypeCallback
PADDLE_THROW("Not supported %d", type);
}
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
......
......@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype));
tensor.mutable_data(cpu, dtype);
// test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16)));
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u);
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "float16";
std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
......@@ -12,17 +12,36 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor)
endif()
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_rpc)
else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor)
endif()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif()
......@@ -33,10 +52,10 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
......@@ -44,10 +63,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
......@@ -48,7 +48,14 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
// this is a distributed or inter-process call, find a better way.
#ifdef PADDLE_WITH_CUDA
if (NoDummyInputSize() == 1 &&
local_scopes_[0]->FindLocalVar(NCCL_ID_VARNAME) == nullptr) {
#else
if (NoDummyInputSize() == 1) {
#endif
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
......@@ -120,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
......
......@@ -58,10 +58,23 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
if (strategy_.trainer_id_ > 0) {
PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
}
VLOG(1) << "CollectiveContext:" << context->String();
// Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
multi_devices_pass->Set<int>("num_trainers",
new int(strategy_.num_trainers_));
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
......@@ -133,7 +146,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "sequential_execution_pass") {
VLOG(1) << "set enable_sequential_execution:"
LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_;
pass->Erase(kAllOpDescs);
......@@ -141,7 +154,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "all_reduce_deps_pass") {
VLOG(1) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_;
pass->Erase(kAllOpDescs);
......
......@@ -74,6 +74,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{false};
// NOTE:
......
......@@ -20,11 +20,13 @@ namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place)
platform::Place place,
size_t scope_idx)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place) {}
place_(place),
scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
......
......@@ -28,7 +28,8 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx);
std::string Name() const override;
......@@ -38,6 +39,8 @@ struct ComputationOpHandle : public OpHandleBase {
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
size_t GetScopeIdx() const { return scope_idx_; }
protected:
void RunImpl() override;
......@@ -47,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
size_t scope_idx_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
namespace paddle {
namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (dynamic_cast<StreamGarbageCollector *>(gc_)) {
platform::CUDADeviceGuard guard(
boost::get<platform::CUDAPlace>(place).device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
PADDLE_ENFORCE_NOT_NULL(event_);
}
}
#endif
}
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#ifdef PADDLE_WITH_CUDA
if (event_) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
#endif
}
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
// Var not found, not reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue;
}
auto *var = exec_scope->FindVar(name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << name;
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
garbages.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto *tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *tensor_arr) {
garbages.emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
}
}
if (!garbages.empty()) {
ClearGarbages(&garbages);
}
}
void EagerDeletionOpHandle::ClearGarbages(
std::deque<std::shared_ptr<memory::Allocation>> *garbages) {
#ifdef PADDLE_WITH_CUDA
if (event_) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream =
reinterpret_cast<StreamGarbageCollector *>(gc_)->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(std::move(*garbages), callback_func);
} else {
#endif
gc_->Add(std::move(*garbages));
#ifdef PADDLE_WITH_CUDA
}
#endif
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <string>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
class Scope;
namespace details {
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle();
std::string Name() const override;
protected:
void RunImpl() override;
private:
void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_;
std::unordered_set<std::string> var_names_;
GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
const auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
// a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted.
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
op_vars_map;
for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first;
for (auto *op : var_ops_pair.second) {
op_vars_map[op].insert(var_name);
}
}
}
for (auto &pair : op_vars_map) {
auto *op = pair.first;
auto &var_names = pair.second;
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var);
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
}
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
const proto::VarType::Type var_type)
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
......@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_;
const proto::VarType::Type type_;
int64_t total_numel_;
};
} // namespace details
......
......@@ -133,6 +133,7 @@ static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
static const char kNumTrainers[] = "num_trainers";
void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
......@@ -299,6 +300,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
int num_trainers = Get<int>(kNumTrainers);
for (auto &node : nodes) {
if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var());
......@@ -383,7 +386,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps(&result, node, places_.size());
}
if (!is_forwarding && places_.size() > 1) {
if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
......@@ -562,7 +565,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id]));
local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id);
}
......@@ -685,8 +688,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
result->CreateOpNode(node->Op()), s, p, scope_idx));
CreateOpHandleIOs(result, node, scope_idx);
}
}
......@@ -895,4 +898,5 @@ REGISTER_PASS(multi_devices_pass,
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy);
.RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNumTrainers);
......@@ -23,6 +23,8 @@ namespace details {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) {
preceding_ops_[op];
pending_ops_[op];
......@@ -40,6 +42,7 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
......
......@@ -14,7 +14,7 @@
#pragma once
#include <memory>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -34,6 +34,11 @@ class OpGraphView {
bool HasOp(OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
......@@ -44,6 +49,28 @@ class OpGraphView {
pending_ops_;
};
template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
Callback &&callback) const {
EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
q.push(op);
do {
op = q.front();
q.pop();
for (auto &pending_op : pending_ops_.at(op)) {
if (visited.count(pending_op) == 0) {
visited.insert(pending_op);
if (!callback(pending_op)) {
return false;
}
}
}
} while (!q.empty());
return true;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -53,7 +53,7 @@ struct ReduceLoDTensor {
}
};
inline void GatherSelectedRows(
inline void GatherLocalSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
......
......@@ -16,6 +16,12 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#endif
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool(
......@@ -26,6 +32,112 @@ namespace paddle {
namespace framework {
namespace details {
std::once_flag CollectiveContext::init_flag_;
std::unique_ptr<CollectiveContext> CollectiveContext::context_;
static inline std::string GetRemoteVarName(const std::string &var_name,
int trainer_id) {
return string::Sprintf("%s_merged_tmp@trainer_%d", var_name, trainer_id);
}
void ReduceOpHandle::Wait(
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes) {
// TODO(gongwb): use event wait?
for (auto &dev_ctx : dev_ctxes) {
dev_ctx.second->Wait();
}
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType>
void ReduceOpHandle::GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selected_rows,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
VarHandle *out_var_handle, const platform::Place &out_place,
SelectedRows *dst_selected_rows) {
const CollectiveContext &collective_context =
*CollectiveContext::GetInstance();
// 1. gather local selected rows, merge them
std::string gathered_var_name = out_var_handle->name_ + "_gathered_tmp";
auto scope = local_scopes_.at(out_var_handle->scope_idx_);
auto gathered_var_mid = scope->Var(gathered_var_name);
auto gathered_select_rows =
gathered_var_mid->GetMutable<framework::SelectedRows>();
GatherLocalSelectedRows(src_selected_rows, in_places, dev_ctxes, out_place,
gathered_select_rows);
// FIXME(gongwb): remove this Wait.
Wait(dev_ctxes);
// merge them
auto merged_dev_ctx = dynamic_cast<DevCtx *>(dev_ctxes.at(out_place));
std::string merged_var_name =
GetRemoteVarName(out_var_handle->name_, collective_context.trainer_id_);
auto merged_select_rows =
scope->Var(merged_var_name)->GetMutable<SelectedRows>();
operators::math::scatter::MergeAdd<DevCtx, DataType> merge_func;
merge_func(*merged_dev_ctx, *gathered_select_rows, merged_select_rows);
// 2. start collective server if it doesn't exist
operators::distributed::CollectiveServer *server =
operators::distributed::CollectiveServer::GetInstance(
collective_context.endpoints_[collective_context.trainer_id_],
collective_context.endpoints_.size() - 1);
auto rpc_server = server->GetRPCServer();
rpc_server->RegisterVar(merged_var_name,
operators::distributed::kRequestGetMonomerVariable,
scope, merged_dev_ctx);
// 3. gather them from all remote nodes.
std::vector<const SelectedRows *> remote;
operators::distributed::CollectiveClient *client =
operators::distributed::CollectiveClient::GetInstance();
std::vector<operators::distributed::RemoteVar> vars;
for (unsigned int i = 0; i < collective_context.endpoints_.size(); i++) {
if (i == (unsigned)collective_context.trainer_id_) continue;
operators::distributed::RemoteVar var;
var.trainer_id_ = i;
var.var_name_ = GetRemoteVarName(out_var_handle->name_, i);
var.ep_ = collective_context.endpoints_[i];
vars.push_back(var);
VLOG(4) << "gather from:" << var.String();
}
// erase gathered vars
merged_dev_ctx->Wait();
scope->EraseVars(std::vector<std::string>{gathered_var_name});
PADDLE_ENFORCE(client->Gather(vars, &remote, *merged_dev_ctx, scope));
PADDLE_ENFORCE(remote.size() == vars.size());
// 4. merged local selected rows.
std::vector<const SelectedRows *> all;
all.resize(collective_context.endpoints_.size());
for (auto v : vars) {
all[v.trainer_id_] =
scope->FindVar(v.var_name_)->GetMutable<SelectedRows>();
}
all[collective_context.trainer_id_] = merged_select_rows;
merge_func(*merged_dev_ctx, all, dst_selected_rows);
rpc_server->WaitVarBarrier(merged_var_name);
rpc_server->ClearVar(merged_var_name);
// 5. clear mid vars
std::vector<std::string> tmp_vars{merged_var_name};
for (auto r : vars) {
tmp_vars.push_back(r.var_name_);
}
scope->EraseVars(tmp_vars);
}
#endif
void ReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
......@@ -90,8 +202,36 @@ void ReduceOpHandle::RunImpl() {
this->RunAndRecordEvent([&] {
std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
const CollectiveContext &collective_context =
*CollectiveContext::GetInstance();
VLOG(10) << "GatherSelectedRows CollectiveContext:"
<< collective_context.String();
// TODO(gongwb): add cpu support
if (collective_context.endpoints_.size() <= 1 ||
is_cpu_place(in_places[0]) || is_cpu_place(t_out_p)) {
GatherLocalSelectedRows(in_selected_rows, in_places, dev_ctxes_,
t_out_p,
out_var->GetMutable<framework::SelectedRows>());
return;
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if (in_selected_rows[0]->value().type() ==
framework::proto::VarType::FP32) {
GatherSelectedRows<platform::CUDADeviceContext, float>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
} else if (in_selected_rows[0]->value().type() ==
framework::proto::VarType::FP64) {
GatherSelectedRows<platform::CUDADeviceContext, double>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
} else {
PADDLE_THROW("only support double or float when gather SelectedRows");
}
#endif
});
} else {
std::vector<const LoDTensor *> lod_tensors =
......@@ -106,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
} else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
......@@ -116,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
......@@ -30,6 +30,32 @@
namespace paddle {
namespace framework {
namespace details {
struct CollectiveContext {
std::vector<std::string> endpoints_;
int trainer_id_{0};
std::string String() const {
std::stringstream ss;
ss << "endpoints_:";
for (auto e : endpoints_) {
ss << e << ",";
}
ss << "trainer_id_:" << trainer_id_;
return ss.str();
}
static CollectiveContext *GetInstance() {
std::call_once(init_flag_,
[&]() { context_.reset(new CollectiveContext()); });
return context_.get();
}
private:
static std::once_flag init_flag_;
static std::unique_ptr<CollectiveContext> context_;
};
struct ReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_;
......@@ -64,6 +90,19 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType>
void GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
VarHandle *out_var_handle, const platform::Place &out_place,
SelectedRows *dst_selecte_rows);
#endif
void Wait(
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes);
template <typename T>
std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
using ReferenceCountMap = std::unordered_map<std::string, int>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<int>>;
using DeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
using AtomicDeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
using DeviceGarbageCollectorMap =
std::unordered_map<int,
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
class ReferenceCountOpHandle : public OpHandleBase {
public:
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
const platform::CUDAPlace &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
for (auto &name : var_names) AddVar(name);
}
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
std::string Name() const override { return "reference_count"; }
void AddVar(const std::string &name) {
auto it = var_names_.find(name);
if (it != var_names_.end())
++(it->second);
else
var_names_[name] = 1;
}
protected:
void RunImpl() override {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<Tensor *> tensors;
for (auto &pair : var_names_) {
auto &name = pair.first;
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<LoDTensor>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
} else if (var->IsType<SelectedRows>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(const std::vector<Tensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
}
const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_;
std::unordered_map<std::string, int> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,187 +14,240 @@
#include <queue>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
std::queue<VarHandleBase *> queue;
queue.push(var_in);
// A functor to shrink/remove operators who depend on other operators in a set
class ShrinkDepsOpFunctor {
private:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) {
if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) {
ret.emplace(static_cast<KeyType>(ops[i]));
}
}
return ret;
}
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
}
PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops");
std::vector<std::vector<RelationShip>> ret(ops.size());
for (auto &e : ret) {
e.assign(ops.size(), kSame);
}
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
if (i != j && ret[i][j] == kSame) {
ret[i][j] = kBefore;
ret[j][i] = kAfter;
found_num += 2;
if (found_num == total_num) {
return false;
}
}
}
return true;
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
}
for (size_t i = 0; i < ops.size(); ++i) {
for (size_t j = i + 1; j < ops.size(); ++j) {
if (ret[i][j] != kSame) continue;
ret[i][j] = kNoDeps;
ret[j][i] = kNoDeps;
}
}
return ret;
}
const OpGraphView graph_;
};
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *var = queue.front();
queue.pop();
for (auto *op : var->PendingOps()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
queue.push(out_var);
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
}
}
} while (!queue.empty());
} while (!q.empty());
return nullptr;
}
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}
static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&](
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
if (!platform::is_gpu_place(var_handle->place_) ||
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
continue;
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name();
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(),
"Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
// This is weird but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name);
if (var_desc == nullptr) continue;
}
const auto &vars = graph->Get<GraphVars>(kGraphVars);
if (var_desc->Persistable()) continue;
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS) {
continue;
}
last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size());
// compute op only runs in one device
if (ref_cnts[place.device]->count(var_name))
++(*ref_cnts[place.device])[var_name];
else
(*ref_cnts[place.device])[var_name] = 1;
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
names.insert(var_name);
var_names_in_op.push_back(var_name);
}
return var_names_in_op;
};
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
auto update_ref_cnts_from_non_compute_op = [&](
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto var_name = var_handle->Node()->Name();
auto var_place = var_handle->place_;
if (!platform::is_gpu_place(var_place)) continue;
auto place = boost::get<platform::CUDAPlace>(var_place);
if (names.count(var_name) == 0) continue;
if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name];
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
if (next_compute_op != nullptr) {
if (compute_ref_cnt_map.count(next_compute_op)) {
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
VLOG(5) << "Add reference count of " << var_name << " to Operator "
<< next_compute_op->Name();
} else {
// Create new reference_count_op_handle
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
"reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
}
}
}
if (var_desc == nullptr || var_desc->Persistable()) {
continue;
}
};
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS &&
var_type != proto::VarType::LOD_TENSOR_ARRAY) {
// Var type cannot be deleted
continue;
}
for (auto &op : all_ops) {
update_ref_cnts_from_non_compute_op(op, op->Inputs());
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
bool ok;
auto result = ExtractComputationOpFromLastLivedVar(
name_var_pair.second.back(), i, shrink_func, &ok);
std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
it->second->AddOutput(dummy_leaf);
new_all_ops.emplace_back(std::move(it->second));
if (ok) {
auto &var_name = name_var_pair.first;
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
}
}
}
all_ops.swap(new_all_ops);
return graph;
}
......@@ -205,5 +258,4 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -22,10 +21,6 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kGlobalReferenceCount[] = "reference_count";
constexpr char kCurReferenceCount[] = "current_reference_count";
constexpr char kGarbageCollector[] = "garbage_collector";
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
namespace details {} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
namespace details {
class ComputationOpHandle;
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorMap =
std::map<platform::Place, std::unique_ptr<GarbageCollector>>;
const char kGlobalReferenceCount[] = "global_reference_count";
const char kRuntimeReferenceCount[] = "runtime_reference_count";
const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,9 +18,6 @@
#include <vector>
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace paddle {
namespace framework {
......@@ -69,27 +66,12 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
#ifdef PADDLE_WITH_CUDA
const std::string gc_name = "garbage_collector";
DeviceGarbageCollectorMap *gc =
Graph().Has(gc_name) ? &(Graph().Get<DeviceGarbageCollectorMap>(gc_name))
: nullptr;
#endif
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr && platform::is_gpu_place(p)) {
auto gpu_place = boost::get<platform::CUDAPlace>(p);
auto &gc_at_place = gc->at(gpu_place.device);
gc_at_place->Wait();
gc_at_place->Reset();
}
#endif
}
for (auto &scope : local_scopes_) {
auto &local_scope =
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle {
namespace framework {
......@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype;
}
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) {
#define REG_DL_DATA_TYPE(type) \
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType>
type_to_dtype_map({
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT
REG_DL_DATA_TYPE(double), // NOLINT
REG_DL_DATA_TYPE(int), // NOLINT
REG_DL_DATA_TYPE(int64_t), // NOLINT
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT
REG_DL_DATA_TYPE(int16_t), // NOLINT
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
static std::unordered_map<int, ::DLDataType> result;
#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
_ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
return result;
}
static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static auto type_to_dtype_map = CreateDLDataTypeMap();
static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type);
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s",
type.name());
auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type);
return it->second;
#undef REG_DL_DATA_TYPE
}
......
......@@ -91,23 +91,11 @@ void TestMainLoop() {
}
}
}
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \
TEST(dlpack, test_##type) { TestMainLoop<type>(); }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
_ForEachDataType_(TestCallback);
}
} // namespace framework
} // namespace paddle
......@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/ngraph_operator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
......@@ -26,6 +26,10 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/framework/ngraph_operator.h"
#endif
DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
......@@ -38,11 +42,43 @@ namespace {
int kProgramId = -1;
} // namespace
static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
const BlockDesc& block, const std::vector<std::string>& skip_var_list) {
std::unordered_map<std::string, size_t> ref_cnts;
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end());
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
if (skip_vars.count(name)) continue;
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS &&
type != proto::VarType::LOD_TENSOR_ARRAY) {
continue;
}
++ref_cnts[name];
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars)
: prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCount<int>(prog_, block_id_);
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars);
}
}
......@@ -50,28 +86,40 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
}
template <typename RefCntMap>
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
GarbageCollector<Tensor>* gc,
RefCntMap* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors;
static void DeleteUnusedTensors(
const Scope& scope, const OperatorBase* op, GarbageCollector* gc,
std::unordered_map<std::string, size_t>* ref_cnts) {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue;
if ((it->second)-- == 1) {
if (--(it->second) != 0) {
continue;
}
auto* var = scope.FindVar(name);
if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'";
continue;
}
VLOG(2) << "Erase variable " << name;
if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>());
garbages.emplace_back(
var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) {
erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value());
}
garbages.emplace_back(var->GetMutable<SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *lod_tensor_arr) {
garbages.emplace_back(t.MoveMemoryHolder());
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
}
}
}
......@@ -80,19 +128,19 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
handler(op->Inputs());
handler(op->Outputs());
if (!erase_tensors.empty()) {
gc->Add(erase_tensors);
if (!garbages.empty()) {
gc->Add(std::move(garbages));
}
}
static void EnableFusedOp(ExecutorPrepareContext* ctx) {
#ifdef PADDLE_WITH_NGRAPH
VLOG(3) << "use_ngraph=True";
auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_);
auto intervals = NgraphOperator::NgraphOpIntervals(&ctx->ops_);
for (auto& interval : intervals) {
auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_,
interval.at(0), interval.at(1));
*interval[0] = std::unique_ptr<OperatorBase>(fused_op);
auto* ng_op = new NgraphOperator(ctx->prog_, ctx->block_id_, interval.at(0),
interval.at(1));
*interval[0] = std::unique_ptr<OperatorBase>(ng_op);
}
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
ctx->ops_.erase(it->at(0) + 1, it->at(1));
......@@ -109,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>(0)
->SendComplete();
auto client =
paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
client->SendComplete();
#endif
}
......@@ -322,9 +370,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) {
std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id));
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
......@@ -335,16 +384,28 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
}
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids) {
const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d",
block_ids.size());
std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
size_t idx = 0;
for (auto& bid : block_ids) {
auto* ctx = new ExecutorPrepareContext(program, bid);
ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid);
} else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]);
}
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
++idx;
}
return result;
}
......@@ -362,22 +423,23 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
// WhileOp would set keep_kids to true,
// because WhileGradOp needs the scopes created in WhileOp.
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
std::unique_ptr<GarbageCollector> gc;
// skip while_op and while_grad_op temporarily
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else {
gc.reset(new DefaultStreamGarbageCollector(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
}
} else if (platform::is_cpu_place(place_)) {
#endif
gc.reset(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place_), max_memory_size));
gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place_),
max_memory_size));
#ifdef PADDLE_WITH_CUDA
}
#endif
......@@ -386,17 +448,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_);
if (gc != nullptr) {
if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
&(ctx->cur_ref_cnts_));
&(ctx->runtime_ref_cnts_));
}
}
if (gc != nullptr) {
gc->Wait();
} else {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (local_scope != scope) {
scope->DeleteScope(local_scope);
......
......@@ -27,52 +27,21 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_map<std::string, T> ref_cnts;
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS) {
continue;
}
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; }
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_;
std::unordered_map<std::string, size_t> global_ref_cnts_;
std::unordered_map<std::string, size_t> runtime_ref_cnts_;
};
class Executor {
......@@ -108,10 +77,14 @@ class Executor {
const std::string& fetch_holder_name = "fetch");
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids);
const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>());
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
......@@ -181,7 +182,7 @@ void ExecutorThreadWorker::SetDevice() {
static unsigned concurrency_cap = std::thread::hardware_concurrency();
int thread_id = this->thread_id_;
if (thread_id < concurrency_cap) {
if (static_cast<unsigned>(thread_id) < concurrency_cap) {
unsigned proc = thread_id;
cpu_set_t mask;
......@@ -222,42 +223,24 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl;
}
void print_fetch_var(Scope* scope, std::string var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) ==
std::type_index(typeid(platform::float16))) {
print_lod_tensor<platform::float16>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) {
print_lod_tensor<float>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(double))) {
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) {
print_lod_tensor<int>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
static void print_fetch_var(Scope* scope, const std::string& var_name) {
auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
return;
#define PrintLoDTensorCallback(cpp_type, proto_type) \
do { \
if (tensor.type() == proto_type) { \
print_lod_tensor<cpp_type>(var_name, tensor); \
return; \
} \
} while (0)
_ForEachDataType_(PrintLoDTensorCallback);
VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
}
void ExecutorThreadWorker::TrainFiles() {
platform::SetNumThreads(1);
// todo: configurable
SetDevice();
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
......@@ -53,5 +55,12 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
return tensor;
}
LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name) {
Variable* var = scope.FindVar(var_name);
PADDLE_ENFORCE(var, "%s no in scope", var_name);
PADDLE_ENFORCE(var->IsType<LoDTensor>(), "Only support lod tensor now.");
return *var->GetMutable<LoDTensor>();
}
} // namespace framework
} // namespace paddle
......@@ -27,5 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index);
LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name);
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new GarbageQueue());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void CPUGarbageCollector::ClearCallback(const std::function<void()> &callback) {
callback();
}
#ifdef PADDLE_WITH_CUDA
UnsafeFastGPUGarbageCollector::UnsafeFastGPUGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void UnsafeFastGPUGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
DefaultStreamGarbageCollector::DefaultStreamGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void DefaultStreamGarbageCollector::Wait() const {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void DefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
StreamGarbageCollector::~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
void StreamGarbageCollector::Wait() const { callback_manager_->Wait(); }
void StreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback_manager_->AddCallback(callback);
}
#endif
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,6 @@
#pragma once
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
......@@ -24,134 +23,74 @@
namespace paddle {
namespace framework {
// T should have memory_size() and clear() method
template <typename T>
class GarbageCollector {
public:
GarbageCollector(const platform::Place &place, size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new std::deque<T *>());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
using GarbageQueue = std::deque<std::shared_ptr<memory::Allocation>>;
virtual ~GarbageCollector() {}
GarbageCollector(const platform::Place &place, size_t max_memory_size);
void Reset() {
std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>());
cur_memory_size_ = 0;
}
virtual ~GarbageCollector() = default;
virtual void Wait() const {}
template <typename Container>
void Add(const Container &objs) {
Add(objs, []() {});
}
void Add(Container &&objs);
template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) {
std::shared_ptr<std::deque<T *>> clear_deque;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) {
garbages_->push_back(obj);
cur_memory_size_ += obj->memory_size();
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
clear_deque = garbages_;
garbages_.reset(new std::deque<T *>());
}
}
if (clear_deque != nullptr) {
callback();
ClearCallback([=]() {
for (auto *obj : *clear_deque) obj->clear();
});
}
}
virtual void Wait() const {}
void Add(Container &&objs, Callback &&callback);
protected:
virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_;
std::shared_ptr<std::deque<T *>> garbages_;
std::unique_ptr<GarbageQueue> garbages_;
mutable std::mutex mutex_;
const size_t max_memory_size_;
size_t cur_memory_size_ = 0;
size_t cur_memory_size_{0};
};
template <typename T>
class CPUGarbageCollector : public GarbageCollector<T> {
class CPUGarbageCollector : public GarbageCollector {
public:
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size);
protected:
void ClearCallback(const std::function<void()> &callback) override {
callback();
}
void ClearCallback(const std::function<void()> &callback) override;
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class DefaultStreamGarbageCollector : public GarbageCollector<T> {
class UnsafeFastGPUGarbageCollector : public GarbageCollector {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size);
cudaStream_t stream() const {
return static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->stream();
}
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
void Wait() const override {
this->dev_ctx_->Wait();
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
class DefaultStreamGarbageCollector : public GarbageCollector {
public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size);
void Wait() const override;
protected:
void ClearCallback(const std::function<void()> &callback) override {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
void ClearCallback(const std::function<void()> &callback) override;
};
template <typename T>
class StreamGarbageCollector : public GarbageCollector<T> {
class StreamGarbageCollector : public GarbageCollector {
public:
StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
size_t max_memory_size);
~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
~StreamGarbageCollector();
void Wait() const override {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->Wait();
}
void Wait() const override;
cudaStream_t stream() const { return stream_; }
cudaStream_t stream() const;
protected:
void ClearCallback(const std::function<void()> &callback) override {
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->AddCallback(callback);
}
void ClearCallback(const std::function<void()> &callback) override;
private:
cudaStream_t stream_;
......@@ -159,5 +98,33 @@ class StreamGarbageCollector : public GarbageCollector<T> {
};
#endif
template <typename Container>
void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {});
}
template <typename Container, typename Callback>
void GarbageCollector::Add(Container &&objs, Callback &&callback) {
GarbageQueue *garbage_queue = nullptr;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto &obj : objs) {
if (!obj) continue;
cur_memory_size_ += obj->size();
garbages_->push_back(std::move(obj));
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
garbage_queue = garbages_.release();
garbages_.reset(new GarbageQueue());
}
}
if (garbage_queue) {
callback();
ClearCallback([garbage_queue]() { delete garbage_queue; });
}
}
} // namespace framework
} // namespace paddle
......@@ -42,6 +42,8 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base)
......
......@@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
std::string type = is_conv3d() ? "conv3d" : "conv2d";
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(type, "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input);
conv_bias_pattern(conv_input, is_conv3d());
int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d");
desc.SetType(type);
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
......@@ -135,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
} // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
......@@ -26,11 +26,19 @@ namespace ir {
class ConvBiasFusePass : public FusePassBase {
public:
virtual ~ConvBiasFusePass() {}
virtual bool is_conv3d() const { return false; }
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
/*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
bool is_conv3d() const override { return true; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, new_conv_op); // ResidualData
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, conv_op); // Input
IR_NODE_LINK_TO(conv_filter, conv_op); // Filter
IR_NODE_LINK_TO(conv_op, conv_out); // Output
IR_NODE_LINK_TO(elementwise_add_in_y, conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, conv_op); // Bias
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAdd2ActFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAdd2ActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& activation, const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetType("conv2d_fusion");
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
desc.Flush();
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("conv2d", "Input")
->AsInput();
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto =
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAddActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddActFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -38,8 +38,7 @@ void CheckProgram(const ProgramDesc &program) {
switch (role_id) {
case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR)
<< "Cannot add backward operator before forward operator %s."
LOG(ERROR) << "Cannot add backward operator before forward operator "
<< op->Type();
}
break;
......
......@@ -73,14 +73,21 @@ class Graph {
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
return attrs_.count(attr_name) > 0;
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
template <typename AttrType>
......@@ -177,14 +184,13 @@ class Graph {
return nullptr;
}
const ProgramDesc &program() const { return program_; }
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
......
......@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
......@@ -25,6 +26,7 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -104,7 +106,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
for (auto &node : GraphTraits::DFS(graph)) {
for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked";
VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
pdnodes2nodes_[pdnode.get()].insert(&node);
}
}
......@@ -1030,10 +1032,11 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
}
PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input) {
paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
conv_input->assert_is_op_input(type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type);
auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables
......@@ -1041,11 +1044,11 @@ PDNode *patterns::ConvBias::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(type, "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_only_output_of_op(type)
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
......@@ -1098,6 +1101,115 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var;
}
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput();
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y = pattern->NewNode(elementwise_add_in_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
auto act_op = pattern->NewNode(act_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
auto op_type = node->Name();
return conv_act_set.count(op_type);
});
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_var()
// is activation op's output.
->assert_more([&](Node *node) {
for (auto *in_op : node->inputs) {
if (conv_act_set.count(in_op->Name())) {
return true;
}
}
return false;
})
->AsOutput();
conv_op->LinksFrom({conv_in, conv_filter});
conv_out->LinksFrom({conv_op});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
act_op->LinksFrom({elementwise_add_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y = pattern->NewNode(elementwise_add_in_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise_add_op_1 = pattern->NewNode(elementwise_add_op_1_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y_1 = pattern->NewNode(elementwise_add_in_y_1_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out_1 = pattern->NewNode(elementwise_add_out_1_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
auto act_op = pattern->NewNode(act_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
auto op_type = node->Name();
return conv_act_set.count(op_type);
});
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_var()
// is activation op's output.
->assert_more([&](Node *node) {
for (auto *in_op : node->inputs) {
if (conv_act_set.count(in_op->Name())) {
return true;
}
}
return false;
})
->AsOutput();
conv_op->LinksFrom({conv_in, conv_filter}).LinksTo({conv_out});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
elementwise_add_op_1->LinksFrom(
{elementwise_add_out, elementwise_add_in_y_1});
act_op->LinksFrom({elementwise_add_out_1}).LinksTo({act_out});
return act_out;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input);
PDNode* operator()(PDNode* conv_input, bool is_conv3d = false);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise);
......@@ -671,6 +671,51 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
};
// Conv + ElementwiseAdd + an activation
// This pattern can futher fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase {
ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_in_y); // input
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
};
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct ConvElementwiseadd2Act : public PatternBase {
ConvElementwiseadd2Act(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope,
"conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_in_y); // input
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(elementwise_add_op_1);
PATTERN_DECL_NODE(elementwise_add_in_y_1); // input
PATTERN_DECL_NODE(elementwise_add_out_1);
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
......@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->HasAttr("is_test")) {
if (op->HasAttr("is_test") || op->HasProtoAttr("is_test")) {
op->SetAttr("is_test", true);
} else if (std::find(begin(op_list), end(op_list), op->Type()) !=
end(op_list)) {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
#include <string>
namespace paddle {
namespace framework {
......@@ -21,9 +22,19 @@ namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy.";
const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) {
n->Op()->SetAttr("use_mkldnn", true);
if (n->IsOp()) {
auto* op = n->Op();
if (op->HasAttr("use_mkldnn") || op->HasProtoAttr("use_mkldnn")) {
if (op_types_list.empty()) {
op->SetAttr("use_mkldnn", true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr("use_mkldnn", true);
}
}
}
}
return graph;
......@@ -33,5 +44,5 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_placement_pass,
paddle::framework::ir::MKLDNNPlacementPass);
REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
.RequirePassAttr("mkldnn_enabled_op_types");
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle {
namespace framework {
......@@ -24,10 +25,11 @@ constexpr char Node::kControlDepVarName[];
const char Node::kControlDepVarName[] = "__control_var";
#endif
std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
Node::Type type) {
return std::unique_ptr<Node>(new Node(name, type));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -51,11 +51,18 @@ class Pass {
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
return attrs_.count(attr_name) > 0;
}
void Erase(const std::string &attr_name) {
......
......@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) {
if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) {
} else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " ";
} else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
......@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type();
auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
......@@ -12,28 +12,109 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <functional>
#include <vector>
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace framework {
static std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
const VariableNameMap& var_map,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = var_map.at(name);
PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s name %s expects one associated var", op->Type(),
name);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]];
} else {
return nullptr;
}
}
static std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, name, op->Inputs(), ngb_node_map);
}
static std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, name, op->Outputs(), ngb_node_map);
}
static void SetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<ngraph::Node> node,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = op->Outputs().at(name);
if (var_names.size() == 1) {
(*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node;
} else {
PADDLE_THROW("name %s has more than 1 var_names.", name);
}
}
static bool HasOutput(const std::shared_ptr<OperatorBase>& op,
const std::string name) {
auto& outputs = op->Outputs();
if (outputs.find(name) == outputs.end()) return false;
return outputs.at(name).size() > 0;
}
template <typename T>
static void BuildBinaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = GetInputNode(op, "X", ngb_node_map);
auto y = GetInputNode(op, "Y", ngb_node_map);
auto out = std::make_shared<T>(x, y);
SetOutputNode(op, "Out", out, ngb_node_map);
}
template <typename T>
static void BuildUnaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = GetInputNode(op, "X", ngb_node_map);
auto out = std::make_shared<T>(input);
SetOutputNode(op, "Out", out, ngb_node_map);
}
std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = {};
NgraphBridge::NG_NODE_MAP = {{"relu", BuildUnaryNode<ngraph::op::Relu>},
{"tanh", BuildUnaryNode<ngraph::op::Tanh>}};
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) {
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
auto& op_type = op->Type();
NG_NODE_MAP[op_type](op, ngb_node_map);
NG_NODE_MAP[op_type](op, ngb_node_map_);
}
} // namespace framework
} // namespace paddle
#endif
......@@ -14,22 +14,18 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/node.hpp"
namespace paddle {
namespace framework {
class OperatorBase;
class NgraphBridge {
public:
static std::map<
......@@ -43,16 +39,15 @@ class NgraphBridge {
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map)
: ngb_node_map(var_node_map) {}
: ngb_node_map_(var_node_map) {}
void build_graph(const std::shared_ptr<OperatorBase>& op);
void BuildNgNode(const std::shared_ptr<OperatorBase>& op);
private:
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map;
ngb_node_map_;
};
} // namespace framework
} // namespace paddle
#endif
......@@ -12,21 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/ngraph_operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace framework {
static ngraph::Shape Ddim2Shape(const DDim& dims) {
ngraph::Shape sp;
for (int i = 0; i < dims.size(); ++i) {
int k = dims[i];
k = k == 0 ? 1 : k;
sp.push_back(k);
}
return sp;
}
static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
{proto::VarType::FP32, ngraph::element::f32},
{proto::VarType::FP64, ngraph::element::f64},
......@@ -42,9 +56,10 @@ typedef enum { /* nGraph support state on ops */
PARTIAL_TEST /* Support partial list of ops for test */
} op_state;
class NgraphOperator {
// perform graph build through bridge and execute computation
class NgraphEngine {
public:
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
explicit NgraphEngine(const Scope& scope, const platform::Place& place,
const std::vector<std::shared_ptr<OperatorBase>>& ops,
const std::unordered_map<
std::string, ngraph::element::Type>& var_type_map,
......@@ -59,13 +74,23 @@ class NgraphOperator {
persistables_(persist),
fetches_(fetches),
post_op_inputs_(post_op_inputs),
ng_op_state_(ng_op_state) {}
ng_op_state_(ng_op_state) {
var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
BuildNgIO();
GetNgFunction();
}
void Run(const Scope& scope, const platform::Place& place) const;
private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache;
func_cache_;
const Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
......@@ -74,10 +99,39 @@ class NgraphOperator {
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_;
op_state ng_op_state_;
// ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_;
// ngraph function to call and execute
std::shared_ptr<ngraph::Function> ngraph_function_;
// var_name of inputs
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
std::vector<std::string> var_out_;
// map input vars to nodes
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_in_node_map_;
// map each var name with a ngraph node
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_;
// cache key to check if function is cached
std::shared_ptr<std::string> GetCacheKey();
// get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<OperatorBase> op);
// Call ngraph bridge to map ops
void BuildNgNodes();
// get the ngraph input and output var list
void BuildNgIO();
// build ngraph function call
void BuildNgFunction();
// Check cache for ngraph function or otherwise build the function
void GetNgFunction();
};
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
FusedOperator::FusedOpIntervals(
NgraphOperator::NgraphOpIntervals(
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
intervals;
......@@ -86,7 +140,7 @@ FusedOperator::FusedOpIntervals(
}
size_t size = ops->size();
size_t left = 0;
while (left < size && ops.at(left)->Type() != kFeedOpType) {
while (left < size && ops->at(left)->Type() != kFeedOpType) {
++left;
}
if (left == size) {
......@@ -116,7 +170,7 @@ FusedOperator::FusedOpIntervals(
size_t start = pivot, end = start;
while (pivot < right &&
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
ops.at(pivot)->Type()) !=
ops->at(pivot)->Type()) !=
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
++pivot;
++end;
......@@ -130,13 +184,15 @@ FusedOperator::FusedOpIntervals(
return intervals;
}
FusedOperator::FusedOperator(
NgraphOperator::NgraphOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
: OperatorBase(type, inputs, outputs, attrs),
pdesc_(prog),
block_(block_id) {
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
it != end; ++it) {
fused_ops_.push_back(std::move(*it));
......@@ -152,13 +208,13 @@ FusedOperator::FusedOperator(
}
if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) {
is_complete = true;
is_full_ = true;
}
Process();
}
void FusedOperator::Process() {
void NgraphOperator::Process() {
auto& bdesc = pdesc_.Block(block_);
for (auto& var : bdesc.AllVars()) {
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
......@@ -194,7 +250,7 @@ void FusedOperator::Process() {
}
}
void FusedOperator::RunImpl(const Scope& scope,
void NgraphOperator::RunImpl(const Scope& scope,
const platform::Place& place) const {
op_state ng_op_state = PARTIAL_TEST;
auto& bdesc = pdesc_.Block(block_);
......@@ -205,16 +261,284 @@ void FusedOperator::RunImpl(const Scope& scope,
}
}
if (is_full) {
if (is_full_) {
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
}
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_,
persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_op.Run(scope, place);
ngraph_engine.Run(scope, place);
}
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
NgraphEngine::func_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RuntimeInferShape(scope_, place_);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
}
}
}
void NgraphEngine::BuildNgNodes() {
for (auto& var_name : var_out_) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
}
paddle::framework::NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
}
void NgraphEngine::BuildNgIO() {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
inputs.insert(var_name);
const bool is_output = outputs.find(var_name) != outputs.end();
if (!is_output &&
std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) {
// fill var_in here to keep lhs and rhs order
var_in_.push_back(var_name);
}
}
}
if (op->Type() != "fill_constant") {
GetNgInputShape(op);
}
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
outputs.insert(var_name);
}
}
}
// var_out.clear();
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
switch (ng_op_state_) {
case PARTIAL_TEST:
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
}
break;
case FULL_TEST:
if (fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
}
break;
case PARTIAL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
}
break;
case FULL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
}
break;
default:
var_out_.push_back(var_name);
}
}
}
}
}
void NgraphEngine::BuildNgFunction() {
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
ngraph::op::ParameterVector func_inputs;
for (auto& vo : var_out_) {
func_outputs.push_back(var_node_map_->at(vo));
}
for (auto& vi : var_in_) {
std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi));
func_inputs.push_back(prm);
}
ngraph_function_ =
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
std::shared_ptr<std::string> NgraphEngine::GetCacheKey() {
auto cache_key = std::make_shared<std::string>("");
*cache_key += std::to_string(fused_ops_.size());
for (auto& op : fused_ops_) {
*cache_key += op->Type();
}
for (auto& var_name : var_in_) {
auto shape = var_node_map_->at(var_name)->get_shape();
*cache_key += var_name;
*cache_key += var_type_map_.at(var_name).c_type_string();
for (size_t i = 0; i < shape.size(); ++i) {
*cache_key += std::to_string(shape.at(i));
}
}
for (auto& var_name : var_out_) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
for (int i = 0; i < ddim.size(); ++i) {
*cache_key += std::to_string(ddim[i]);
}
}
}
return cache_key;
}
void NgraphEngine::GetNgFunction() {
bool cache_on = true;
if (cache_on) {
std::string cache_key_val = *GetCacheKey();
if (func_cache_.find(cache_key_val) != func_cache_.end()) {
ngraph_function_ = func_cache_.at(cache_key_val);
} else {
BuildNgFunction();
func_cache_[cache_key_val] = ngraph_function_;
}
} else {
BuildNgFunction();
}
}
void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
for (size_t i = 0; i < var_in_.size(); ++i) {
auto vi = var_in_.at(i);
auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
if (tensor_pd->type() == proto::VarType::FP32) {
const float* arr = tensor_pd->data<float>();
ti = backend_->create_tensor(ngraph::element::f32, sp,
const_cast<float*>(arr));
} else if (tensor_pd->type() == proto::VarType::INT32) {
const int* arr = tensor_pd->data<int>();
ti = backend_->create_tensor(ngraph::element::i32, sp,
const_cast<int*>(arr));
} else if (tensor_pd->type() == proto::VarType::INT64) {
const int64_t* arr = tensor_pd->data<int64_t>();
ti = backend_->create_tensor(ngraph::element::i64, sp,
const_cast<int64_t*>(arr));
} else if (tensor_pd->type() == proto::VarType::FP64) {
const double* arr = tensor_pd->data<double>();
ti = backend_->create_tensor(ngraph::element::f64, sp,
const_cast<double*>(arr));
} else if (tensor_pd->type() == proto::VarType::BOOL) {
const bool* arr = tensor_pd->data<bool>();
ti = backend_->create_tensor(ngraph::element::boolean, sp,
const_cast<bool*>(arr));
} else {
PADDLE_THROW("Data type not handling for var %s", vi);
}
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_test = (ng_op_state_ == PARTIAL_TEST || ng_op_state_ == FULL_TEST)
? true
: false;
bool is_persistable =
(persistables_.find(vi) != persistables_.end()) ? true : false;
if (is_test && is_persistable) {
ti->set_stale(false);
}
t_in.push_back(ti);
}
for (size_t i = 0; i < var_out_.size(); ++i) {
auto var_name = var_out_[i];
auto* var = scope.FindVar(var_name);
std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims();
ngraph::Shape sp = Ddim2Shape(dd);
auto ng_type = var_type_map_.at(var_name);
if (ng_type == ngraph::element::f32) {
auto pd_arr = tensor_pd->mutable_data<float>(place);
to = backend_->create_tensor(ngraph::element::f32, sp, pd_arr);
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
to = backend_->create_tensor(ngraph::element::i64, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
to = backend_->create_tensor(ngraph::element::f64, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
to = backend_->create_tensor(ngraph::element::boolean, sp, pd_arr);
} else {
PADDLE_THROW("Data type not handled in for var %s", var_name);
}
t_out.push_back(to);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", var_name);
}
}
backend_->call(ngraph_function_, t_out, t_in);
} // NgraphEngine::RunImpl
} // namespace framework
} // namespace paddle
#endif
......@@ -237,6 +237,23 @@ void OpDesc::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args;
}
bool OpDesc::HasProtoAttr(const std::string &name) const {
auto &op_info = OpInfoMap::Instance();
if (op_info.Has(desc_.type())) {
auto op_info_ptr = op_info.Get(desc_.type());
if (op_info_ptr.HasOpProtoAndChecker()) {
const proto::OpProto &proto = op_info_ptr.Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return true;
}
}
}
}
return false;
}
proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
......
......@@ -65,6 +65,8 @@ class OpDesc {
return attrs_.find(name) != attrs_.end();
}
bool HasProtoAttr(const std::string &name) const;
proto::AttrType GetAttrType(const std::string &name) const;
std::vector<std::string> AttrNames() const;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册