未验证 提交 34011bdb 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Merge branch 'develop' into transpose_int8_mkldnn_2

python/paddle/fluid/tests/unittests/reader_reset_test.recordio
paddle/operators/check_t.save paddle/operators/check_t.save
paddle/operators/check_tensor.ls paddle/operators/check_tensor.ls
paddle/operators/tensor.save paddle/operators/tensor.save
......
...@@ -27,18 +27,27 @@ message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " ...@@ -27,18 +27,27 @@ message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
message(STATUS "AR tools: ${CMAKE_AR}") message(STATUS "AR tools: ${CMAKE_AR}")
if(WIN32) if(WIN32)
option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
set(CMAKE_SUPPRESS_REGENERATION ON) set(CMAKE_SUPPRESS_REGENERATION ON)
set(CMAKE_STATIC_LIBRARY_PREFIX lib) set(CMAKE_STATIC_LIBRARY_PREFIX lib)
add_definitions("/DGOOGLE_GLOG_DLL_DECL=") add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT") if (MSVC_STATIC_CRT)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd") message(STATUS "Use static C runtime time, refer to https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=vs-2019")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT")
endif()
add_compile_options(/wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838) add_compile_options(/wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838)
set(PADDLE_LINK_FLAGS "/IGNORE:4006 /IGNORE:4098 /IGNORE:4217 /IGNORE:4221") set(PADDLE_LINK_FLAGS "/IGNORE:4006 /IGNORE:4098 /IGNORE:4217 /IGNORE:4221")
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
else(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations")
endif(WIN32) endif(WIN32)
find_package(CUDA QUIET) find_package(CUDA QUIET)
...@@ -65,13 +74,13 @@ option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" ...@@ -65,13 +74,13 @@ option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools"
option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF)
option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF) option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(WITH_HIGH_LEVEL_API_TEST "Test fluid python high-level api interface" OFF) option(WITH_HIGH_LEVEL_API_TEST "Test fluid python high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF) option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF)
...@@ -150,8 +159,6 @@ include(external/cub) ...@@ -150,8 +159,6 @@ include(external/cub)
include(external/rocprim) include(external/rocprim)
include(external/xxhash) # download xxhash include(external/xxhash) # download xxhash
include(external/dlpack) include(external/dlpack)
include(external/snappy) # download snappy
include(external/snappystream) # download snappystream
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
if (NOT WIN32) if (NOT WIN32)
...@@ -164,6 +171,9 @@ if(WITH_PSLIB) ...@@ -164,6 +171,9 @@ if(WITH_PSLIB)
include(external/pslib_brpc) include(external/pslib_brpc)
include(external/pslib) include(external/pslib)
endif(WITH_PSLIB) endif(WITH_PSLIB)
if(WITH_BOX_PS)
include(external/box_ps)
endif(WITH_BOX_PS)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
if(WITH_GRPC) if(WITH_GRPC)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
English | [简体中文](./README_cn.md) English | [简体中文](./README_cn.md)
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/documentation/docs/en/1.4/beginners_guide/index_en.html) [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org.cn/documentation/docs/en/1.5/beginners_guide/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/index_cn.html) [![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/index_cn.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
...@@ -77,33 +77,33 @@ Now our developers could acquire Tesla V100 online computing resources for free. ...@@ -77,33 +77,33 @@ Now our developers could acquire Tesla V100 online computing resources for free.
## Installation ## Installation
It is recommended to read [this doc](http://www.paddlepaddle.org/documentation/docs/en/1.4/beginners_guide/index_en.html) on our website. It is recommended to read [this doc](http://www.paddlepaddle.org.cn/documentation/docs/en/1.5/beginners_guide/index_en.html) on our website.
## Documentation ## Documentation
We provide [English](http://www.paddlepaddle.org/documentation/docs/en/1.4/beginners_guide/index_en.html) and We provide [English](http://www.paddlepaddle.org.cn/documentation/docs/en/1.5/beginners_guide/index_en.html) and
[Chinese](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html) documentation. [Chinese](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book) - [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook. You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/en/1.4/user_guides/howto/training/multi_node_en.html) - [Distributed Training](http://paddlepaddle.org.cn/documentation/docs/en/1.5/user_guides/howto/training/multi_node_en.html)
You can run distributed training jobs on MPI clusters. You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/docs/en/1.4/api/index_en.html) - [Python API](http://paddlepaddle.org.cn/documentation/docs/en/1.5/api/index_en.html)
Our new API enables much shorter programs. Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/en/1.4/advanced_usage/development/contribute_to_paddle/index_en.html) - [How to Contribute](http://paddlepaddle.org.cn/documentation/docs/en/1.5/advanced_usage/development/contribute_to_paddle/index_en.html)
We appreciate your contributions! We appreciate your contributions!
## Communication ## Communication
- [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc. - [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 432676488 (PaddlePaddle). - QQ discussion group: 796771754 (PaddlePaddle).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc. - [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Copyright and License ## Copyright and License
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
[English](./README.md) | 简体中文 [English](./README.md) | 简体中文
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/documentation/docs/en/1.4/beginners_guide/index_en.html) [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org.cn/documentation/docs/en/1.5/beginners_guide/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/index_cn.html) [![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/index_cn.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
...@@ -59,33 +59,33 @@ PaddlePaddle用户可领取**免费Tesla V100在线算力资源**,训练模型 ...@@ -59,33 +59,33 @@ PaddlePaddle用户可领取**免费Tesla V100在线算力资源**,训练模型
## 安装 ## 安装
推荐阅读官网上的[安装说明](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html) 推荐阅读官网上的[安装说明](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)
## 文档 ## 文档
我们提供[英文](http://www.paddlepaddle.org/documentation/docs/en/1.4/beginners_guide/index_en.html) 我们提供[英文](http://www.paddlepaddle.org.cn/documentation/docs/en/1.5/beginners_guide/index_en.html)
[中文](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html) 文档 [中文](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 文档
- [深度学习101](https://github.com/PaddlePaddle/book) - [深度学习101](https://github.com/PaddlePaddle/book)
或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行 或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.4/user_guides/howto/training/multi_node.html) - [分布式训练](http://paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/howto/training/multi_node.html)
可以在MPI集群上运行分布式训练任务 可以在MPI集群上运行分布式训练任务
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.4/api_cn/index_cn.html) - [Python API](http://paddlepaddle.org.cn/documentation/docs/zh/1.5/api_cn/index_cn.html)
新的API支持代码更少更简洁的程序 新的API支持代码更少更简洁的程序
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.4/advanced_usage/development/contribute_to_paddle/index_cn.html) - [贡献方式](http://paddlepaddle.org.cn/documentation/docs/zh/1.5/advanced_usage/development/contribute_to_paddle/index_cn.html)
欢迎您的贡献! 欢迎您的贡献!
## 交流与反馈 ## 交流与反馈
- 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议 - 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议
- QQ群: 432676488 (PaddlePaddle) - QQ群: 796771754 (PaddlePaddle)
- [论坛](http://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围 - [论坛](http://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 版权和许可证 ## 版权和许可证
......
...@@ -62,6 +62,10 @@ if(WITH_PSLIB) ...@@ -62,6 +62,10 @@ if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB) add_definitions(-DPADDLE_WITH_PSLIB)
endif() endif()
if(WITH_BOX_PS)
add_definitions(-DPADDLE_WITH_BOX_PS)
endif()
if(WITH_GPU) if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU) add_definitions(-DEIGEN_USE_GPU)
...@@ -88,14 +92,20 @@ if(WITH_GPU) ...@@ -88,14 +92,20 @@ if(WITH_GPU)
include_directories(${CUDA_TOOLKIT_INCLUDE}) include_directories(${CUDA_TOOLKIT_INCLUDE})
if(TENSORRT_FOUND) if(TENSORRT_FOUND)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 8) if(WIN32)
message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile") if(${CUDA_VERSION_MAJOR} VERSION_LESS 9)
endif() message(FATAL_ERROR "TensorRT needs CUDA >= 9.0 to compile on Windows")
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) endif()
message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile") else()
endif() if(${CUDA_VERSION_MAJOR} VERSION_LESS 8)
if(${TENSORRT_MAJOR_VERSION} VERSION_LESS 4) message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile")
message(FATAL_ERROR "Paddle needs TensorRT >= 4.0 to compile") endif()
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile")
endif()
if(${TENSORRT_MAJOR_VERSION} VERSION_LESS 4)
message(FATAL_ERROR "Paddle needs TensorRT >= 4.0 to compile")
endif()
endif() endif()
include_directories(${TENSORRT_INCLUDE_DIR}) include_directories(${TENSORRT_INCLUDE_DIR})
endif() endif()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import shutil
import glob
def main():
src = sys.argv[1]
dst = sys.argv[2]
if os.path.isdir(src): #copy directory
pathList = os.path.split(src)
dst = os.path.join(dst, pathList[-1])
if not os.path.exists(dst):
shutil.copytree(src, dst)
print("first copy directory: {0} --->>> {1}".format(src, dst))
else:
shutil.rmtree(dst)
shutil.copytree(src, dst)
print("overwritten copy directory: {0} --->>> {1}".format(src, dst))
else: #copy file, wildcard
if not os.path.exists(dst):
os.makedirs(dst)
srcFiles = glob.glob(src)
for srcFile in srcFiles:
shutil.copy(srcFile, dst)
print("copy file: {0} --->>> {1}".format(srcFile, dst))
if __name__ == "__main__":
main()
...@@ -186,10 +186,6 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11") ...@@ -186,10 +186,6 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_FAST_MATH)
# Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
endif()
# in cuda9, suppress cuda warning on eigen # in cuda9, suppress cuda warning on eigen
list(APPEND CUDA_NVCC_FLAGS "-w") list(APPEND CUDA_NVCC_FLAGS "-w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings # Set :expt-relaxed-constexpr to suppress Eigen warnings
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_BOX_PS})
return()
ENDIF(NOT ${WITH_BOX_PS})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with BOX_PS in Paddle yet."
"Force WITH_BOX_PS=OFF")
SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(BOX_PS_PROJECT "extern_box_ps")
IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL))
MESSAGE(STATUS "use pre defined download url")
SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE)
SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE)
SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps_stub.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}")
SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps")
SET(BOX_PS_DOWNLOAD_DIR "${BOX_PS_SOURCE_DIR}/src/${BOX_PS_PROJECT}")
SET(BOX_PS_DST_DIR "box_ps")
SET(BOX_PS_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(BOX_PS_INSTALL_DIR ${BOX_PS_INSTALL_ROOT}/${BOX_PS_DST_DIR})
SET(BOX_PS_ROOT ${BOX_PS_INSTALL_DIR})
SET(BOX_PS_INC_DIR ${BOX_PS_ROOT}/include)
SET(BOX_PS_LIB_DIR ${BOX_PS_ROOT}/lib)
SET(BOX_PS_LIB ${BOX_PS_LIB_DIR}/libbox_ps.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${BOX_PS_ROOT}/lib")
INCLUDE_DIRECTORIES(${BOX_PS_INC_DIR})
FILE(WRITE ${BOX_PS_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(BOX_PS)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${BOX_PS_NAME}/include ${BOX_PS_NAME}/lib \n"
" DESTINATION ${BOX_PS_DST_DIR})\n")
ExternalProject_Add(
${BOX_PS_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${BOX_PS_SOURCE_DIR}
DOWNLOAD_DIR ${BOX_PS_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${BOX_PS_URL} -c -q -O ${BOX_PS_NAME}.tar.gz
&& tar zxvf ${BOX_PS_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${BOX_PS_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BOX_PS_INSTALL_ROOT}
)
ADD_LIBRARY(box_ps SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET box_ps PROPERTY IMPORTED_LOCATION ${BOX_PS_LIB})
ADD_DEPENDENCIES(box_ps ${BOX_PS_PROJECT})
...@@ -33,7 +33,7 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr ...@@ -33,7 +33,7 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args # Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog") set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF # If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add( ExternalProject_Add(
...@@ -62,7 +62,7 @@ ExternalProject_Add( ...@@ -62,7 +62,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
ADD_DEPENDENCIES(extern_brpc protobuf ssl crypto leveldb gflags glog gtest snappy) ADD_DEPENDENCIES(extern_brpc protobuf ssl crypto leveldb gflags glog gtest)
ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL) ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES}) SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc) ADD_DEPENDENCIES(brpc extern_brpc)
......
...@@ -3,15 +3,6 @@ INCLUDE(ExternalProject) ...@@ -3,15 +3,6 @@ INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3)
INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
if(NOT WITH_FAST_MATH)
# EIGEN_FAST_MATH: https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html
# enables some optimizations which might affect the accuracy of the result.
# This currently enables the SSE vectorization of sin() and cos(),
# and speedups sqrt() for single precision.
# Defined to 1 by default. Define it to 0 to disable.
add_definitions(-DEIGEN_FAST_MATH=0)
endif()
if(WIN32) if(WIN32)
set(EIGEN_GIT_REPOSITORY https://github.com/wopeizl/eigen-git-mirror) set(EIGEN_GIT_REPOSITORY https://github.com/wopeizl/eigen-git-mirror)
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
#FIXME:(gongwb) Move brpc's gtest dependency. #FIXME:(gongwb) Move brpc's gtest dependency.
include(GNUInstallDirs)
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING) IF(WITH_TESTING)
ENABLE_TESTING() ENABLE_TESTING()
...@@ -28,14 +31,14 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -28,14 +31,14 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WIN32) IF(WIN32)
set(GTEST_LIBRARIES set(GTEST_LIBRARIES
"${GTEST_INSTALL_DIR}/lib/gtest.lib" CACHE FILEPATH "gtest libraries." FORCE) "${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest.lib" CACHE FILEPATH "gtest libraries." FORCE)
set(GTEST_MAIN_LIBRARIES set(GTEST_MAIN_LIBRARIES
"${GTEST_INSTALL_DIR}/lib/gtest_main.lib" CACHE FILEPATH "gtest main libraries." FORCE) "${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest_main.lib" CACHE FILEPATH "gtest main libraries." FORCE)
ELSE(WIN32) ELSE(WIN32)
set(GTEST_LIBRARIES set(GTEST_LIBRARIES
"${GTEST_INSTALL_DIR}/lib/libgtest.a" CACHE FILEPATH "gtest libraries." FORCE) "${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgtest.a" CACHE FILEPATH "gtest libraries." FORCE)
set(GTEST_MAIN_LIBRARIES set(GTEST_MAIN_LIBRARIES
"${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE) "${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE)
ENDIF(WIN32) ENDIF(WIN32)
IF(WITH_MKLML) IF(WITH_MKLML)
...@@ -48,7 +51,7 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -48,7 +51,7 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${GTEST_DEPENDS} DEPENDS ${GTEST_DEPENDS}
GIT_REPOSITORY "https://github.com/google/googletest.git" GIT_REPOSITORY "https://github.com/google/googletest.git"
GIT_TAG "release-1.8.0" GIT_TAG "release-1.8.1"
PREFIX ${GTEST_SOURCES_DIR} PREFIX ${GTEST_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -34,8 +34,6 @@ ExternalProject_Add( ...@@ -34,8 +34,6 @@ ExternalProject_Add(
BUILD_IN_SOURCE 1 BUILD_IN_SOURCE 1
) )
ADD_DEPENDENCIES(extern_leveldb snappy)
ADD_LIBRARY(leveldb STATIC IMPORTED GLOBAL) ADD_LIBRARY(leveldb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET leveldb PROPERTY IMPORTED_LOCATION ${LEVELDB_LIBRARIES}) SET_PROPERTY(TARGET leveldb PROPERTY IMPORTED_LOCATION ${LEVELDB_LIBRARIES})
ADD_DEPENDENCIES(leveldb extern_leveldb) ADD_DEPENDENCIES(leveldb extern_leveldb)
...@@ -43,7 +43,7 @@ IF(WIN32) ...@@ -43,7 +43,7 @@ IF(WIN32)
ELSE() ELSE()
#TODO(intel-huying): #TODO(intel-huying):
# Now enable Erf function in mklml library temporarily, it will be updated as offical version later. # Now enable Erf function in mklml library temporarily, it will be updated as offical version later.
SET(MKLML_VER "Glibc225_vsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) SET(MKLML_VER "csrmm2_mklml_lnx_2019.0.2" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
......
...@@ -222,6 +222,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -222,6 +222,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
-DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_LIBDIR=lib
-DBUILD_SHARED_LIBS=OFF -DBUILD_SHARED_LIBS=OFF
-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}
CMAKE_CACHE_ARGS CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR}
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include (ExternalProject)
# NOTE: snappy is needed when linking with recordio
set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE)
if(WIN32)
SET(SNAPPY_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4267")
else()
SET(SNAPPY_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
endif()
ExternalProject_Add(
extern_snappy
GIT_REPOSITORY "https://github.com/google/snappy"
GIT_TAG "1.1.7"
PREFIX ${SNAPPY_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${SNAPPY_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DSNAPPY_BUILD_TESTS:BOOL=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
add_library(snappy STATIC IMPORTED GLOBAL)
set_property(TARGET snappy PROPERTY IMPORTED_LOCATION ${SNAPPY_LIBRARIES})
include_directories(${SNAPPY_INCLUDE_DIR})
add_dependencies(snappy extern_snappy)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include (ExternalProject)
set(SNAPPYSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy_stream)
set(SNAPPYSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy_stream)
set(SNAPPYSTREAM_INCLUDE_DIR "${SNAPPYSTREAM_INSTALL_DIR}/include" CACHE PATH "snappy stream include directory." FORCE)
if(WIN32)
# Fix me, VS2015 come without VLA support
set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/snappystream.lib")
MESSAGE(WARNING, "In windows, snappystream has no compile support for windows,
please build it manually and put it at " ${SNAPPYSTREAM_INSTALL_DIR})
else(WIN32)
set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/libsnappystream.a")
ExternalProject_Add(
extern_snappystream
GIT_REPOSITORY "https://github.com/hoxnox/snappystream.git"
GIT_TAG "0.2.8"
PREFIX ${SNAPPYSTREAM_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DSNAPPY_ROOT=${SNAPPY_INSTALL_DIR}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${SNAPPYSTREAM_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPYSTREAM_INSTALL_DIR}/lib
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
DEPENDS snappy
)
endif(WIN32)
add_library(snappystream STATIC IMPORTED GLOBAL)
set_property(TARGET snappystream PROPERTY IMPORTED_LOCATION ${SNAPPYSTREAM_LIBRARIES})
include_directories(${SNAPPYSTREAM_INCLUDE_DIR}) # For snappysteam to include its own headers.
include_directories(${THIRD_PARTY_PATH}/install) # For Paddle to include snappy stream headers.
add_dependencies(snappystream extern_snappystream)
...@@ -204,7 +204,7 @@ foreach(flag ${GPU_COMMON_FLAGS}) ...@@ -204,7 +204,7 @@ foreach(flag ${GPU_COMMON_FLAGS})
safe_set_nvflag(${flag}) safe_set_nvflag(${flag})
endforeach() endforeach()
if(WIN32) if(WIN32 AND MSVC_STATIC_CRT)
# windows build turn off warnings. # windows build turn off warnings.
safe_set_static_flag() safe_set_static_flag()
foreach(flag_var foreach(flag_var
......
...@@ -13,6 +13,14 @@ ...@@ -13,6 +13,14 @@
# limitations under the License. # limitations under the License.
# make package for paddle fluid shared and static library # make package for paddle fluid shared and static library
if(WIN32)
if(NOT PYTHON_EXECUTABLE)
FIND_PACKAGE(PythonInterp REQUIRED)
endif()
endif()
set(COPY_SCRIPT_DIR ${PADDLE_SOURCE_DIR}/cmake)
function(copy TARGET) function(copy TARGET)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -26,42 +34,16 @@ function(copy TARGET) ...@@ -26,42 +34,16 @@ function(copy TARGET)
message(FATAL_ERROR "${TARGET} source numbers are not equal to destination numbers") message(FATAL_ERROR "${TARGET} source numbers are not equal to destination numbers")
endif () endif ()
math(EXPR len "${copy_lib_SRCS_len} - 1") math(EXPR len "${copy_lib_SRCS_len} - 1")
add_custom_target(${TARGET} DEPENDS ${copy_lib_DEPS}) add_custom_target(${TARGET} DEPENDS ${copy_lib_DEPS})
foreach (index RANGE ${len}) foreach (index RANGE ${len})
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
if (WIN32) if (WIN32) #windows
if(IS_DIRECTORY ${src}) file(TO_NATIVE_PATH ${src} native_src)
get_filename_component(last_path ${src} NAME) file(TO_NATIVE_PATH ${dst} native_dst)
string(APPEND dst "/" ${last_path}) add_custom_command(TARGET ${TARGET} POST_BUILD
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND ${PYTHON_EXECUTABLE} ${COPY_SCRIPT_DIR}/copyfile.py ${native_src} ${native_dst})
COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}" else (WIN32) #not windows
)
if(EXISTS ${src})
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND cmake -E copy_directory "${src}" "${dst}"
COMMENT "copying ${src} -> ${dst}")
else()
message(WARNING "${src} not exist!")
endif()
else()
# windows cmd shell will not expand wildcard automatically.
# below expand the files, and copy them by rules.
file(GLOB src_files ${src})
if (NOT "${src_files}" STREQUAL "")
list(REMOVE_DUPLICATES src_files)
endif ()
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory "${dst}"
)
foreach (src_file ${src_files})
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy "${src_file}" "${dst}"
COMMENT "copying ${src_file} -> ${dst}")
endforeach ()
endif()
else (WIN32) # not windows
add_custom_command(TARGET ${TARGET} PRE_BUILD add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND mkdir -p "${dst}" COMMAND mkdir -p "${dst}"
COMMAND cp -r "${src}" "${dst}" COMMAND cp -r "${src}" "${dst}"
...@@ -167,18 +149,6 @@ if (WITH_NGRAPH) ...@@ -167,18 +149,6 @@ if (WITH_NGRAPH)
) )
endif () endif ()
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy")
copy(snappy_lib
SRCS ${SNAPPY_INCLUDE_DIR} ${SNAPPY_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib
DEPS snappy)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappystream")
copy(snappystream_lib
SRCS ${SNAPPYSTREAM_INCLUDE_DIR} ${SNAPPYSTREAM_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib
DEPS snappystream)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib")
copy(zlib_lib copy(zlib_lib
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES} SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}
...@@ -189,13 +159,11 @@ copy(zlib_lib ...@@ -189,13 +159,11 @@ copy(zlib_lib
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
set(dst_dir "${FLUID_INSTALL_DIR}/paddle/fluid") set(dst_dir "${FLUID_INSTALL_DIR}/paddle/fluid")
set(module "framework") set(module "framework")
if (NOT WIN32) set(framework_lib_deps framework_py_proto)
set(framework_lib_deps framework_py_proto)
endif (NOT WIN32)
copy(framework_lib DEPS ${framework_lib_deps} copy(framework_lib DEPS ${framework_lib_deps}
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/data_feed.pb.h ${src_dir}/${module}/ir/memory_optimize_pass/*.h SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/data_feed.pb.h ${src_dir}/${module}/ir/memory_optimize_pass/*.h
${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h ${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet
) )
...@@ -211,7 +179,7 @@ set(module "inference/api") ...@@ -211,7 +179,7 @@ set(module "inference/api")
if (TENSORRT_FOUND) if (TENSORRT_FOUND)
copy(tensorrt_lib DEPS ${inference_deps} copy(tensorrt_lib DEPS ${inference_deps}
SRCS ${TENSORRT_ROOT}/include/Nv*.h ${TENSORRT_ROOT}/lib/libnvinfer* SRCS ${TENSORRT_ROOT}/include/Nv*.h ${TENSORRT_ROOT}/lib/*nvinfer*
DSTS ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/include ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/lib) DSTS ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/include ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/lib)
endif () endif ()
......
...@@ -2,14 +2,28 @@ if(NOT WITH_GPU) ...@@ -2,14 +2,28 @@ if(NOT WITH_GPU)
return() return()
endif() endif()
set(TENSORRT_ROOT "/usr" CACHE PATH "TENSORRT ROOT") if(WIN32)
if("${TENSORRT_ROOT}" STREQUAL "")
message(WARNING "Please specify the TensorRT root path: TENSORRT_ROOT.")
endif()
string(REPLACE "\\" "/" TENSORRT_ROOT "${TENSORRT_ROOT}")
set(TR_INFER_LIB nvinfer.lib)
set(TR_INFER_RT nvinfer.dll)
set(TR_INFER_PLUGIN_RT nvinfer_plugin.dll)
else()
set(TENSORRT_ROOT "/usr" CACHE PATH "TENSORRT ROOT")
set(TR_INFER_LIB libnvinfer.a)
set(TR_INFER_RT libnvinfer.so)
set(TR_INFER_PLUGIN_RT libnvinfer_plugin.so)
endif()
find_path(TENSORRT_INCLUDE_DIR NvInfer.h find_path(TENSORRT_INCLUDE_DIR NvInfer.h
PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include
NO_DEFAULT_PATH NO_DEFAULT_PATH
) )
find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a find_library(TENSORRT_LIBRARY NAMES ${TR_INFER_LIB} ${TR_INFER_RT}
PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib
NO_DEFAULT_PATH NO_DEFAULT_PATH
......
hash: 107c058cf5c9163a75d40eef2273a793c36112683c25d72aa8288827fdde3a19
updated: 2017-10-30T03:46:19.137696069Z
imports:
- name: github.com/alecthomas/gometalinter
version: bae2f1293d092fd8167939d5108d1b025eaef9de
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd
version: f1d7dd87da3e8feab4aaf675b8e29c6a5ed5f58b
subpackages:
- alarm
- auth
- auth/authpb
- client
- clientv3
- clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/go-stack/stack
version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages:
- jsonpb
- proto
- name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/inconshreveable/log15
version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/mattn/go-colorable
version: 5411d3eea5978e6cdc258b30de592b60df6aba96
- name: github.com/mattn/go-isatty
version: 57fdcb988a5c543893cc61bce354a6e24ab70022
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio
version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus
version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 9419663f5a44be8b34ca85f08abc5fe1be11f8a3
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- ssh/terminal
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
- context
- http2
- http2/hpack
- idna
- internal/timeseries
- lex/httplex
- trace
- name: golang.org/x/sys
version: e48874b42435b4347fc52bdee0424a52abc974d7
repo: https://github.com/golang/sys.git
vcs: git
subpackages:
- unix
- windows
- name: golang.org/x/text
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- secure/bidirule
- transform
- unicode/bidi
- unicode/norm
- name: google.golang.org/grpc
version: 8050b9cbc271307e5a716a9d782803d09b0d6f2d
subpackages:
- codes
- credentials
- grpclog
- internal
- keepalive
- metadata
- naming
- peer
- stats
- tap
- transport
- name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert
package: github.com/PaddlePaddle/Paddle/go
import:
- package: github.com/PaddlePaddle/recordio
- package: github.com/coreos/etcd
version: ^3.2.1
subpackages:
- clientv3
- clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag
version: ^1.7.4-pre
- package: github.com/sirupsen/logrus
version: ^1.0.0
- package: github.com/topicai/candy
- package: golang.org/x/crypto
repo: https://github.com/golang/crypto.git
vcs: git
- package: golang.org/x/sys
repo: https://github.com/golang/sys.git
vcs: git
- package: golang.org/x/text
repo: https://github.com/golang/text.git
vcs: git
- package: github.com/satori/go.uuid
version: v1.1.0
- package: github.com/alecthomas/gometalinter
version: v1.2.1
- package: github.com/inconshreveable/log15
version: v2.13
- package: github.com/go-stack/stack
version: v1.6.0
- package: github.com/golang/protobuf
此差异已折叠。
...@@ -4,7 +4,6 @@ add_subdirectory(framework) ...@@ -4,7 +4,6 @@ add_subdirectory(framework)
add_subdirectory(imperative) add_subdirectory(imperative)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(recordio)
add_subdirectory(pybind) add_subdirectory(pybind)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
......
...@@ -63,7 +63,7 @@ if(WITH_GPU) ...@@ -63,7 +63,7 @@ if(WITH_GPU)
else() else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif() endif()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
...@@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co ...@@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type) glog box_wrapper shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -135,6 +135,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc ...@@ -135,6 +135,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_library(op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
...@@ -177,7 +179,7 @@ if(WITH_DISTRIBUTE) ...@@ -177,7 +179,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
...@@ -188,12 +190,12 @@ else() ...@@ -188,12 +190,12 @@ else()
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer) graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper) target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper conditional_block_op_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
......
...@@ -168,10 +168,10 @@ class ArchiveBase { ...@@ -168,10 +168,10 @@ class ArchiveBase {
#else #else
if (newsize > Capacity()) { if (newsize > Capacity()) {
#endif #endif
Reserve(std::max(Capacity() * 2, newsize)); Reserve((std::max)(Capacity() * 2, newsize));
} }
finish_ = buffer_ + newsize; finish_ = buffer_ + newsize;
cursor_ = std::min(cursor_, finish_); cursor_ = (std::min)(cursor_, finish_);
} }
void Reserve(size_t newcap) { void Reserve(size_t newcap) {
...@@ -207,7 +207,7 @@ class ArchiveBase { ...@@ -207,7 +207,7 @@ class ArchiveBase {
#else #else
if (size > size_t(limit_ - finish_)) { if (size > size_t(limit_ - finish_)) {
#endif #endif
Reserve(std::max(Capacity() * 2, Length() + size)); Reserve((std::max)(Capacity() * 2, Length() + size));
} }
} }
...@@ -311,6 +311,18 @@ class Archive<BinaryArchiveType> : public ArchiveBase { ...@@ -311,6 +311,18 @@ class Archive<BinaryArchiveType> : public ArchiveBase {
*this >> x; *this >> x;
return x; return x;
} }
template <class... ARGS>
void Printf(const char* fmt, ARGS&&... args) {
size_t temp = Limit() - Finish();
int len = snprintf(Finish(), temp, fmt, args...);
CHECK(len >= 0); // NOLINT
if ((size_t)len >= temp) {
PrepareWrite(len + 1);
CHECK(snprintf(Finish(), (size_t)len + 1, fmt, args...) == len);
}
AdvanceFinish(len);
}
}; };
template <class AR, class T, size_t N> template <class AR, class T, size_t N>
......
...@@ -40,7 +40,7 @@ class ChannelObject { ...@@ -40,7 +40,7 @@ class ChannelObject {
// capacity can be zero // capacity can be zero
explicit ChannelObject(size_t capacity) { explicit ChannelObject(size_t capacity) {
capacity_ = std::min(MaxCapacity(), capacity); capacity_ = (std::min)(MaxCapacity(), capacity);
} }
void Clear() { void Clear() {
...@@ -192,7 +192,7 @@ class ChannelObject { ...@@ -192,7 +192,7 @@ class ChannelObject {
std::condition_variable full_cond_; std::condition_variable full_cond_;
static constexpr size_t MaxCapacity() { static constexpr size_t MaxCapacity() {
return std::numeric_limits<size_t>::max() / 2; return (std::numeric_limits<size_t>::max)() / 2;
} }
void Notify() { void Notify() {
...@@ -289,7 +289,7 @@ template <class T> ...@@ -289,7 +289,7 @@ template <class T>
using Channel = std::shared_ptr<ChannelObject<T>>; using Channel = std::shared_ptr<ChannelObject<T>>;
template <class T> template <class T>
Channel<T> MakeChannel(size_t capacity = std::numeric_limits<size_t>::max()) { Channel<T> MakeChannel(size_t capacity = (std::numeric_limits<size_t>::max)()) {
return std::make_shared<ChannelObject<T>>(capacity); return std::make_shared<ChannelObject<T>>(capacity);
} }
...@@ -370,7 +370,7 @@ class ChannelWriter { ...@@ -370,7 +370,7 @@ class ChannelWriter {
void Reset(ChannelObject<T>* channel) { void Reset(ChannelObject<T>* channel) {
CHECK(buffer_.empty()) << "Forgot to flush"; CHECK(buffer_.empty()) << "Forgot to flush";
CHECK(channel != nullptr) << "Channel can not be nullptr"; // CHECK(channel != nullptr) << "Channel can not be nullptr";
channel_ = channel; channel_ = channel;
buffer_.clear(); buffer_.clear();
failed_ = !channel; failed_ = !channel;
......
...@@ -33,11 +33,53 @@ limitations under the License. */ ...@@ -33,11 +33,53 @@ limitations under the License. */
#include "io/shell.h" #include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void RecordCandidateList::ReSize(size_t length) {
_mutex.lock();
_capacity = length;
CHECK(_capacity > 0); // NOLINT
_candidate_list.clear();
_candidate_list.resize(_capacity);
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
}
void RecordCandidateList::ReInit() {
_mutex.lock();
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
}
void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) {
_mutex.lock();
size_t index = 0;
++_total_size;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) {
_candidate_list[_cur_size++] = record;
_full = (_cur_size == _capacity);
} else {
CHECK(_cur_size == _capacity);
index = fleet_ptr->LocalRandomEngine()() % _total_size;
if (index < _capacity) {
_candidate_list[index] = record;
}
}
index = fleet_ptr->LocalRandomEngine()() % _cur_size;
*result = _candidate_list[index];
_mutex.unlock();
}
void DataFeed::AddFeedVar(Variable* var, const std::string& name) { void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit(); CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
...@@ -101,11 +143,24 @@ void DataFeed::AssignFeedVar(const Scope& scope) { ...@@ -101,11 +143,24 @@ void DataFeed::AssignFeedVar(const Scope& scope) {
} }
} }
void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) {
if (platform::is_cpu_place(this->place_)) {
memcpy(dst, src, size);
} else {
#ifdef PADDLE_WITH_CUDA
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
}
}
template <typename T> template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size); PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
queue_size_ = queue_size; queue_size_ = queue_size;
queue_ = paddle::framework::MakeChannel<T>(); queue_ = paddle::framework::MakeChannel<T>();
queue_->SetCapacity(queue_size);
} }
template <typename T> template <typename T>
...@@ -169,6 +224,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -169,6 +224,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0; this->thread_id_ = 0;
this->thread_num_ = 1; this->thread_num_ = 1;
this->parse_ins_id_ = false; this->parse_ins_id_ = false;
this->parse_content_ = false;
this->input_channel_ = nullptr; this->input_channel_ = nullptr;
this->output_channel_ = nullptr; this->output_channel_ = nullptr;
this->consume_channel_ = nullptr; this->consume_channel_ = nullptr;
...@@ -252,6 +308,11 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) { ...@@ -252,6 +308,11 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num; thread_num_ = thread_num;
} }
template <typename T>
void InMemoryDataFeed<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) { void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id; parse_ins_id_ = parse_ins_id;
...@@ -301,7 +362,8 @@ void MultiSlotDataFeed::Init( ...@@ -301,7 +362,8 @@ void MultiSlotDataFeed::Init(
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
SetQueueSize(data_feed_desc.batch_size()); // temporarily set queue size = batch size * 100
SetQueueSize(data_feed_desc.batch_size() * 100);
size_t all_slot_num = multi_slot_desc.slots_size(); size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num); all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num); all_slots_type_.resize(all_slot_num);
...@@ -610,15 +672,16 @@ void MultiSlotDataFeed::PutToFeedVec( ...@@ -610,15 +672,16 @@ void MultiSlotDataFeed::PutToFeedVec(
if (type[0] == 'f') { // float if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData(); const auto& feasign = ins_vec[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>( float* tensor_ptr =
{total_instance, 1}, platform::CPUPlace()); feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); CopyToFeedTensor(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64 } else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle // no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data(); const auto& feasign = ins_vec[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>( int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, this->place_);
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); CopyToFeedTensor(tensor_ptr, &feasign[0],
total_instance * sizeof(int64_t));
} }
LoD data_lod{offset}; LoD data_lod{offset};
...@@ -709,6 +772,18 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -709,6 +772,18 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
pos += len + 1; pos += len + 1;
VLOG(3) << "ins_id " << instance->ins_id_; VLOG(3) << "ins_id " << instance->ins_id_;
} }
if (parse_content_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
instance->content_ = std::string(str + pos, len);
pos += len + 1;
VLOG(3) << "content " << instance->content_;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
...@@ -833,8 +908,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -833,8 +908,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
std::vector<std::vector<size_t>> offset(use_slots_.size(), std::vector<std::vector<size_t>> offset(use_slots_.size(),
std::vector<size_t>{0}); std::vector<size_t>{0});
std::vector<bool> visit(use_slots_.size(), false); std::vector<bool> visit(use_slots_.size(), false);
ins_content_vec_.clear();
ins_content_vec_.reserve(ins_vec.size());
ins_id_vec_.clear();
ins_id_vec_.reserve(ins_vec.size());
for (size_t i = 0; i < ins_vec.size(); ++i) { for (size_t i = 0; i < ins_vec.size(); ++i) {
auto& r = ins_vec[i]; auto& r = ins_vec[i];
ins_id_vec_.push_back(r.ins_id_);
ins_content_vec_.push_back(r.content_);
for (auto& item : r.float_feasigns_) { for (auto& item : r.float_feasigns_) {
batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_); batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_);
visit[item.slot()] = true; visit[item.slot()] = true;
...@@ -872,15 +953,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -872,15 +953,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
const auto& type = all_slots_type_[i]; const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float if (type[0] == 'f') { // float
float* feasign = batch_float_feasigns[i].data(); float* feasign = batch_float_feasigns[i].data();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>( float* tensor_ptr =
{total_instance, 1}, platform::CPUPlace()); feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
memcpy(tensor_ptr, feasign, total_instance * sizeof(float)); CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64 } else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle // no uint64_t type in paddlepaddle
uint64_t* feasign = batch_uint64_feasigns[i].data(); uint64_t* feasign = batch_uint64_feasigns[i].data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>( int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, this->place_);
memcpy(tensor_ptr, feasign, total_instance * sizeof(int64_t)); CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
} }
auto& slot_offset = offset[i]; auto& slot_offset = offset[i];
LoD data_lod{slot_offset}; LoD data_lod{slot_offset};
...@@ -906,15 +987,16 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() { ...@@ -906,15 +987,16 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() {
if (type[0] == 'f') { // float if (type[0] == 'f') { // float
const auto& feasign = ins_vec_[i].GetFloatData(); const auto& feasign = ins_vec_[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>( float* tensor_ptr =
{total_instance, 1}, platform::CPUPlace()); feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); CopyToFeedTensor(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64 } else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle // no uint64_t type in paddlepaddle
const auto& feasign = ins_vec_[i].GetUint64Data(); const auto& feasign = ins_vec_[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>( int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, this->place_);
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); CopyToFeedTensor(tensor_ptr, &feasign[0],
total_instance * sizeof(int64_t));
} }
LoD data_lod{offset}; LoD data_lod{offset};
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -104,13 +105,25 @@ class DataFeed { ...@@ -104,13 +105,25 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {} virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {} virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetFileListMutex(std::mutex* mutex) { virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex; mutex_for_pick_file_ = mutex;
} }
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; } virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
virtual const std::vector<std::string>& GetInsIdVec() const {
return ins_id_vec_;
}
virtual const std::vector<std::string>& GetInsContentVec() const {
return ins_content_vec_;
}
virtual int GetCurBatchSize() { return batch_size_; }
virtual void LoadIntoMemory() { virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
} }
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
virtual const paddle::platform::Place& GetPlace() const { return place_; }
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
...@@ -124,6 +137,7 @@ class DataFeed { ...@@ -124,6 +137,7 @@ class DataFeed {
// This function is used to pick one file from the global filelist(thread // This function is used to pick one file from the global filelist(thread
// safe). // safe).
virtual bool PickOneFile(std::string* filename); virtual bool PickOneFile(std::string* filename);
virtual void CopyToFeedTensor(void* dst, const void* src, size_t size);
std::vector<std::string> filelist_; std::vector<std::string> filelist_;
size_t* file_idx_; size_t* file_idx_;
...@@ -158,6 +172,9 @@ class DataFeed { ...@@ -158,6 +172,9 @@ class DataFeed {
bool finish_set_filelist_; bool finish_set_filelist_;
bool finish_start_; bool finish_start_;
std::string pipe_command_; std::string pipe_command_;
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
}; };
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds. // PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
...@@ -215,6 +232,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -215,6 +232,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id); virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
protected: protected:
...@@ -225,6 +243,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -225,6 +243,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
bool parse_ins_id_; bool parse_ins_id_;
bool parse_content_;
std::ifstream file_; std::ifstream file_;
std::shared_ptr<FILE> fp_; std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_; paddle::framework::ChannelObject<T>* input_channel_;
...@@ -419,6 +438,42 @@ struct Record { ...@@ -419,6 +438,42 @@ struct Record {
std::vector<FeatureItem> uint64_feasigns_; std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_; std::vector<FeatureItem> float_feasigns_;
std::string ins_id_; std::string ins_id_;
std::string content_;
};
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
RecordCandidate& operator=(const Record& rec) {
feas.clear();
ins_id_ = rec.ins_id_;
for (auto& fea : rec.uint64_feasigns_) {
feas.insert({fea.slot(), fea.sign()});
}
return *this;
}
};
class RecordCandidateList {
public:
RecordCandidateList() = default;
RecordCandidateList(const RecordCandidateList&) = delete;
RecordCandidateList& operator=(const RecordCandidateList&) = delete;
void ReSize(size_t length);
void ReInit();
void AddAndGet(const Record& record, RecordCandidate* result);
private:
size_t _capacity = 0;
std::mutex _mutex;
bool _full = false;
size_t _cur_size = 0;
size_t _total_size = 0;
std::vector<RecordCandidate> _candidate_list;
}; };
template <class AR> template <class AR>
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#endif #endif
...@@ -121,24 +120,31 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -121,24 +120,31 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const Tensor& in, Tensor* out) { const Tensor& in, Tensor* out) {
auto in_layout = kernel_type_for_var.data_layout_; auto in_layout = kernel_type_for_var.data_layout_;
auto out_layout = expected_kernel_type.data_layout_; auto out_layout = expected_kernel_type.data_layout_;
auto place = expected_kernel_type.place_;
PADDLE_ENFORCE( PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN, in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"); "non-MKLDNN");
innerTransDataLayoutFromMKLDNN(in_layout, out_layout, in, out, place);
}
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE(in.format() != memory::format::format_undef && PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef,
in.format() != memory::format::any, "Input tensor should have specified memory format");
"Input tensor should have specified memory format"); PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::any,
"Input tensor should have specified memory format");
// Set default as NCHW in case not specified // Set default as NCHW in case not specified
out_layout = out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout; out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>( auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine(); auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims()); std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
...@@ -165,7 +171,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -165,7 +171,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data); auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p = auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, expected_kernel_type.place_); handler.AcquireDstMemory(out, out_format, place);
auto reorder_p = auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
...@@ -177,7 +183,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -177,7 +183,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
} }
out->set_layout(out_layout); out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel // reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef); out->set_format(MKLDNNMemoryFormat::format_undef);
#endif #endif
} }
......
...@@ -21,30 +21,33 @@ ...@@ -21,30 +21,33 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNFormat = mkldnn::memory::format;
using MKLDNNDataType = mkldnn::memory::data_type; using MKLDNNDataType = mkldnn::memory::data_type;
inline MKLDNNFormat ToMKLDNNFormat(const DataLayout& layout) { inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
switch (layout) { switch (layout) {
case DataLayout::kNHWC: case DataLayout::kNHWC:
return MKLDNNFormat::nhwc; return MKLDNNMemoryFormat::nhwc;
case DataLayout::kNCHW: case DataLayout::kNCHW:
return MKLDNNFormat::nchw; return MKLDNNMemoryFormat::nchw;
default: default:
PADDLE_THROW("Fail to convert layout %s to MKLDNN format", PADDLE_THROW("Fail to convert layout %s to MKLDNN format",
DataLayoutToString(layout)); DataLayoutToString(layout));
} }
} }
inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
switch (format) { switch (format) {
case MKLDNNFormat::nhwc: case MKLDNNMemoryFormat::nhwc:
return DataLayout::kNHWC; return DataLayout::kNHWC;
case MKLDNNFormat::nchw: case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW; return DataLayout::kNCHW;
default: default:
PADDLE_THROW("Fail to convert MKLDNN format to paddle layout"); PADDLE_THROW("Fail to convert MKLDNN format to paddle layout");
...@@ -69,6 +72,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -69,6 +72,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out); const Tensor& in, Tensor* out);
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const OpKernelType& kernel_type_for_var,
......
...@@ -48,6 +48,8 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -48,6 +48,8 @@ DatasetImpl<T>::DatasetImpl() {
erase_duplicate_feas_ = true; erase_duplicate_feas_ = true;
keep_unmerged_ins_ = true; keep_unmerged_ins_ = true;
min_merge_size_ = 2; min_merge_size_ = 2;
parse_ins_id_ = false;
parse_content_ = false;
} }
// set filelist, file_idx_ will reset to zero. // set filelist, file_idx_ will reset to zero.
...@@ -103,6 +105,16 @@ void DatasetImpl<T>::SetChannelNum(int channel_num) { ...@@ -103,6 +105,16 @@ void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num; channel_num_ = channel_num;
} }
template <typename T>
void DatasetImpl<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T>
void DatasetImpl<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetMergeByInsId( void DatasetImpl<T>::SetMergeByInsId(
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas, const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
...@@ -114,6 +126,14 @@ void DatasetImpl<T>::SetMergeByInsId( ...@@ -114,6 +126,14 @@ void DatasetImpl<T>::SetMergeByInsId(
keep_unmerged_ins_ = keep_unmerged_ins; keep_unmerged_ins_ = keep_unmerged_ins;
} }
template <typename T>
void DatasetImpl<T>::SetFeaEval(bool fea_eval, int record_candidate_size) {
slots_shuffle_fea_eval_ = fea_eval;
slots_shuffle_rclist_.ReSize(record_candidate_size);
VLOG(3) << "SetFeaEval fea eval mode: " << fea_eval
<< " with record candidate size: " << record_candidate_size;
}
template <typename T> template <typename T>
std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() { std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() {
std::vector<paddle::framework::DataFeed*> ret; std::vector<paddle::framework::DataFeed*> ret;
...@@ -352,8 +372,6 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -352,8 +372,6 @@ void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Filelist size in Dataset: " << filelist_.size(); VLOG(3) << "Filelist size in Dataset: " << filelist_.size();
VLOG(3) << "channel num in Dataset: " << channel_num_; VLOG(3) << "channel num in Dataset: " << channel_num_;
CHECK(thread_num_ > 0) << "thread num should > 0"; CHECK(thread_num_ > 0) << "thread num should > 0";
CHECK(thread_num_ <= filelist_.size())
<< "thread num should <= filelist size";
CHECK(channel_num_ > 0) << "channel num should > 0"; CHECK(channel_num_ > 0) << "channel num should > 0";
CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num"; CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num";
VLOG(3) << "readers size: " << readers_.size(); VLOG(3) << "readers size: " << readers_.size();
...@@ -372,7 +390,8 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -372,7 +390,8 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileListMutex(&mutex_for_pick_file_); readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
readers_[i]->SetFileListIndex(&file_idx_); readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_); readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(merge_by_insid_); readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseContent(parse_content_);
if (input_channel_ != nullptr) { if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get()); readers_[i]->SetInputChannel(input_channel_.get());
} }
...@@ -648,5 +667,167 @@ void MultiSlotDataset::MergeByInsId() { ...@@ -648,5 +667,167 @@ void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId end"; VLOG(3) << "MultiSlotDataset::MergeByInsId end";
} }
void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
slots_shuffle_rclist_.ReInit();
for (const auto& rec : slots_shuffle_original_data_) {
RecordCandidate rand_rec;
Record new_rec = rec;
slots_shuffle_rclist_.AddAndGet(rec, &rand_rec);
for (auto it = new_rec.uint64_feasigns_.begin();
it != new_rec.uint64_feasigns_.end();) {
if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) {
it = new_rec.uint64_feasigns_.erase(it);
debug_erase_cnt += 1;
} else {
++it;
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
}
}
result->push_back(std::move(new_rec));
}
VLOG(2) << "erase feasign num: " << debug_erase_cnt
<< " repush feasign num: " << debug_push_cnt;
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
int out_channel_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
out_channel_size += multi_output_channel_[i]->Size();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
out_channel_size += multi_consume_channel_[i]->Size();
}
}
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<< input_channel_->Size()
<< " output channel size: " << out_channel_size;
if (!slots_shuffle_fea_eval_) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle";
return;
}
if ((!input_channel_ || input_channel_->Size() == 0) &&
slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle";
return;
}
platform::Timer timeline;
timeline.Start();
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::set<uint16_t> index_slots;
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name();
if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) {
index_slots.insert(i);
}
}
if (slots_shuffle_original_data_.size() == 0) {
// before first slots shuffle, instances could be in
// input_channel, oupput_channel or consume_channel
if (input_channel_ && input_channel_->Size() != 0) {
slots_shuffle_original_data_.reserve(input_channel_->Size());
input_channel_->Close();
input_channel_->ReadAll(slots_shuffle_original_data_);
} else {
CHECK(out_channel_size > 0); // NOLINT
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<Record> vec_data;
multi_output_channel_[i]->Close();
multi_output_channel_[i]->ReadAll(vec_data);
slots_shuffle_original_data_.reserve(
slots_shuffle_original_data_.size() + vec_data.size());
slots_shuffle_original_data_.insert(
slots_shuffle_original_data_.end(),
std::make_move_iterator(vec_data.begin()),
std::make_move_iterator(vec_data.end()));
vec_data.clear();
vec_data.shrink_to_fit();
multi_output_channel_[i]->Clear();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
std::vector<Record> vec_data;
multi_consume_channel_[i]->Close();
multi_consume_channel_[i]->ReadAll(vec_data);
slots_shuffle_original_data_.reserve(
slots_shuffle_original_data_.size() + vec_data.size());
slots_shuffle_original_data_.insert(
slots_shuffle_original_data_.end(),
std::make_move_iterator(vec_data.begin()),
std::make_move_iterator(vec_data.end()));
vec_data.clear();
vec_data.shrink_to_fit();
multi_consume_channel_[i]->Clear();
}
}
}
} else {
// if already have original data for slots shuffle, clear channel
input_channel_->Clear();
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
if (!multi_output_channel_[i]) {
continue;
}
multi_output_channel_[i]->Clear();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
if (!multi_consume_channel_[i]) {
continue;
}
multi_consume_channel_[i]->Clear();
}
}
}
int end_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
if (!multi_output_channel_[i]) {
continue;
}
end_size += multi_output_channel_[i]->Size();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
if (!multi_consume_channel_[i]) {
continue;
}
end_size += multi_consume_channel_[i]->Size();
}
}
CHECK(input_channel_->Size() == 0)
<< "input channel should be empty before slots shuffle";
std::vector<Record> random_data;
random_data.clear();
// get slots shuffled random_data
GetRandomData(index_slots, &random_data);
input_channel_->Open();
input_channel_->Write(std::move(random_data));
random_data.clear();
random_data.shrink_to_fit();
input_channel_->Close();
timeline.Pause();
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() end"
<< ", memory data size for slots shuffle=" << input_channel_->Size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <set>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
...@@ -57,10 +58,15 @@ class Dataset { ...@@ -57,10 +58,15 @@ class Dataset {
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// set channel num // set channel num
virtual void SetChannelNum(int channel_num) = 0; virtual void SetChannelNum(int channel_num) = 0;
// set parse ins id
virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0;
// set merge by ins id // set merge by ins id
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list, virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size, bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins) = 0; bool keep_unmerged_ins) = 0;
// set fea eval mode
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
// get file list // get file list
virtual const std::vector<std::string>& GetFileList() = 0; virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num // get thread num
...@@ -94,6 +100,10 @@ class Dataset { ...@@ -94,6 +100,10 @@ class Dataset {
virtual void LocalShuffle() = 0; virtual void LocalShuffle() = 0;
// global shuffle data // global shuffle data
virtual void GlobalShuffle() = 0; virtual void GlobalShuffle() = 0;
// for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) = 0;
// create readers // create readers
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
// destroy readers // destroy readers
...@@ -126,13 +136,17 @@ class DatasetImpl : public Dataset { ...@@ -126,13 +136,17 @@ class DatasetImpl : public Dataset {
const std::string& fs_ugi); const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num); virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list, virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size, bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins); bool keep_unmerged_ins);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
virtual const std::vector<std::string>& GetFileList() { return filelist_; } virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; } virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; } virtual int GetTrainerNum() { return trainer_num_; }
virtual Channel<T> GetInputChannel() { return input_channel_; }
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; } virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() { virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_); return std::make_pair(fs_name_, fs_ugi_);
...@@ -150,6 +164,9 @@ class DatasetImpl : public Dataset { ...@@ -150,6 +164,9 @@ class DatasetImpl : public Dataset {
virtual void ReleaseMemory(); virtual void ReleaseMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {}
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize(); virtual int64_t GetMemoryDataSize();
...@@ -168,6 +185,8 @@ class DatasetImpl : public Dataset { ...@@ -168,6 +185,8 @@ class DatasetImpl : public Dataset {
// and when finish reading, we set cur_channel = 1 - cur_channel, // and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in output_channel, else consume_channel // so if cur_channel=0, all data are in output_channel, else consume_channel
int cur_channel_; int cur_channel_;
std::vector<T> slots_shuffle_original_data_;
RecordCandidateList slots_shuffle_rclist_;
int thread_num_; int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_; paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_; int trainer_num_;
...@@ -180,10 +199,13 @@ class DatasetImpl : public Dataset { ...@@ -180,10 +199,13 @@ class DatasetImpl : public Dataset {
int64_t fleet_send_sleep_seconds_; int64_t fleet_send_sleep_seconds_;
std::vector<std::thread> preload_threads_; std::vector<std::thread> preload_threads_;
bool merge_by_insid_; bool merge_by_insid_;
bool parse_ins_id_;
bool parse_content_;
bool erase_duplicate_feas_; bool erase_duplicate_feas_;
bool keep_unmerged_ins_; bool keep_unmerged_ins_;
int min_merge_size_; int min_merge_size_;
std::vector<std::string> merge_slots_list_; std::vector<std::string> merge_slots_list_;
bool slots_shuffle_fea_eval_ = false;
}; };
// use std::vector<MultiSlotType> or Record as data type // use std::vector<MultiSlotType> or Record as data type
...@@ -191,6 +213,9 @@ class MultiSlotDataset : public DatasetImpl<Record> { ...@@ -191,6 +213,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
public: public:
MultiSlotDataset() {} MultiSlotDataset() {}
virtual void MergeByInsId(); virtual void MergeByInsId();
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {} virtual ~MultiSlotDataset() {}
}; };
......
...@@ -20,12 +20,9 @@ ...@@ -20,12 +20,9 @@
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue: #ifdef PADDLE_WITH_CUDA
// https://github.com/PaddlePaddle/Paddle/issues/15049 DECLARE_bool(sync_nccl_allreduce);
DEFINE_bool( #endif
sync_nccl_allreduce, true,
"If set true, will call `cudaStreamSynchronize(nccl_stream)`"
"after allreduce, this mode can get better performance in some scenarios.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -77,7 +77,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -77,7 +77,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Specifies the restrictions between different pass. // Specifies the restrictions between different pass.
if (strategy_.enable_parallel_graph_) { if (strategy_.enable_parallel_graph_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops doesn't works under " << "Currently, fuse_all_optimizer_ops doesn't work under "
"parallel_graph."; "parallel_graph.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = false;
} }
...@@ -96,6 +96,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -96,6 +96,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
<< "fuse_all_optimizer_ops only work in Reducer mode."; << "fuse_all_optimizer_ops only work in Reducer mode.";
strategy_.fuse_all_reduce_ops_ = false; strategy_.fuse_all_reduce_ops_ = false;
} }
if (strategy_.async_mode_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops doesn't work under "
"async mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
} }
void AppendMultiGraphOptPasses() { void AppendMultiGraphOptPasses() {
......
...@@ -31,7 +31,7 @@ struct ExecutionStrategy { ...@@ -31,7 +31,7 @@ struct ExecutionStrategy {
// iterations the framework cleans up a local execution scope. // iterations the framework cleans up a local execution scope.
// In some models, the value of this parameter has a great // In some models, the value of this parameter has a great
// influence on the performance(about 15%) of the program. // influence on the performance(about 15%) of the program.
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{100};
// At present, the kExperimental executor is the fastest in most models. // At present, the kExperimental executor is the fastest in most models.
ExecutorType type_{kExperimental}; ExecutorType type_{kExperimental};
// This debug option. // This debug option.
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include <deque>
#include <memory> #include <memory>
#include <queue>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -191,13 +191,13 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -191,13 +191,13 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) { const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_; ++remaining_;
this->pool_.enqueue([=] { this->pool_.enqueue([=] {
std::queue<OpHandleBase *> op_queue; std::deque<OpHandleBase *> op_queue;
op_queue.push(op); op_queue.push_front(op);
size_t complete = 0; size_t complete = 0;
while (!op_queue.empty()) { while (!op_queue.empty()) {
OpHandleBase *op_to_run = op_queue.front(); OpHandleBase *op_to_run = op_queue.back();
op_queue.pop(); op_queue.pop_back();
if (!RunOp(op_to_run, complete_q, &complete)) { if (!RunOp(op_to_run, complete_q, &complete)) {
return; return;
...@@ -213,7 +213,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -213,7 +213,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
// NOTE(zjl): op with highest priority should run // NOTE(zjl): op with highest priority should run
// first without switching to another thread. // first without switching to another thread.
if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) { if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) {
op_queue.push(pending_op); op_queue.push_back(pending_op);
} else { } else {
if (op_to_run == nullptr) { if (op_to_run == nullptr) {
op_to_run = pending_op; op_to_run = pending_op;
...@@ -224,7 +224,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -224,7 +224,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
} }
} }
if (op_to_run != nullptr) op_queue.push(op_to_run); if (op_to_run != nullptr) {
op_queue.push_front(op_to_run);
}
} }
--remaining_; --remaining_;
complete_q->Push(complete); complete_q->Push(complete);
......
...@@ -114,12 +114,19 @@ class DeviceWorker { ...@@ -114,12 +114,19 @@ class DeviceWorker {
virtual void BindingDataFeedMemory() = 0; virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope); virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(DataFeed* data_feed); virtual void SetDataFeed(DataFeed* data_feed);
virtual void SetNeedDump(bool need_dump_field) {}
virtual void SetChannelWriter(ChannelObject<std::string>* queue) {}
virtual void SetPlace(const paddle::platform::Place& place) { virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place; place_ = place;
} }
virtual void SetReaderPlace(const paddle::platform::Place& place) {
device_reader_->SetPlace(place);
}
virtual Scope* GetThreadScope() { return thread_scope_; }
protected: protected:
Scope* root_scope_ = nullptr; Scope* root_scope_ = nullptr;
Scope* thread_scope_;
paddle::platform::Place place_; paddle::platform::Place place_;
DataFeed* device_reader_ = nullptr; DataFeed* device_reader_ = nullptr;
int64_t batch_num_; int64_t batch_num_;
...@@ -151,15 +158,18 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -151,15 +158,18 @@ class HogwildWorker : public CPUWorkerBase {
virtual void PrintFetchVars(); virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog); virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory(); virtual void BindingDataFeedMemory();
template <typename T>
void SetZero(LoDTensor* tensor, LoDTensor* root_tensor, int tensor_dim);
protected: protected:
void CreateThreadOperators(const ProgramDesc& program); void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program);
std::vector<std::string> op_names_; std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
Scope* thread_scope_; // Scope* thread_scope_;
HogwildWorkerParameter param_; HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_; std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_;
}; };
class DownpourWorker : public HogwildWorker { class DownpourWorker : public HogwildWorker {
...@@ -169,6 +179,8 @@ class DownpourWorker : public HogwildWorker { ...@@ -169,6 +179,8 @@ class DownpourWorker : public HogwildWorker {
virtual void Initialize(const TrainerDesc& desc); virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles(); virtual void TrainFiles();
virtual void TrainFilesWithProfiler(); virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
protected: protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_; std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
...@@ -176,11 +188,15 @@ class DownpourWorker : public HogwildWorker { ...@@ -176,11 +188,15 @@ class DownpourWorker : public HogwildWorker {
void FillSparseValue(size_t table_id); void FillSparseValue(size_t table_id);
void PushGradients(); void PushGradients();
void CollectLabelInfo(size_t table_id); void CollectLabelInfo(size_t table_id);
void AdjustInsWeight();
private: private:
bool need_to_push_dense_; bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_; bool dump_slot_;
bool need_to_push_sparse_; bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
DownpourWorkerParameter param_; DownpourWorkerParameter param_;
float scale_datanorm_; float scale_datanorm_;
// just save the value in param_ for easy access // just save the value in param_ for easy access
...@@ -205,6 +221,10 @@ class DownpourWorker : public HogwildWorker { ...@@ -205,6 +221,10 @@ class DownpourWorker : public HogwildWorker {
std::shared_ptr<PullDenseWorker> _pull_dense_worker; std::shared_ptr<PullDenseWorker> _pull_dense_worker;
std::vector<::std::future<int32_t>> push_sparse_status_; std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_; std::vector<::std::future<int32_t>> push_dense_status_;
// adjust ins weight
AdjustInsWeightConfig adjust_ins_weight_config_;
std::vector<float> nid_show_;
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
...@@ -22,16 +23,34 @@ limitations under the License. */ ...@@ -22,16 +23,34 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
Dataset* dataset) { Dataset *dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
SetDataset(dataset); SetDataset(dataset);
const std::vector<paddle::framework::DataFeed*> readers = dump_fields_path_ = trainer_desc.dump_fields_path();
dump_converter_ = trainer_desc.dump_converter();
need_dump_field_ = false;
if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") {
need_dump_field_ = true;
}
if (need_dump_field_) {
auto &file_list = dataset->GetFileList();
if (file_list.size() == 0) {
need_dump_field_ = false;
}
}
mpi_rank_ = trainer_desc.mpi_rank() / 2;
const std::vector<paddle::framework::DataFeed *> readers =
dataset->GetReaders(); dataset->GetReaders();
thread_num_ = readers.size(); thread_num_ = readers.size();
workers_.resize(thread_num_); workers_.resize(thread_num_);
for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
i++) {
need_merge_var_names_.push_back(
trainer_desc.downpour_param().stat_var_names(i));
}
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
...@@ -39,6 +58,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -39,6 +58,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]); workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc); workers_[i]->Initialize(trainer_desc);
workers_[i]->SetNeedDump(need_dump_field_);
} }
VLOG(3) << "going to initialize pull dense worker"; VLOG(3) << "going to initialize pull dense worker";
...@@ -48,7 +68,51 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -48,7 +68,51 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
SetDebug(trainer_desc.debug()); SetDebug(trainer_desc.debug());
} }
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { void DistMultiTrainer::DumpWork() {
#ifdef _LINUX
while (1) {
std::string out_str;
if (!queue_->Get(out_str)) {
break;
}
size_t write_count =
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp_.get());
if (write_count != out_str.length()) {
VLOG(3) << "dump text failed";
continue;
}
write_count = fwrite_unlocked("\n", 1, 1, fp_.get());
if (write_count != 1) {
VLOG(3) << "dump text failed";
continue;
}
}
#endif
}
void DistMultiTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
fp_ = fs_open_write(path, &err_no, dump_converter_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_ = std::thread(&DistMultiTrainer::DumpWork, this);
}
void DistMultiTrainer::FinalizeDumpEnv() {
queue_->Close();
dump_thread_.join();
queue_.reset();
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) {
if (need_dump_field_) {
InitDumpEnv();
}
pull_dense_worker_->SetRootScope(root_scope_); pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start(); pull_dense_worker_->Start();
VLOG(3) << "init other env done."; VLOG(3) << "init other env done.";
...@@ -67,12 +131,48 @@ void DistMultiTrainer::Run() { ...@@ -67,12 +131,48 @@ void DistMultiTrainer::Run() {
} }
void DistMultiTrainer::Finalize() { void DistMultiTrainer::Finalize() {
for (auto& th : threads_) { for (auto &th : threads_) {
th.join(); th.join();
} }
for (int i = 0; i < need_merge_var_names_.size(); i++) {
Variable *root_var = root_scope_->FindVar(need_merge_var_names_[i]);
if (root_var == nullptr) {
continue;
}
LoDTensor *root_tensor = root_var->GetMutable<LoDTensor>();
for (int j = 1; j < thread_num_; j++) {
Scope *cur_thread_scope = workers_[j]->GetThreadScope();
Variable *thread_var =
cur_thread_scope->FindVar(need_merge_var_names_[i]);
LoDTensor *thread_tensor = thread_var->GetMutable<LoDTensor>();
if (root_tensor->numel() != thread_tensor->numel()) {
continue;
}
#define MergeCallback(cpp_type, proto_type) \
do { \
if (root_tensor->type() == proto_type) { \
MergeToRootScope<cpp_type>(root_tensor, thread_tensor); \
} \
} while (0)
_ForEachDataType_(MergeCallback);
}
}
if (need_dump_field_) {
FinalizeDumpEnv();
}
pull_dense_worker_->Stop(); pull_dense_worker_->Stop();
root_scope_->DropKids(); root_scope_->DropKids();
} }
template <typename T>
void DistMultiTrainer::MergeToRootScope(LoDTensor *root_tensor,
LoDTensor *tensor) {
T *root_data = root_tensor->data<T>();
T *data = tensor->data<T>();
for (int i = 0; i < tensor->numel(); i++) {
root_data[i] += data[i];
}
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -15,6 +15,12 @@ limitations under the License. */ ...@@ -15,6 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -58,6 +64,10 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -58,6 +64,10 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
for (int i = 0; i < param_.stat_var_names_size(); ++i) {
stat_var_name_map_[param_.stat_var_names(i)] = 1;
}
need_to_push_sparse_ = param_.push_sparse(); need_to_push_sparse_ = param_.push_sparse();
need_to_push_dense_ = param_.push_dense(); need_to_push_dense_ = param_.push_dense();
...@@ -66,6 +76,87 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -66,6 +76,87 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
scale_datanorm_ = desc.scale_datanorm(); scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot(); dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
for (int i = 0; i < desc.dump_fields_size(); ++i) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
}
void DownpourWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
writer_.Reset(queue);
}
void DownpourWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}
template <typename T>
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << tensor->data<T>()[i];
}
return os.str();
}
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
}
return os.str();
}
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
std::string out_val;
if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end);
} else if (tensor->type() == proto::VarType::INT64) {
out_val = PrintLodTensorIntType(tensor, start, end);
} else if (tensor->type() == proto::VarType::FP64) {
out_val = PrintLodTensorType<double>(tensor, start, end);
} else {
out_val = "unsupported type";
}
return out_val;
}
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
auto& dims = tensor->dims();
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
return {lod[index] * dims[1], lod[index + 1] * dims[1]};
} else {
return {index * dims[1], (index + 1) * dims[1]};
}
}
bool CheckValidOutput(LoDTensor* tensor, int batch_size) {
auto& dims = tensor->dims();
if (dims.size() != 2) return false;
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
if (lod.size() != batch_size + 1) {
return false;
}
} else {
if (dims[0] != batch_size) {
return false;
}
}
return true;
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
...@@ -150,30 +241,130 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -150,30 +241,130 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
auto& tensor_lod = tensor->lod()[0]; auto& tensor_lod = tensor->lod()[0];
LoD data_lod{tensor_lod}; LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod); tensor_emb->set_lod(data_lod);
bool is_nid = (adjust_ins_weight_config_.need_adjust() &&
adjust_ins_weight_config_.nid_slot() == emb_slot_name);
if (is_nid) {
nid_show_.clear();
}
int nid_ins_index = 0;
for (int index = 0; index < len; ++index) { for (int index = 0; index < len; ++index) {
if (use_cvm_) { if (use_cvm_) {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(), memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
if (is_nid) {
nid_show_.push_back(-1);
++nid_ins_index;
}
continue; continue;
} }
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(), memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(),
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
if (is_nid && index == tensor->lod()[0][nid_ins_index]) {
nid_show_.push_back(fea_value[fea_idx][0]);
++nid_ins_index;
}
fea_idx++; fea_idx++;
} else { } else {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2, memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
if (is_nid) {
nid_show_.push_back(-1);
++nid_ins_index;
}
continue; continue;
} }
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2, memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
if (is_nid && index == tensor->lod()[0][nid_ins_index]) {
nid_show_.push_back(fea_value[fea_idx][0]);
++nid_ins_index;
}
fea_idx++; fea_idx++;
} }
} }
} }
} }
void DownpourWorker::AdjustInsWeight() {
#ifdef _LINUX
// check var and tensor not null
if (!adjust_ins_weight_config_.need_adjust()) {
VLOG(0) << "need_adjust=false, skip adjust ins weight";
return;
}
Variable* nid_var =
thread_scope_->FindVar(adjust_ins_weight_config_.nid_slot());
if (nid_var == nullptr) {
VLOG(0) << "nid slot var " << adjust_ins_weight_config_.nid_slot()
<< " is nullptr, skip adjust ins weight";
return;
}
LoDTensor* nid_tensor = nid_var->GetMutable<LoDTensor>();
if (nid_tensor == nullptr) {
VLOG(0) << "tensor of nid slot var " << adjust_ins_weight_config_.nid_slot()
<< " is nullptr, skip adjust ins weight";
return;
}
Variable* ins_weight_var =
thread_scope_->FindVar(adjust_ins_weight_config_.ins_weight_slot());
if (ins_weight_var == nullptr) {
VLOG(0) << "ins weight var " << adjust_ins_weight_config_.ins_weight_slot()
<< " is nullptr, skip adjust ins weight";
return;
}
LoDTensor* ins_weight_tensor = ins_weight_var->GetMutable<LoDTensor>();
if (ins_weight_tensor == nullptr) {
VLOG(0) << "tensor of ins weight tensor "
<< adjust_ins_weight_config_.ins_weight_slot()
<< " is nullptr, skip adjust ins weight";
return;
}
float* ins_weights = ins_weight_tensor->data<float>();
size_t len = ins_weight_tensor->numel(); // len = batch size
// here we assume nid_show slot only has one feasign in each instance
CHECK(len == nid_show_.size()) << "ins_weight size should be equal to "
<< "nid_show size, " << len << " vs "
<< nid_show_.size();
float nid_adjw_threshold = adjust_ins_weight_config_.nid_adjw_threshold();
float nid_adjw_ratio = adjust_ins_weight_config_.nid_adjw_ratio();
int64_t nid_adjw_num = 0;
double nid_adjw_weight = 0.0;
size_t ins_index = 0;
for (int i = 0; i < len; ++i) {
float nid_show = nid_show_[i];
VLOG(3) << "nid_show " << nid_show;
if (nid_show < 0) {
VLOG(3) << "nid_show < 0, continue";
continue;
}
float ins_weight = 1.0;
if (nid_show >= 0 && nid_show < nid_adjw_threshold) {
ins_weight = log(M_E +
(nid_adjw_threshold - nid_show) / nid_adjw_threshold *
nid_adjw_ratio);
// count nid adjw insnum and weight
++nid_adjw_num;
nid_adjw_weight += ins_weight;
// choose large ins weight
VLOG(3) << "ins weight new " << ins_weight << ", ins weight origin "
<< ins_weights[ins_index];
if (ins_weight > ins_weights[ins_index]) {
VLOG(3) << "ins " << ins_index << " weight changes to " << ins_weight;
ins_weights[ins_index] = ins_weight;
}
++ins_index;
}
}
VLOG(3) << "nid adjw info: total_adjw_num: " << nid_adjw_num
<< ", avg_adjw_weight: " << nid_adjw_weight;
#endif
}
void DownpourWorker::TrainFilesWithProfiler() { void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "Begin to train files with profiler"; VLOG(3) << "Begin to train files with profiler";
platform::SetNumThreads(1); platform::SetNumThreads(1);
...@@ -202,6 +393,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -202,6 +393,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
double total_time = 0.0; double total_time = 0.0;
double read_time = 0.0; double read_time = 0.0;
double pull_sparse_time = 0.0; double pull_sparse_time = 0.0;
double adjust_ins_weight_time = 0.0;
double collect_label_time = 0.0; double collect_label_time = 0.0;
double fill_sparse_time = 0.0; double fill_sparse_time = 0.0;
double push_sparse_time = 0.0; double push_sparse_time = 0.0;
...@@ -209,8 +401,6 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -209,8 +401,6 @@ void DownpourWorker::TrainFilesWithProfiler() {
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
uint64_t total_inst = 0; uint64_t total_inst = 0;
double op_sum_time = 0;
std::unordered_map<std::string, double> op_to_time;
timeline.Start(); timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause(); timeline.Pause();
...@@ -245,6 +435,16 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -245,6 +435,16 @@ void DownpourWorker::TrainFilesWithProfiler() {
timeline.Pause(); timeline.Pause();
fill_sparse_time += timeline.ElapsedSec(); fill_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
timeline.Start();
auto nid_iter = std::find(sparse_value_names_[tid].begin(),
sparse_value_names_[tid].end(),
adjust_ins_weight_config_.nid_slot());
if (nid_iter != sparse_value_names_[tid].end()) {
AdjustInsWeight();
}
timeline.Pause();
adjust_ins_weight_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
} }
VLOG(3) << "Fill sparse value for all sparse table done."; VLOG(3) << "Fill sparse value for all sparse table done.";
...@@ -358,6 +558,8 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -358,6 +558,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
if (thread_id_ == 0) { if (thread_id_ == 0) {
// should be configured here // should be configured here
if (batch_cnt > 0 && batch_cnt % 100 == 0) { if (batch_cnt > 0 && batch_cnt % 100 == 0) {
double op_sum_time = 0;
std::unordered_map<std::string, double> op_to_time;
for (size_t i = 0; i < op_total_time.size(); ++i) { for (size_t i = 0; i < op_total_time.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt); op_name[i].c_str(), op_total_time[i] / batch_cnt);
...@@ -382,10 +584,15 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -382,10 +584,15 @@ void DownpourWorker::TrainFilesWithProfiler() {
fprintf(stderr, "push dense time: %fs\n", push_dense_time / batch_cnt); fprintf(stderr, "push dense time: %fs\n", push_dense_time / batch_cnt);
fprintf(stderr, "collect label time: %fs\n", fprintf(stderr, "collect label time: %fs\n",
collect_label_time / batch_cnt); collect_label_time / batch_cnt);
fprintf(stderr, "adjust ins weight time: %fs\n",
adjust_ins_weight_time / batch_cnt);
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt); fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100); fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "op run percent: %f\n", op_sum_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n", fprintf(stderr, "pull sparse time percent: %f\n",
pull_sparse_time / total_time * 100); pull_sparse_time / total_time * 100);
fprintf(stderr, "adjust ins weight time percent: %f\n",
adjust_ins_weight_time / total_time * 100);
fprintf(stderr, "collect label time percent: %f\n", fprintf(stderr, "collect label time percent: %f\n",
collect_label_time / total_time * 100); collect_label_time / total_time * 100);
fprintf(stderr, "fill sparse time percent: %f\n", fprintf(stderr, "fill sparse time percent: %f\n",
...@@ -425,6 +632,12 @@ void DownpourWorker::TrainFiles() { ...@@ -425,6 +632,12 @@ void DownpourWorker::TrainFiles() {
&feature_values_[tid], table.fea_dim()); &feature_values_[tid], table.fea_dim());
CollectLabelInfo(i); CollectLabelInfo(i);
FillSparseValue(i); FillSparseValue(i);
auto nid_iter = std::find(sparse_value_names_[tid].begin(),
sparse_value_names_[tid].end(),
adjust_ins_weight_config_.nid_slot());
if (nid_iter != sparse_value_names_[tid].end()) {
AdjustInsWeight();
}
} }
VLOG(3) << "fill sparse value for all sparse table done."; VLOG(3) << "fill sparse value for all sparse table done.";
...@@ -518,11 +731,52 @@ void DownpourWorker::TrainFiles() { ...@@ -518,11 +731,52 @@ void DownpourWorker::TrainFiles() {
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
} }
} }
if (need_dump_field_) {
int batch_size = device_reader_->GetCurBatchSize();
std::vector<std::string> ars(batch_size);
for (auto& ar : ars) {
ar.clear();
}
auto& ins_id_vec = device_reader_->GetInsIdVec();
auto& ins_content_vec = device_reader_->GetInsContentVec();
for (size_t i = 0; i < ins_id_vec.size(); i++) {
ars[i] += ins_id_vec[i];
ars[i] = ars[i] + "\t" + ins_content_vec[i];
}
for (auto& field : dump_fields_) {
Variable* var = thread_scope_->FindVar(field);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (!CheckValidOutput(tensor, batch_size)) {
continue;
}
for (int i = 0; i < batch_size; ++i) {
auto output_dim = tensor->dims()[1];
std::string output_dimstr =
boost::lexical_cast<std::string>(output_dim);
ars[i] = ars[i] + "\t" + field + ":" + output_dimstr;
auto bound = GetTensorBound(tensor, i);
ars[i] += PrintLodTensor(tensor, bound.first, bound.second);
}
}
// #pragma omp parallel for
for (size_t i = 0; i < ars.size(); i++) {
if (ars[i].length() == 0) {
continue;
}
writer_ << ars[i];
}
}
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
++batch_cnt; ++batch_cnt;
} }
if (need_dump_field_) {
writer_.Flush();
}
} }
} // end namespace framework } // end namespace framework
......
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer_factory.h" #include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
...@@ -58,10 +59,30 @@ ExecutorPrepareContext::ExecutorPrepareContext( ...@@ -58,10 +59,30 @@ ExecutorPrepareContext::ExecutorPrepareContext(
void ExecutorPrepareContext::PrepareUnusedVars( void ExecutorPrepareContext::PrepareUnusedVars(
const std::vector<std::string>& keep_vars, bool force_disable_gc) { const std::vector<std::string>& keep_vars, bool force_disable_gc) {
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
// FIXME(zjl): There is difference when ngraph and gc are both enabled
// in unittests. I do not know why it happens. Maybe ngraph engine
// would cache some variables?
LOG_FIRST_N(WARNING, 1)
<< "FLAGS_use_ngraph=True, garbage collection strategy is "
"disabled in Executor";
force_disable_gc = true;
}
#endif
force_disable_gc_ = force_disable_gc; force_disable_gc_ = force_disable_gc;
if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) { if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
return; return;
} }
// If gc is enabled and block size > 1
if (prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
block_id_, ops_);
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(block_id_, ops_);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
block_id_, ops_);
}
unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
} }
...@@ -388,8 +409,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -388,8 +409,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc; std::unique_ptr<GarbageCollector> gc;
// FIXME(zjl): recurrent_op is rather complex, we would
// disable gc forcely in recurrent_op
if (!ctx->force_disable_gc_ && max_memory_size >= 0) { if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -407,13 +426,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -407,13 +426,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
// If gc is enabled and block size > 1
if (gc && ctx->prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
ctx->block_id_, ctx->ops_);
}
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
......
...@@ -5,3 +5,8 @@ else() ...@@ -5,3 +5,8 @@ else()
endif(WITH_PSLIB) endif(WITH_PSLIB)
cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope) cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope)
if(WITH_BOX_PS)
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor box_ps)
else()
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor)
endif(WITH_BOX_PS)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace framework {
std::shared_ptr<BoxWrapper> BoxWrapper::s_instance_ = nullptr;
#ifdef PADDLE_WITH_BOX_PS
std::shared_ptr<paddle::boxps::BoxPSBase> BoxWrapper::boxps_ptr_ = nullptr;
#endif
int BoxWrapper::GetDate() const {
time_t now = time(0);
tm t;
#ifdef _WIN32
localtime_s(&t, &now);
#else
localtime_r(&now, &t);
#endif
char buf[10];
snprintf(buf, sizeof(buf), "%04d%02d%02d", (1900 + t.tm_year), (1 + t.tm_mon),
t.tm_mday);
return atoi(buf);
}
void BoxWrapper::FeedPass(const std::vector<uint64_t>& feasgin_to_box) const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->FeedPass(GetDate(), feasgin_to_box);
PADDLE_ENFORCE_EQ(ret, 0, "FeedPass failed in BoxPS.");
#endif
}
void BoxWrapper::BeginPass() const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->BeginPass();
PADDLE_ENFORCE_EQ(ret, 0, "BeginPass failed in BoxPS.");
#endif
}
void BoxWrapper::EndPass() const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->EndPass();
PADDLE_ENFORCE_EQ(ret, 0, "EndPass failed in BoxPS.");
#endif
}
void BoxWrapper::PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
#ifdef PADDLE_WITH_BOX_PS
if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) {
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
LoDTensor total_keys_tensor;
int64_t* total_keys =
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place);
int64_t offset = 0;
for (size_t i = 0; i < keys.size(); ++i) {
if (platform::is_cpu_place(place)) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(boost::get<platform::CUDAPlace>(place),
total_keys + offset,
boost::get<platform::CUDAPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t), nullptr);
#else
PADDLE_THROW(
"Please compile WITH_GPU option, and NCCL doesn't support "
"windows.");
#endif
}
offset += slot_lengths[i];
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PullSparse: total feasign keys length "
"should be equal to the sum of length of all input "
"tensors.");
// Space allocation for FeatureValue is left for boxps
paddle::boxps::FeatureValue* total_values;
if (platform::is_cpu_place(place)) {
int ret = boxps_ptr_->PullSparseCPU(
reinterpret_cast<uint64_t*>(total_keys), &total_values,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, "PullSparseCPU failed in BoxPS.");
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ret = boxps_ptr_->PullSparseGPU(
reinterpret_cast<uint64_t*>(total_keys), &total_values,
static_cast<int>(total_length),
boost::get<platform::CUDAPlace>(place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, "PullSparseGPU failed in BoxPS.");
#endif
}
offset = 0;
for (size_t i = 0; i < values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb from BoxPS to paddle tensor. Since 'show','click','emb'
// are continuous in memory, so we copy here using the 'show' address
if (platform::is_cpu_place(place)) {
memory::Copy(
boost::get<platform::CPUPlace>(place),
values[i] + j * hidden_size,
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_values + offset)->show)),
sizeof(float) * hidden_size);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(
boost::get<platform::CUDAPlace>(place),
values[i] + j * hidden_size,
boost::get<platform::CUDAPlace>(place),
reinterpret_cast<float*>(&((total_values + offset)->show)),
sizeof(float) * hidden_size, nullptr);
#endif
}
++offset;
}
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PullSparse: total emb values length should "
"be equal to the sum of length of all input tensors.");
} else {
PADDLE_THROW(
"PaddleBox: PullSparse Only Support CPUPlace and CUDAPlace Now.");
}
#endif
}
void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
#ifdef PADDLE_WITH_BOX_PS
if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) {
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
LoDTensor total_keys_tensor;
int64_t* total_keys =
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place);
int64_t offset = 0;
for (size_t i = 0; i < keys.size(); ++i) {
if (platform::is_cpu_place(place)) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(boost::get<platform::CUDAPlace>(place),
total_keys + offset,
boost::get<platform::CUDAPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t), nullptr);
#else
PADDLE_THROW(
"Please compile WITH_GPU option, and for now NCCL doesn't support "
"windows.");
#endif
}
offset += slot_lengths[i];
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PushSparseGrad: total feasign keys length "
"should be equal to the sum of length of all input "
"tensors.");
auto buf = memory::AllocShared(
place, total_length * sizeof(paddle::boxps::FeaturePushValue));
paddle::boxps::FeaturePushValue* total_grad_values =
reinterpret_cast<paddle::boxps::FeaturePushValue*>(buf->ptr());
offset = 0;
for (size_t i = 0; i < grad_values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb grad from paddle tensor to BoxPS. Since
// 'show','click','emb' are continuous in memory, so we copy here using
// the 'show' address
if (platform::is_cpu_place(place)) {
memory::Copy(
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_grad_values + offset)->show)),
boost::get<platform::CPUPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(
boost::get<platform::CUDAPlace>(place),
reinterpret_cast<float*>(&((total_grad_values + offset)->show)),
boost::get<platform::CUDAPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size,
nullptr);
#endif
}
++offset;
}
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PushSparseGrad: total emb grad values "
"length should be equal to the sum of length of all "
"input tensors.");
if (platform::is_cpu_place(place)) {
int ret = boxps_ptr_->PushSparseCPU(
reinterpret_cast<uint64_t*>(total_keys), total_grad_values,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, "PushSparseCPU failed in BoxPS.");
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ret = boxps_ptr_->PushSparseGPU(
reinterpret_cast<uint64_t*>(total_keys), total_grad_values,
static_cast<int>(total_length),
boost::get<platform::CUDAPlace>(place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, "PushSparseGPU failed in BoxPS.");
#endif
}
} else {
PADDLE_THROW(
"PaddleBox: PushSparse Only Support CPUPlace and CUDAPlace Now.");
}
#endif
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#ifdef PADDLE_WITH_BOX_PS
#include <boxps.h>
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class BoxWrapper {
public:
virtual ~BoxWrapper() {}
BoxWrapper() {}
void FeedPass(const std::vector<uint64_t>& feasgin_to_box) const;
void BeginPass() const;
void EndPass() const;
void PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size);
void PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size);
static std::shared_ptr<BoxWrapper> GetInstance() {
if (nullptr == s_instance_) {
// If main thread is guaranteed to init this, this lock can be removed
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (nullptr == s_instance_) {
s_instance_.reset(new paddle::framework::BoxWrapper());
#ifdef PADDLE_WITH_BOX_PS
s_instance_->boxps_ptr_.reset(new paddle::boxps::FakeBoxPS());
#endif
}
}
return s_instance_;
}
private:
#ifdef PADDLE_WITH_BOX_PS
static std::shared_ptr<paddle::boxps::BoxPSBase> boxps_ptr_;
#endif
static std::shared_ptr<BoxWrapper> s_instance_;
int GetDate() const;
};
class BoxHelper {
public:
explicit BoxHelper(paddle::framework::Dataset* dataset) : dataset_(dataset) {}
virtual ~BoxHelper() {}
void BeginPass() {
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->BeginPass();
}
void EndPass() {
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->EndPass();
}
void LoadIntoMemory() {
dataset_->LoadIntoMemory();
FeedPass();
}
void PreLoadIntoMemory() {
dataset_->PreLoadIntoMemory();
feed_data_thread_.reset(new std::thread([&]() {
dataset_->WaitPreLoadDone();
FeedPass();
}));
}
void WaitFeedPassDone() { feed_data_thread_->join(); }
private:
Dataset* dataset_;
std::shared_ptr<std::thread> feed_data_thread_;
// notify boxps to feed this pass feasigns from SSD to memory
void FeedPass() {
auto box_ptr = BoxWrapper::GetInstance();
auto input_channel_ =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetInputChannel();
std::vector<Record> pass_data;
std::vector<uint64_t> feasign_to_box;
input_channel_->ReadAll(pass_data);
for (const auto& ins : pass_data) {
const auto& feasign_v = ins.uint64_feasigns_;
for (const auto feasign : feasign_v) {
feasign_to_box.push_back(feasign.sign().uint64_feasign_);
}
}
input_channel_->Open();
input_channel_->Write(pass_data);
input_channel_->Close();
box_ptr->FeedPass(feasign_to_box);
}
};
} // end namespace framework
} // end namespace paddle
...@@ -401,7 +401,9 @@ void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id, ...@@ -401,7 +401,9 @@ void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id,
std::vector<std::string> var_list, std::vector<std::string> var_list,
std::string model_path, std::string model_path,
std::string model_proto_file, std::string model_proto_file,
std::vector<std::string> table_var_list,
bool load_combine) { bool load_combine) {
#ifdef PADDLE_WITH_PSLIB
// load ProgramDesc from model file // load ProgramDesc from model file
auto read_proto_func = [](const std::string& filename) -> ProgramDesc { auto read_proto_func = [](const std::string& filename) -> ProgramDesc {
std::string contents; std::string contents;
...@@ -467,7 +469,8 @@ void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id, ...@@ -467,7 +469,8 @@ void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id,
} }
} }
delete old_scope; delete old_scope;
PushDenseParamSync(scope, table_id, old_param_list); PushDenseParamSync(scope, table_id, table_var_list);
#endif
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModel(const std::string& path, const int mode) {
...@@ -512,6 +515,57 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { ...@@ -512,6 +515,57 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#endif #endif
} }
double FleetWrapper::GetCacheThreshold() {
#ifdef PADDLE_WITH_PSLIB
double cache_threshold = 0.0;
auto ret = pslib_ptr_->_worker_ptr->flush();
ret.wait();
ret = pslib_ptr_->_worker_ptr->get_cache_threshold(0, cache_threshold);
ret.wait();
if (cache_threshold < 0) {
LOG(ERROR) << "get cache threshold failed";
exit(-1);
}
return cache_threshold;
#else
VLOG(0) << "FleetWrapper::GetCacheThreshold does nothing when no pslib";
return 0.0;
#endif
}
void FleetWrapper::CacheShuffle(int table_id, const std::string& path,
const int mode, const double cache_threshold) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->cache_shuffle(
0, path, std::to_string(mode), std::to_string(cache_threshold));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "cache shuffle failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::CacheShuffle does nothing when no pslib";
#endif
}
int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save_cache(0, path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib";
return -1;
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) { void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id); auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
......
...@@ -136,6 +136,7 @@ class FleetWrapper { ...@@ -136,6 +136,7 @@ class FleetWrapper {
void LoadFromPaddleModel(Scope& scope, const uint64_t table_id, // NOLINT void LoadFromPaddleModel(Scope& scope, const uint64_t table_id, // NOLINT
std::vector<std::string> var_list, std::vector<std::string> var_list,
std::string model_path, std::string model_proto_file, std::string model_path, std::string model_proto_file,
std::vector<std::string> table_var_list,
bool load_combine); bool load_combine);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff // mode = 1, laod delta feature, which means load diff
...@@ -148,7 +149,13 @@ class FleetWrapper { ...@@ -148,7 +149,13 @@ class FleetWrapper {
// mode = 1, save delta feature, which means save diff // mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode); void SaveModel(const std::string& path, const int mode);
double GetCacheThreshold();
void CacheShuffle(int table_id, const std::string& path, const int mode,
const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode);
void ClearModel(); void ClearModel();
void ShrinkSparseTable(int table_id); void ShrinkSparseTable(int table_id);
void ShrinkDenseTable(int table_id, Scope* scope, void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay, std::vector<std::string> var_list, float decay,
......
...@@ -25,29 +25,21 @@ ...@@ -25,29 +25,21 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
DECLARE_double(eager_delete_tensor_gb);
DECLARE_double(memory_fraction_of_eager_deletion);
DECLARE_bool(fast_eager_deletion_mode);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
GarbageCollector::GarbageCollector(const platform::Place &place, GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size) size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) { : max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new GarbageQueue()); garbages_.reset(new GarbageQueue());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
if (max_memory_size_ > 1) {
mutex_.reset(new std::mutex());
}
} }
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place, CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
......
...@@ -46,7 +46,7 @@ class GarbageCollector { ...@@ -46,7 +46,7 @@ class GarbageCollector {
platform::DeviceContext *dev_ctx_; platform::DeviceContext *dev_ctx_;
std::unique_ptr<GarbageQueue> garbages_; std::unique_ptr<GarbageQueue> garbages_;
mutable std::mutex mutex_; mutable std::unique_ptr<std::mutex> mutex_;
const size_t max_memory_size_; const size_t max_memory_size_;
size_t cur_memory_size_{0}; size_t cur_memory_size_{0};
}; };
...@@ -118,7 +118,7 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) { ...@@ -118,7 +118,7 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) {
GarbageQueue *garbage_queue = nullptr; GarbageQueue *garbage_queue = nullptr;
{ {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(*mutex_);
for (auto &obj : objs) { for (auto &obj : objs) {
if (!obj) continue; if (!obj) continue;
cur_memory_size_ += obj->size(); cur_memory_size_ += obj->size();
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
...@@ -20,7 +21,7 @@ limitations under the License. */ ...@@ -20,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void HogwildWorker::Initialize(const TrainerDesc& desc) { void HogwildWorker::Initialize(const TrainerDesc &desc) {
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param(); param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size()); skip_ops_.resize(param_.skip_ops_size());
...@@ -30,45 +31,70 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) { ...@@ -30,45 +31,70 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
auto& block = program.Block(0); auto &block = program.Block(0);
op_names_.clear(); op_names_.clear();
for (auto& op_desc : block.AllOps()) { for (auto &op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc); std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type()); op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release(); OperatorBase *local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr); ops_.push_back(local_op_ptr);
continue; continue;
} }
} }
void HogwildWorker::CreateThreadScope(const ProgramDesc& program) { void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
auto& block = program.Block(0); auto &block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope"); root_scope_, "root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
for (auto &var : block.AllVars()) {
if (var->Persistable()) { if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name()); auto *ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
thread_id_ != 0) {
int tensor_dim =
root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
auto *ptr1 = thread_scope_->Var(var->Name());
InitializeVariable(ptr1, var->GetType());
LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
LoDTensor *root_tensor =
root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type) \
do { \
if (root_tensor->type() == proto_type) { \
SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
} \
} while (0)
_ForEachDataType_(MemsetCallback);
}
} else { } else {
auto* ptr = thread_scope_->Var(var->Name()); auto *ptr = thread_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
} }
} }
} }
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
int tensor_dim) {
T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
memset(ptr, 0, sizeof(T) * tensor_dim);
}
void HogwildWorker::BindingDataFeedMemory() { void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = const std::vector<std::string> &input_feed =
device_reader_->GetUseSlotAlias(); device_reader_->GetUseSlotAlias();
for (auto name : input_feed) { for (auto name : input_feed) {
device_reader_->AddFeedVar(thread_scope_->FindVar(name), name); device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
} }
} }
void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) { void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
CreateThreadScope(main_prog); CreateThreadScope(main_prog);
CreateThreadOperators(main_prog); CreateThreadOperators(main_prog);
} }
...@@ -78,7 +104,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -78,7 +104,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
device_reader_->Start(); device_reader_->Start();
std::vector<double> op_total_time; std::vector<double> op_total_time;
std::vector<std::string> op_name; std::vector<std::string> op_name;
for (auto& op : ops_) { for (auto &op : ops_) {
op_name.push_back(op->Type()); op_name.push_back(op->Type());
} }
op_total_time.resize(ops_.size()); op_total_time.resize(ops_.size());
...@@ -141,7 +167,7 @@ void HogwildWorker::TrainFiles() { ...@@ -141,7 +167,7 @@ void HogwildWorker::TrainFiles() {
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
for (auto& op : ops_) { for (auto &op : ops_) {
bool need_skip = false; bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) { for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) { if (op->Type().find(skip_ops_[t]) != std::string::npos) {
......
...@@ -53,5 +53,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference { ...@@ -53,5 +53,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference {
} }
}; };
#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \
std::unordered_map<std::string, std::string> operator()( \
const ::paddle::framework::OpDesc& op_desc, \
bool use_cuda) const final { \
return {__VA_ARGS__}; \
} \
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -194,7 +194,8 @@ std::shared_ptr<FILE> shell_popen(const std::string& cmd, ...@@ -194,7 +194,8 @@ std::shared_ptr<FILE> shell_popen(const std::string& cmd,
<< ", err_no[" << *err_no << "]"; << ", err_no[" << *err_no << "]";
} }
if (wstatus == -1 && errno == ECHILD) { if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD"; // temporarily remove this warning
// LOG(WARNING) << "errno is ECHILD";
} }
}}; }};
#endif #endif
...@@ -285,7 +286,8 @@ std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open( ...@@ -285,7 +286,8 @@ std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
<< "status[" << wstatus << "], cmd[" << cmd << "]"; << "status[" << wstatus << "], cmd[" << cmd << "]";
if (wstatus == -1 && errno == ECHILD) { if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD"; // temporarily remove this warning
// LOG(WARNING) << "errno is ECHILD";
} }
}}; }};
......
...@@ -12,21 +12,14 @@ unset(INFER_IR_PASSES CACHE) # clear the global variable ...@@ -12,21 +12,14 @@ unset(INFER_IR_PASSES CACHE) # clear the global variable
function(pass_library TARGET DEST) function(pass_library TARGET DEST)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS DIR)
set(targetPrefix "") set(targetPrefix "")
# Get optional argument cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(extraMacroArgs ${ARGN}) if(pass_library_DIR)
list(LENGTH extraMacroArgs numExtraMacroArgs) cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
if(numExtraMacroArgs GREATER 0)
list(GET extraMacroArgs 0 targetPrefix)
endif()
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(targetPrefix)
cc_library(${TARGET} SRCS ${targetPrefix}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
else() else()
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS}) cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
endif() endif()
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically. # add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
...@@ -44,6 +37,7 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper) ...@@ -44,6 +37,7 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper) cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
...@@ -52,7 +46,6 @@ pass_library(graph_viz_pass base) ...@@ -52,7 +46,6 @@ pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base) pass_library(lock_free_optimize_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
...@@ -61,6 +54,7 @@ pass_library(multi_batch_merge_pass base) ...@@ -61,6 +54,7 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference)
pass_library(seqpool_cvm_concat_fuse_pass inference)
pass_library(repeated_fc_relu_fuse_pass inference) pass_library(repeated_fc_relu_fuse_pass inference)
pass_library(squared_mat_sub_fuse_pass inference) pass_library(squared_mat_sub_fuse_pass inference)
pass_library(is_test_pass base) pass_library(is_test_pass base)
...@@ -76,23 +70,26 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) ...@@ -76,23 +70,26 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference) pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
if(WITH_GPU)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
endif()
if(ANAKIN_SUBGRAPH) if(ANAKIN_SUBGRAPH)
pass_library(simplify_anakin_priorbox_detection_out_pass inference) pass_library(simplify_anakin_priorbox_detection_out_pass inference)
endif() endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn) pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn)
endif() endif()
if(WITH_NGRAPH) if(WITH_NGRAPH)
...@@ -118,15 +115,19 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph ...@@ -118,15 +115,19 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
if(WITH_GPU)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
endif()
if(NOT WIN32) if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif() endif()
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
return *desc.Proto();
}
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, new_conv_op); // ResidualData
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
#define WARPCTC_LIB_PATH "@WARPCTC_INSTALL_DIR@/lib/" REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass)
.RequirePassAttr("cudnn_enabled_op_types");
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/placement_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Specifies which operators should use cuDNN.
*/
class CUDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "cuDNN"; }
const std::string GetAttrName() const { return "use_cudnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
void RegisterOpKernel() {
static bool is_registered = false;
if (!is_registered) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
platform::CUDAPlace place = platform::CUDAPlace(0);
OpKernelType plain_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kPlain);
OpKernelType cudnn_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kCUDNN);
auto fake_kernel_func = [](const ExecutionContext&) -> void {
static int num_calls = 0;
num_calls++;
};
all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
all_kernels["relu"][plain_kernel_type] = fake_kernel_func;
is_registered = true;
}
}
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
unsigned expected_use_cudnn_true_count) {
// operator use_cudnn
// --------------------------------------------------
// (a,b)->concat->c -
// (c,weights,bias)->conv2d->f false
// f->relu->g -
// g->pool2d->h false
// (h,weights2,bias2)->depthwise_conv2d->k false
// k->relu->l -
Layers layers;
VarDesc* a = layers.data("a");
VarDesc* b = layers.data("b");
VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
VarDesc* weights_0 = layers.data("weights_0");
VarDesc* bias_0 = layers.data("bias_0");
VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
VarDesc* g = layers.relu(f);
VarDesc* h = layers.pool2d(g, false);
VarDesc* weights_1 = layers.data("weights_1");
VarDesc* bias_1 = layers.data("bias_1");
VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
layers.relu(k);
RegisterOpKernel();
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>(cudnn_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
unsigned use_cudnn_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
auto* op = node->Op();
if (op->HasAttr("use_cudnn") &&
boost::get<bool>(op->GetAttr("use_cudnn"))) {
++use_cudnn_true_count;
}
}
}
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
}
TEST(CUDNNPlacementPass, enable_conv2d) {
// 1 conv2d
MainTest({"conv2d"}, 1);
}
TEST(CUDNNPlacementPass, enable_relu_pool) {
// 1 conv2d + 1 pool2d
MainTest({"conv2d", "pool2d"}, 2);
}
TEST(CUDNNPlacementPass, enable_all) {
// 1 conv2d + 1 pool2d
// depthwise_conv2d doesnot have CUDNN kernel.
MainTest({}, 2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cudnn_placement_pass);
...@@ -32,19 +32,63 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -32,19 +32,63 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"}; return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
} }
void FuseOptimizerOps( ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set, &aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph); auto fused_adam_node =
FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"), FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph);
adam_ops, graph); auto fused_scale1 =
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"), FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"),
adam_ops, graph); adam_ops, graph);
auto fused_scale2 =
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
RemoveCycleDepsBetweenOpNodes(graph, fused_scale1, fused_scale2);
return fused_adam_node;
} }
void FuseAdamOps( void RemoveCycleDepsBetweenOpNodes(Graph *graph, const Node *fused_scale1,
const Node *fused_scale2) const {
std::unordered_set<Node *> not_need_ctrl_var_nodes;
std::unordered_set<Node *> fused_scale2_in_nodes;
fused_scale2_in_nodes.insert(fused_scale2->inputs.begin(),
fused_scale2->inputs.end());
for (auto &out_node : fused_scale1->outputs) {
if (fused_scale2_in_nodes.count(out_node)) {
PADDLE_ENFORCE(out_node->IsCtrlVar(),
"The dependency var only should be ctrl var.");
not_need_ctrl_var_nodes.insert(out_node);
}
}
for (auto &node : not_need_ctrl_var_nodes) {
// remove this node from the input op node.
PADDLE_ENFORCE(!node->inputs.empty(),
"The input should not be empty here.");
auto op_node = node->inputs.front();
PADDLE_ENFORCE(op_node->IsOp());
op_node->outputs.erase(
remove_if(
op_node->outputs.begin(), op_node->outputs.end(),
[&node](const Node *op_out_node) { return op_out_node == node; }),
op_node->outputs.end());
// remove this node from the output op nodes.
for (auto &out_op_node : node->outputs) {
out_op_node->inputs.erase(
remove_if(
out_op_node->inputs.begin(), out_op_node->inputs.end(),
[&node](const Node *op_in_node) { return op_in_node == node; }),
out_op_node->inputs.end());
}
graph->RemoveNode(node);
}
}
ir::Node *FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
...@@ -102,16 +146,13 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -102,16 +146,13 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
adam_desc.SetAttr("min_row_size_to_use_multithread", adam_desc.SetAttr("min_row_size_to_use_multithread",
min_row_size_to_use_multithread); min_row_size_to_use_multithread);
adam_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role); adam_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
return graph->CreateOpNode(&adam_desc);
auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node);
} }
void FuseScaleOps(const std::vector<std::string> &beta_name, ir::Node *FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name, const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops, const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const { ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size()); PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size());
const std::string scale_op_name = "scale"; const std::string scale_op_name = "scale";
...@@ -139,7 +180,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -139,7 +180,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
scale_ops.emplace_back(*scale_op_iter); scale_ops.emplace_back(*scale_op_iter);
} }
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size()); PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size());
VLOG(7) << "The number of scale op is " << scale_ops.size() << ".";
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>( int op_role = boost::get<int>(
...@@ -175,29 +216,12 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -175,29 +216,12 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
scale_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role); scale_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto scale_node = graph->CreateOpNode(&scale_desc); auto scale_node = graph->CreateOpNode(&scale_desc);
for (auto scale_op : scale_ops) { InsertInputAndOutputForFusedOpNode(scale_ops, graph, scale_node);
// set inputs
scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(),
scale_op->inputs.end());
for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node);
}
// set outputs
scale_node->outputs.insert(scale_node->outputs.begin(),
scale_op->outputs.begin(),
scale_op->outputs.end());
for (auto &output : scale_op->outputs) {
std::replace(output->inputs.begin(), output->inputs.end(), scale_op,
scale_node);
}
}
// Delete scale_ops // Delete scale_ops
for (auto &scale_op : scale_ops) { for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op); graph->RemoveNode(scale_op);
} }
return scale_node;
} }
}; };
} // namespace ir } // namespace ir
......
...@@ -33,7 +33,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -33,7 +33,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
} }
// Fuse Momentum Ops // Fuse Momentum Ops
virtual void FuseOptimizerOps( virtual ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const {
...@@ -77,9 +77,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -77,9 +77,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
momentum_desc.SetAttr("use_nesterov", use_nesterov); momentum_desc.SetAttr("use_nesterov", use_nesterov);
momentum_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role); momentum_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto momentum_node = graph->CreateOpNode(&momentum_desc); return graph->CreateOpNode(&momentum_desc);
InserInputAndOutputForOptOps(momentum_ops, momentum_node);
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h" #include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include <algorithm> #include <algorithm>
#include <set>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -59,6 +60,15 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -59,6 +60,15 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
} }
return; return;
} }
// There should not have no-ctr-var between the op_nodes that link the op_node
// of op_nodes.
if (HasVarDepsBetweenOps(topo_nodes, opt_nodes)) {
VLOG(6) << "There are interdependent variables among these optimization "
"operators, which can not be handled well at present.";
return;
}
result.Set(details::kFusedOptType, new details::FusedOptType); result.Set(details::kFusedOptType, new details::FusedOptType);
result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type; result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
if (!result.Has(details::kProgramDescs)) { if (!result.Has(details::kProgramDescs)) {
...@@ -158,14 +168,54 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -158,14 +168,54 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
&result); &result);
// Step 5: Fuse optimizer Ops and Scale Ops // Step 5: Fuse optimizer Ops and Scale Ops
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_nodes, &result); auto *fused_opt_node =
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_nodes, &result);
InsertInputAndOutputForFusedOpNode(opt_nodes, graph, fused_opt_node);
// Step 6: Remove optimizer Ops // Step 6: Remove optimizer Ops
for (auto &opt_op : opt_nodes) { for (auto &opt_op : opt_nodes) {
graph->RemoveNode(opt_op); graph->RemoveNode(opt_op);
} }
} }
bool FuseOptimizerOpPass::HasVarDepsBetweenOps(
const std::vector<Node *> &topo_nodes,
const std::vector<Node *> &opt_nodes) const {
std::unordered_map<Node *, std::unordered_set<Node *>> preceding_ops;
std::unordered_map<Node *, std::unordered_set<Node *>> pending_ops;
for (auto &op : topo_nodes) {
preceding_ops[op];
pending_ops[op];
for (auto &var : op->outputs) {
if (var->IsCtrlVar()) continue;
for (auto &pending_op : var->outputs) {
preceding_ops[pending_op].insert(op);
pending_ops[op].insert(pending_op);
}
}
}
std::unordered_set<Node *> opt_node_set(opt_nodes.begin(), opt_nodes.end());
auto has_var_deps = [](const std::unordered_set<Node *> &op_set1,
const std::unordered_set<Node *> &op_set2) -> bool {
std::set<Node *> intersect_ops;
set_intersection(op_set1.begin(), op_set1.end(), op_set2.begin(),
op_set2.end(),
inserter(intersect_ops, intersect_ops.begin()));
return !intersect_ops.empty();
};
for (auto opt_node : opt_node_set) {
if (has_var_deps(preceding_ops.at(opt_node), opt_node_set)) {
return true;
}
if (has_var_deps(pending_ops.at(opt_node), opt_node_set)) {
return true;
}
}
return false;
}
void FuseOptimizerOpPass::GradientsFilter( void FuseOptimizerOpPass::GradientsFilter(
const std::vector<size_t> &new_grad_idx, std::vector<Node *> *opt_nodes, const std::vector<size_t> &new_grad_idx, std::vector<Node *> *opt_nodes,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set) std::unordered_map<std::string, std::vector<std::string>> *aux_var_set)
...@@ -338,26 +388,84 @@ void FuseOptimizerOpPass::AppendAllocContinuousSpace( ...@@ -338,26 +388,84 @@ void FuseOptimizerOpPass::AppendAllocContinuousSpace(
op_desc->SetAttr("check_name", check_name); op_desc->SetAttr("check_name", check_name);
} }
void FuseOptimizerOpPass::InserInputAndOutputForOptOps( void FuseOptimizerOpPass::InsertInputAndOutputForFusedOpNode(
const std::vector<ir::Node *> &opt_nodes, ir::Node *opt_node) const { const std::vector<ir::Node *> &op_nodes, ir::Graph *graph,
ir::Node *fused_opt_node) const {
std::unordered_set<ir::Node *> inputs; std::unordered_set<ir::Node *> inputs;
std::unordered_set<ir::Node *> outputs; std::unordered_set<ir::Node *> outputs;
for (auto opt_op : opt_nodes) { for (auto opt_op : op_nodes) {
// set inputs
inputs.insert(opt_op->inputs.begin(), opt_op->inputs.end()); inputs.insert(opt_op->inputs.begin(), opt_op->inputs.end());
for (auto &input : opt_op->inputs) { for (auto &input : opt_op->inputs) {
replace(input->outputs.begin(), input->outputs.end(), opt_op, opt_node); replace(input->outputs.begin(), input->outputs.end(), opt_op,
fused_opt_node);
} }
// set outputs
outputs.insert(opt_op->outputs.begin(), opt_op->outputs.end()); outputs.insert(opt_op->outputs.begin(), opt_op->outputs.end());
for (auto &output : opt_op->outputs) { for (auto &output : opt_op->outputs) {
replace(output->inputs.begin(), output->inputs.end(), opt_op, opt_node); replace(output->inputs.begin(), output->inputs.end(), opt_op,
fused_opt_node);
}
}
// Remove the dependence vars between op_nodes.
std::unordered_set<ir::Node *> out_dep_vars;
std::unordered_set<ir::Node *> not_useful_vars;
auto deal_with_ctrl_vars = [&out_dep_vars, &not_useful_vars,
&fused_opt_node](ir::Node *ctr_var_node) {
PADDLE_ENFORCE_EQ(ctr_var_node->inputs.size(), 1);
if (ctr_var_node->inputs.front() == fused_opt_node) {
PADDLE_ENFORCE_GT(ctr_var_node->outputs.size(), 0);
auto output_ops = ctr_var_node->outputs;
output_ops.erase(std::remove_if(output_ops.begin(), output_ops.end(),
[&fused_opt_node](const ir::Node *node) {
return node == fused_opt_node;
}),
output_ops.end());
if (!output_ops.empty()) {
out_dep_vars.insert(ctr_var_node);
}
not_useful_vars.insert(ctr_var_node);
} }
};
for (auto *in_node : inputs) {
if (in_node->IsCtrlVar()) {
deal_with_ctrl_vars(in_node);
}
}
for (auto *out_node : outputs) {
if (out_node->IsCtrlVar()) {
deal_with_ctrl_vars(out_node);
}
}
for (auto &node : not_useful_vars) {
if (inputs.count(node)) {
inputs.erase(node);
}
if (outputs.count(node)) {
outputs.erase(node);
}
}
for (auto &dep_var : out_dep_vars) {
if (not_useful_vars.count(dep_var)) {
not_useful_vars.erase(dep_var);
}
dep_var->inputs.clear();
dep_var->inputs.emplace_back(fused_opt_node);
}
outputs.insert(out_dep_vars.begin(), out_dep_vars.end());
fused_opt_node->inputs.insert(fused_opt_node->inputs.begin(), inputs.begin(),
inputs.end());
fused_opt_node->outputs.insert(fused_opt_node->outputs.begin(),
outputs.begin(), outputs.end());
for (auto &ctrl_var_node : not_useful_vars) {
graph->RemoveNode(ctrl_var_node);
} }
opt_node->inputs.insert(opt_node->inputs.begin(), inputs.begin(),
inputs.end());
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -41,15 +41,16 @@ class FuseOptimizerOpPass : public ir::Pass { ...@@ -41,15 +41,16 @@ class FuseOptimizerOpPass : public ir::Pass {
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set, std::unordered_map<std::string, std::vector<std::string>> *aux_var_set,
std::vector<ir::Node *> *ops) const; std::vector<ir::Node *> *ops) const;
void InserInputAndOutputForOptOps(const std::vector<ir::Node *> &opt_ops, void InsertInputAndOutputForFusedOpNode(
ir::Node *opt_node) const; const std::vector<ir::Node *> &opt_ops, ir::Graph *graph,
ir::Node *opt_node) const;
private: private:
virtual const std::string GetOpType() const = 0; virtual const std::string GetOpType() const = 0;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const = 0; virtual const std::vector<std::string> GetAuxiliaryVarNames() const = 0;
virtual void FuseOptimizerOps( virtual ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0; const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0;
...@@ -91,6 +92,9 @@ class FuseOptimizerOpPass : public ir::Pass { ...@@ -91,6 +92,9 @@ class FuseOptimizerOpPass : public ir::Pass {
*aux_var_set) const; *aux_var_set) const;
bool IsLoDTensorType(const proto::VarType::Type &type) const; bool IsLoDTensorType(const proto::VarType::Type &type) const;
bool HasVarDepsBetweenOps(const std::vector<Node *> &topo_nodes,
const std::vector<Node *> &opt_nodes) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -31,7 +31,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -31,7 +31,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
} }
// Fuse Sgd Ops // Fuse Sgd Ops
virtual void FuseOptimizerOps( virtual ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
...@@ -56,9 +56,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -56,9 +56,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
// NOTE: multi_devices_pass requires that every op should have a role. // NOTE: multi_devices_pass requires that every op should have a role.
Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role); Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto sgd_node = graph->CreateOpNode(&Sgd_desc); return graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
} }
}; };
} // namespace ir } // namespace ir
......
...@@ -200,12 +200,7 @@ class Graph { ...@@ -200,12 +200,7 @@ class Graph {
// WARN: After a series of passes, the current graph can be quite // WARN: After a series of passes, the current graph can be quite
// different from OriginProgram. Caller shouldn't assume much from // different from OriginProgram. Caller shouldn't assume much from
// the returned OriginProgram. // the returned OriginProgram.
const ProgramDesc &OriginProgram() const { const ProgramDesc &OriginProgram() const { return program_; }
LOG(WARNING) << "WARN: After a series of passes, the current graph can be "
"quite different from OriginProgram. So, please avoid "
"using the `OriginProgram()` method!";
return program_;
}
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
......
...@@ -771,58 +771,33 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, ...@@ -771,58 +771,33 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return bn_out_var; return bn_out_var;
} }
PDNode *patterns::ConvReLU::operator()( PDNode *patterns::ConvActivation::operator()(
paddle::framework::ir::PDNode *conv_input) { paddle::framework::ir::PDNode *conv_input, std::string conv_type,
std::string activation_type) {
// Create Operators // Create Operators
conv_input->assert_is_op_input("conv2d", "Input"); conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); auto *activation_op =
// Create variables pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu");
// output
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
PDNode *patterns::ConvBReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6");
// Create variables // Create variables
// Filter // Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op("conv2d") ->assert_is_only_output_of_op(conv_type)
->assert_is_op_input("relu6"); ->assert_is_op_input(activation_type);
// output // output
auto *brelu_out_var = pattern->NewNode(brelu_out_repr()) auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("relu6"); ->assert_is_op_output(activation_type);
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var}); activation_op->LinksFrom({conv_out_var}).LinksTo({activation_out_var});
return brelu_out_var; return activation_out_var;
} }
PDNode *patterns::SeqConvEltAddRelu::operator()( PDNode *patterns::SeqConvEltAddRelu::operator()(
...@@ -1296,6 +1271,41 @@ PDNode *patterns::ConvConcatReLU::operator()() { ...@@ -1296,6 +1271,41 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out; return relu_out;
} }
PDNode *patterns::ConvRequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
conv_op->LinksTo({conv_out});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out});
return requant_out;
}
PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
conv_op->LinksTo({conv_out});
dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::PriorBox::operator()() { PDNode *patterns::PriorBox::operator()() {
auto prior_box_op = auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
...@@ -431,46 +431,26 @@ struct ConvBN : public PatternBase { ...@@ -431,46 +431,26 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE(bn_saved_variance); PATTERN_DECL_NODE(bn_saved_variance);
}; };
// CONV with ReLU // Conv with Activation
// op: conv + relu // op: conv + activation
// named nodes: // named nodes:
// conv_input, conv_weight, // conv_input, conv_weight,
// conv_out, conv, // conv_out, conv,
// relu_out, relu // activation_out, activation
struct ConvReLU : public PatternBase { struct ConvActivation : public PatternBase {
ConvReLU(PDPattern* pattern, const std::string& name_scope) ConvActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_relu") {} : PatternBase(pattern, name_scope, "conv_activation") {}
PDNode* operator()(PDNode* conv_input); PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d",
std::string activation_type = "relu");
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(relu); PATTERN_DECL_NODE(activation);
// declare variable node's name // declare variable node's name
PATTERN_DECL_NODE(conv_weight); PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out); PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(activation_out);
};
// CONV with ReLU6
// op: conv + relu6
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// relu6_out, relu6
struct ConvBReLU : public PatternBase {
ConvBReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bounded_relu") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(brelu);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(brelu_out);
}; };
// SEQCONV with Elementwise_Add ReLU // SEQCONV with Elementwise_Add ReLU
...@@ -811,6 +791,40 @@ struct ConvConcatReLU : public PatternBase { ...@@ -811,6 +791,40 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
// Conv + Requant
// named nodes:
// conv_op, conv_out
// requant_op, requant_out
struct ConvRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// Conv + Dequant
// named nodes:
// conv_op, conv_out
// dequant_op, dequant_out
struct ConvDequant : public PatternBase {
ConvDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// PriorBox operator // PriorBox operator
// operator: prior_box_op // operator: prior_box_op
// inputs: prior_box_input, prior_box_image // inputs: prior_box_input, prior_box_image
......
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass) cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,50 +12,44 @@ ...@@ -12,50 +12,44 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm> #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/op_variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class InferCleanGraphPass : public FusePassBase { class ConditionalOpEagerDeletionPass : public Pass {
public:
virtual ~InferCleanGraphPass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const { void ApplyImpl(Graph *graph) const override {
FusePassBase::Init("original_graph", graph); auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
PADDLE_ENFORCE(graph);
// Find all conditional_op and conditional_grad_op
auto is_valid_node = [](Node* x) { std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); std::vector<OperatorBase *>>>
}; target_ops;
for (auto *op : all_ops) {
std::unordered_set<const Node*> invalid_nodes; auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
int valid_op = 0; if (compute_op == nullptr) continue;
for (auto* node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node); if (compute_op->Name() == "conditional_block") {
if (is_valid_node(node)) { target_ops[compute_op->GetScopeIdx()].first.emplace_back(
invalid_nodes.insert(node); compute_op->GetOp());
} else if (node->IsOp()) { } else if (compute_op->Name() == "conditional_block_grad") {
// Collect all the operators to help tracking number of operators. target_ops[compute_op->GetScopeIdx()].second.emplace_back(
++valid_op; compute_op->GetOp());
} }
} }
GraphSafeRemoveNodes(graph, invalid_nodes); for (auto &ops_pair : target_ops) {
auto &ifelse_ops = ops_pair.second.first;
AddStatis(valid_op); auto &ifelse_grad_ops = ops_pair.second.second;
} operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
ifelse_ops, ifelse_grad_ops);
void CleanEdges(std::vector<Node*>* nodes, }
const std::unordered_set<Node*>& to_remove) const {
auto it = std::remove_if(nodes->begin(), nodes->end(),
[&](Node* x) { return to_remove.count(x); });
nodes->erase(it, nodes->end());
} }
}; };
...@@ -63,5 +57,5 @@ class InferCleanGraphPass : public FusePassBase { ...@@ -63,5 +57,5 @@ class InferCleanGraphPass : public FusePassBase {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(infer_clean_graph_pass, REGISTER_PASS(conditional_block_op_eager_deletion_pass,
paddle::framework::ir::InferCleanGraphPass); paddle::framework::ir::ConditionalOpEagerDeletionPass);
...@@ -269,6 +269,11 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -269,6 +269,11 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
} }
} }
auto conditional_block_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get(
"conditional_block_op_eager_deletion_pass");
conditional_block_op_eager_deletion_pass->Apply(graph);
auto while_op_eager_deletion_pass = auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
while_op_eager_deletion_pass->Apply(graph); while_op_eager_deletion_pass->Apply(graph);
...@@ -288,5 +293,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) ...@@ -288,5 +293,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kAllPlaces) .RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector); .RequirePassAttr(paddle::framework::ir::kGarbageCollector);
USE_PASS(conditional_block_op_eager_deletion_pass);
USE_PASS(while_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass);
USE_PASS(recurrent_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass);
...@@ -58,7 +58,7 @@ class MemOptVarInfo { ...@@ -58,7 +58,7 @@ class MemOptVarInfo {
}; };
using MemOptVarInfoMapList = std::vector< using MemOptVarInfoMapList = std::vector<
std::unordered_map<std::string, std::unique_ptr<MemOptVarInfo>>>; std::unordered_map<std::string, std::shared_ptr<MemOptVarInfo>>>;
class SkipMemOptVarsGuard { class SkipMemOptVarsGuard {
public: public:
......
...@@ -100,8 +100,10 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const { ...@@ -100,8 +100,10 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const {
int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const { int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const {
auto *var_desc = GetVarDesc(var); auto *var_desc = GetVarDesc(var);
auto shapes = var_desc->GetShape(); auto shapes = var_desc->GetShape();
auto sizeof_dtype = static_cast<int64_t>(SizeOfType(var_desc->GetDataType()));
return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1), return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>()); std::multiplies<int64_t>()) *
sizeof_dtype;
} }
void MemoryReusePass::CollectShareTensorBufferOpHandles() const { void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
......
...@@ -337,6 +337,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -337,6 +337,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
for (auto iter = var_handles.rbegin(); iter != var_handles.rend(); for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) { ++iter) {
if ((*iter)->Node()->IsCtrlVar()) {
break;
}
VLOG(10) << "Try to find last living ops of " << var_name << " " VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time"; << (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure; LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -21,51 +21,77 @@ namespace paddle { ...@@ -21,51 +21,77 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const { void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph, "graph cannot be nullptr.");
FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph); FusePassBase::Init("conv_activation_mkldnn_fuse", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input") ->NewNode("conv_activation_mkldnn_fuse/conv_input")
->AsInput() ->AsInput()
->assert_is_op_input("conv2d", "Input"); ->assert_is_op_input(conv_type(), "Input");
patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(), patterns::ConvActivation conv_activation_pattern(
"conv_bounded_relu_mkldnn_fuse"); gpd.mutable_pattern(), "conv_activation_mkldnn_fuse");
conv_brelu_pattern(conv_input); conv_activation_pattern(conv_input, conv_type(), activation_type());
int found_conv_brelu_count = 0; int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvBoundedReLUFusePass fuse"; VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_brelu_pattern); // Filter conv_activation_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op conv_activation_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_activation_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
conv_activation_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(activation, activation,
conv_activation_pattern); // Activation op
// Transform Conv node into ConvBReLU node. // Transform Conv node into ConvActivation node.
OpDesc* desc = conv->Op(); OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({brelu_out->Name()})); desc->SetOutput("Output",
desc->SetAttr("fuse_brelu", true); std::vector<std::string>({activation_out->Name()}));
desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold"));
GraphSafeRemoveNodes(graph, {brelu, conv_out}); desc->SetAttr("fuse_activation", activation_type());
PADDLE_ENFORCE(subgraph.count(conv_input)); // MKLDNN ops use alpha and beta as activation parameters but paddle ops are
IR_NODE_LINK_TO(conv, brelu_out); // not generalized
found_conv_brelu_count++; if (activation_type() == "relu6") {
desc->SetAttr("fuse_alpha",
boost::get<float>(activation->Op()->GetAttr("threshold")));
} else {
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("alpha"));
}
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("beta"));
GraphSafeRemoveNodes(graph, {activation, conv_out});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
"subgraph has to contain conv_input node.");
IR_NODE_LINK_TO(conv, activation_out);
found_conv_activation_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_conv_brelu_count); AddStatis(found_conv_activation_count);
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_brelu_mkldnn_fuse_pass, REGISTER_PASS(conv_activation_mkldnn_fuse_pass,
paddle::framework::ir::ConvBReLUFusePass); paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DLeakyReLUFusePass);
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DReLU6FusePass);
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -22,18 +23,33 @@ ...@@ -22,18 +23,33 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/* /*
* Fuse the CONV and ReLU to a ConvReLUOp. * Fuse Conv and Activation base class.
*/ */
class ConvReLUFusePass : public FusePassBase { class ConvActivationFusePass : public FusePassBase {
public: public:
virtual ~ConvReLUFusePass() {} virtual ~ConvActivationFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
virtual std::string activation_type() const { return "relu"; }
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_activation_mkldnn_fuse"};
};
/*
* Fuse Conv and LeakyReLU class
*/
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
public:
std::string activation_type() const { return "leaky_relu"; }
};
/*
* Fuse Conv and BoundedReLU class
*/
class Conv2DReLU6FusePass : public ConvActivationFusePass {
public:
std::string activation_type() const { return "relu6"; }
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -23,18 +23,24 @@ namespace ir { ...@@ -23,18 +23,24 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = false) { const std::vector<std::string>& outputs, bool is_activation = false,
bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("name", name);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") { } else if (is_activation) {
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs); op->SetInput("X", inputs);
if (type == "leaky_relu") {
op->SetAttr("alpha", 0.02f);
} else if (type == "relu6") {
op->SetAttr("threshold", 6.0f);
}
} }
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
...@@ -44,15 +50,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -44,15 +50,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
// a->OP0->b // a->OP0->b
// b->OP1->c // b->OP1->c
// (c, weights, bias)->conv->f // (c, weights, bias)->conv->f
// (f)->relu->g // (f)->activation->g
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc(std::string activation) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g", std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) { "h", "weights2", "bias2", "k", "l", "m"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2") {
var->SetPersistable(true); var->SetPersistable(true);
} }
} }
...@@ -61,30 +67,33 @@ ProgramDesc BuildProgramDesc() { ...@@ -61,30 +67,33 @@ ProgramDesc BuildProgramDesc() {
std::vector<std::string>({"b"})); std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}), SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"})); std::vector<std::string>({"c"}));
// conv+relu, both with MKL-DNN // conv+activation, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1", SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}), std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), true); std::vector<std::string>({"f"}), false, true);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}), SetOp(&prog, activation, "activation1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true); std::vector<std::string>({"g"}), true, true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}), SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"})); std::vector<std::string>({"h"}));
// conv+relu, only one with MKL-DNN // conv+activation, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2", SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}), std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true); std::vector<std::string>({"k"}), false, true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}), SetOp(&prog, "activation", "activation2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"})); std::vector<std::string>({"l"}), true, false);
SetOp(&prog, "OP4", "op4", std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}));
return prog; return prog;
} }
TEST(ConvReLUFusePass, basic) { void MainTest(std::string activation) {
auto prog = BuildProgramDesc(); auto prog = BuildProgramDesc(activation);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); auto pass =
PassRegistry::Instance().Get("conv_" + activation + "_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
...@@ -92,36 +101,41 @@ TEST(ConvReLUFusePass, basic) { ...@@ -92,36 +101,41 @@ TEST(ConvReLUFusePass, basic) {
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, RELU, conv_out // Remove 3 Nodes: CONV, activation, conv_out
// Add 1 Node: ConvReLU // Add 1 Node: ConvActivation
EXPECT_EQ(original_nodes_num - 2, current_nodes_num); EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_activation op in newly generated graph
int conv_relu_count = 0; int conv_activation_count = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") { if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op(); auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn"))); EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
// check if only "conv1" convolution is fused
auto op_name = boost::get<std::string>(op->GetAttr("name")); auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op->GetAttrIfExists<std::string>("fuse_activation") == activation) {
++conv_activation_count;
}
// check if only "conv1" convolution is fused
if (op_name == "conv1") { if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_relu")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
}
} else if (op_name == "conv2") { } else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_relu")); ASSERT_FALSE(op->HasAttr("fuse_activation"));
} }
} }
} }
EXPECT_EQ(conv_relu_count, 1); EXPECT_EQ(conv_activation_count, 1);
}
TEST(ConvActivationFusePass, conv_relu_fuse_pass) { MainTest("relu"); }
TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) {
MainTest("leaky_relu");
} }
TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(conv_relu_mkldnn_fuse_pass); USE_PASS(conv_activation_mkldnn_fuse_pass);
...@@ -83,7 +83,7 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU( ...@@ -83,7 +83,7 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU(
// Transform Conv node into ConvReLU node. // Transform Conv node into ConvReLU node.
OpDesc* conv_desc = conv_op->Op(); OpDesc* conv_desc = conv_op->Op();
conv_desc->SetAttr("fuse_relu", true); conv_desc->SetAttr("fuse_activation", std::string("relu"));
// Remove ReLU when all Convs were transformed. // Remove ReLU when all Convs were transformed.
auto number_of_unfused_convs_left = auto number_of_unfused_convs_left =
......
...@@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetType(type); op->SetType(type);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("fuse_relu", false); op->SetAttr("fuse_activation", std::string(""));
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) { if (inputs.size() > 2) {
...@@ -109,8 +109,9 @@ void MainTest(const ProgramDesc& prog, bool fuse_relu) { ...@@ -109,8 +109,9 @@ void MainTest(const ProgramDesc& prog, bool fuse_relu) {
if (node->IsOp()) { if (node->IsOp()) {
auto* op = node->Op(); auto* op = node->Op();
if (op->Type() == "conv2d") { if (op->Type() == "conv2d") {
ASSERT_TRUE(op->HasAttr("fuse_relu")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
bool fuse_relu_attr = boost::get<bool>(op->GetAttr("fuse_relu")); bool fuse_relu_attr =
(boost::get<std::string>(op->GetAttr("fuse_activation")) == "relu");
EXPECT_EQ(fuse_relu, fuse_relu_attr); EXPECT_EQ(fuse_relu, fuse_relu_attr);
} else if (op->Type() == "relu") { } else if (op->Type() == "relu") {
relu_count++; relu_count++;
......
...@@ -109,8 +109,7 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( ...@@ -109,8 +109,7 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
auto fuse_relu = HasAttribute<bool>(*conv_op, "fuse_relu"); if (HasFusedActivation(conv_op)) return;
if (fuse_relu && *fuse_relu) return;
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
...@@ -179,8 +178,7 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( ...@@ -179,8 +178,7 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
return; return;
} }
auto fuse_relu = HasAttribute<bool>(*residual_conv_op, "fuse_relu"); if (HasFusedActivation(residual_conv_op)) return;
if (fuse_relu && *fuse_relu) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
......
...@@ -126,6 +126,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -126,6 +126,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(graph_ptr graph) const; void ApplyImpl(graph_ptr graph) const;
static bool HasFusedActivation(Node* conv_node) {
return !(conv_node->Op()
->GetAttrIfExists<std::string>("fuse_activation")
.empty());
}
const std::string name_scope_{"residual_connection_fuse_pass"}; const std::string name_scope_{"residual_connection_fuse_pass"};
}; };
......
...@@ -208,6 +208,14 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -208,6 +208,14 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
DequantizeOutput(g, conv_op, conv_output, "Output", output_scale, DequantizeOutput(g, conv_op, conv_output, "Output", output_scale,
is_output_unsigned, "Scale_out"); is_output_unsigned, "Scale_out");
// change threshold in bounded ReLu
if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
"relu6") {
float scale_out = boost::get<float>(conv_op->Op()->GetAttr("Scale_out"));
float threshold = boost::get<float>(conv_op->Op()->GetAttr("fuse_alpha"));
conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
}
++quantize_conv_count; ++quantize_conv_count;
}; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册