未验证 提交 84109c64 编写于 作者: S sangoly 提交者: GitHub

[sangoly] paddle-lite step rnn new (#19100)

* step rnn

* disable ci
上级 b57eac73
...@@ -22,7 +22,9 @@ before_install: ...@@ -22,7 +22,9 @@ before_install:
script: script:
- | - |
# 43min timeout # 43min timeout
paddle/scripts/paddle_docker_build.sh ${JOB} #paddle/scripts/paddle_docker_build.sh ${JOB}
###
echo 0;
if [ $? -eq 0 ] || [ $? -eq 142 ]; then true; else exit 1; fi; if [ $? -eq 0 ] || [ $? -eq 142 ]; then true; else exit 1; fi;
notifications: notifications:
email: email:
......
...@@ -16,8 +16,10 @@ cmake_minimum_required(VERSION 3.0) ...@@ -16,8 +16,10 @@ cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
#add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions("-Wall -g")
include(system) include(system)
include(cross_compiling/preproject)
project(paddle CXX C) project(paddle CXX C)
message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
...@@ -41,7 +43,9 @@ if(WIN32) ...@@ -41,7 +43,9 @@ if(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
endif(WIN32) endif(WIN32)
find_package(CUDA QUIET) if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
find_package(CUDA QUIET)
endif()
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
...@@ -79,11 +83,35 @@ option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VER ...@@ -79,11 +83,35 @@ option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VER
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
# PY_VERSION if(ANDROID OR IOS OR ARMLINUX)
if(NOT PY_VERSION) set(WITH_GPU OFF CACHE STRING
set(PY_VERSION 2.7) "Disable GPU when cross-compiling for Android and iOS" FORCE)
set(WITH_DSO OFF CACHE STRING
"Disable DSO when cross-compiling for Android and iOS" FORCE)
set(WITH_AVX OFF CACHE STRING
"Disable AVX when cross-compiling for Android and iOS" FORCE)
set(WITH_PYTHON OFF CACHE STRING
"Disable PYTHON when cross-compiling for Android and iOS" FORCE)
set(WITH_RDMA OFF CACHE STRING
"Disable RDMA when cross-compiling for Android and iOS" FORCE)
set(WITH_MKL OFF CACHE STRING
"Disable MKL when cross-compiling for Android and iOS" FORCE)
endif() endif()
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
# for lite, both server and mobile framework.
option(WITH_LITE "Enable lite framework" OFF)
option(LITE_WITH_JAVA "Enable Java JNI lib in lite mode" OFF)
option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_ARM "Enable ARM in lite mode" OFF)
option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON)
option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF)
option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -92,6 +120,42 @@ if(NOT CMAKE_BUILD_TYPE) ...@@ -92,6 +120,42 @@ if(NOT CMAKE_BUILD_TYPE)
FORCE) FORCE)
endif() endif()
include_directories("${PADDLE_SOURCE_DIR}")
# for mobile
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
message(STATUS "Building the mobile framework")
include(cross_compiling/postproject)
# include the necessary thirdparty dependencies
include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog
include(external/gtest) # download, build, install gtest
#include(external/zlib) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf
include(external/eigen) # download eigen3
include(ccache) # set ccache for compilation
# for opencl
if (LITE_WITH_OPENCL)
include(external/opencl-headers)
include(external/opencl-clhpp)
endif()
include(generic) # simplify cmake module
include(configure) # add paddle env configuration
add_subdirectory(paddle)
return()
endif()
# PY_VERSION
if(NOT PY_VERSION)
set(PY_VERSION 2.7)
endif()
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
if (APPLE) if (APPLE)
set(WITH_MKL OFF CACHE STRING set(WITH_MKL OFF CACHE STRING
"Disable MKL for building on mac" FORCE) "Disable MKL for building on mac" FORCE)
...@@ -102,16 +166,12 @@ if (WIN32) ...@@ -102,16 +166,12 @@ if (WIN32)
"Disable DISTRIBUTE when compiling for Windows" FORCE) "Disable DISTRIBUTE when compiling for Windows" FORCE)
endif() endif()
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
"A path setting fluid shared and static libraries") "A path setting fluid shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries") "A path setting fluid inference shared and static libraries")
set(THIRD_PARTY_BUILD_TYPE Release)
set(WITH_MKLML ${WITH_MKL}) set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN) if (NOT DEFINED WITH_MKLDNN)
...@@ -185,7 +245,6 @@ if(WITH_BRPC_RDMA) ...@@ -185,7 +245,6 @@ if(WITH_BRPC_RDMA)
endif() endif()
endif() endif()
include(external/threadpool) include(external/threadpool)
include(flags) # set paddle compile flags include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
...@@ -234,7 +293,6 @@ include(coveralls) # set code coverage ...@@ -234,7 +293,6 @@ include(coveralls) # set code coverage
include(inference_lib) # add paddle fluid inference libraries include(inference_lib) # add paddle fluid inference libraries
include_directories("${PADDLE_SOURCE_DIR}")
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
find_package(HIP) find_package(HIP)
......
# A image for building paddle binaries # A image for building paddle binaries
# Use cuda devel base image for both cpu and gpu environment # Use cuda devel base image for both cpu and gpu environment
# When you modify it, please be aware of cudnn-runtime version # When you modify it, please be aware of cudnn-runtime version
# and libcudnn.so.x in paddle/scripts/docker/build.sh
FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
...@@ -75,7 +76,7 @@ RUN curl -s -q https://glide.sh/get | sh ...@@ -75,7 +76,7 @@ RUN curl -s -q https://glide.sh/get | sh
# 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle. # 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle.
# See https://github.com/PaddlePaddle/Paddle/issues/10129 for details. # See https://github.com/PaddlePaddle/Paddle/issues/10129 for details.
RUN wget -q https://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \ RUN wget -q https://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \
tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \ tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \ cp -rf /usr/local/TensorRT/include /usr && \
cp -rf /usr/local/TensorRT/lib /usr cp -rf /usr/local/TensorRT/lib /usr
...@@ -92,17 +93,17 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 ...@@ -92,17 +93,17 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U # specify sphinx version as 1.5.6 and remove -U option for [pip install -U
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest # sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
# version(1.7.1 for now), which causes building documentation failed. # version(1.7.1 for now), which causes building documentation failed.
RUN pip3 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \ RUN pip3 --no-cache-dir install -U wheel && \
pip3 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip3 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.6 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \ pip3.6 --no-cache-dir install -U wheel && \
pip3.6 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip3.6 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.6 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3.6 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.7 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \ pip3.7 --no-cache-dir install -U wheel && \
pip3.7 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip3.7 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.7 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3.7 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip --no-cache-dir install -U pip setuptools wheel py-cpuinfo==5.0.0 && \ pip --no-cache-dir install -U pip setuptools wheel && \
pip --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark pip --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark
......
...@@ -98,11 +98,9 @@ We provide [English](http://www.paddlepaddle.org/documentation/docs/en/1.4/begin ...@@ -98,11 +98,9 @@ We provide [English](http://www.paddlepaddle.org/documentation/docs/en/1.4/begin
We appreciate your contributions! We appreciate your contributions!
## Communication ## Ask Questions
- [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc. You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
- QQ discussion group: 432676488 (PaddlePaddle).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Copyright and License ## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE). PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
...@@ -80,11 +80,9 @@ pip install paddlepaddle-gpu==1.4.1.post85 ...@@ -80,11 +80,9 @@ pip install paddlepaddle-gpu==1.4.1.post85
欢迎您的贡献! 欢迎您的贡献!
## 交流与反馈 ## 答疑
- 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议 欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交
- QQ群: 432676488 (PaddlePaddle)
- [论坛](http://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 版权和许可证 ## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供 PaddlePaddle由[Apache-2.0 license](LICENSE)提供
if(NOT WITH_GPU)
return()
endif()
set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT") set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT")
find_path(ANAKIN_INCLUDE_DIR anakin_config.h find_path(ANAKIN_INCLUDE_DIR anakin_config.h
PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include
...@@ -12,7 +16,9 @@ find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so ...@@ -12,7 +16,9 @@ find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so
DOC "Path to ANAKIN library.") DOC "Path to ANAKIN library.")
if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY) if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY)
if(WITH_DSO)
set(ANAKIN_FOUND ON) set(ANAKIN_FOUND ON)
endif(WITH_DSO)
else() else()
set(ANAKIN_FOUND OFF) set(ANAKIN_FOUND OFF)
endif() endif()
...@@ -25,8 +31,3 @@ if(ANAKIN_FOUND) ...@@ -25,8 +31,3 @@ if(ANAKIN_FOUND)
link_directories(${ANAKIN_ROOT}) link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN) add_definitions(-DPADDLE_WITH_ANAKIN)
endif() endif()
if(ANAKIN_FOUND AND WITH_GPU AND WITH_DSO)
message(STATUS "Compile with anakin subgraph.")
set(ANAKIN_SUBGRAPH ON)
endif()
...@@ -30,7 +30,6 @@ endif(NOT WITH_PROFILER) ...@@ -30,7 +30,6 @@ endif(NOT WITH_PROFILER)
if(WITH_AVX AND AVX_FOUND) if(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG}) set(SIMD_FLAG ${AVX_FLAG})
add_definitions(-DPADDLE_WITH_AVX)
elseif(SSE3_FOUND) elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG}) set(SIMD_FLAG ${SSE3_FLAG})
endif() endif()
...@@ -158,3 +157,33 @@ endif(WITH_BRPC_RDMA) ...@@ -158,3 +157,33 @@ endif(WITH_BRPC_RDMA)
if(ON_INFER) if(ON_INFER)
add_definitions(-DPADDLE_ON_INFERENCE) add_definitions(-DPADDLE_ON_INFERENCE)
endif(ON_INFER) endif(ON_INFER)
if(WITH_WBAES)
add_definitions(-DPADDLE_WITH_WBAES)
endif(WITH_WBAES)
# for lite
# TODO(Superjomn) not work fine with the option
if (LITE_WITH_CUDA)
add_definitions("-DLITE_WITH_CUDA")
endif()
if (LITE_WITH_X86)
add_definitions("-DLITE_WITH_X86")
endif()
if (LITE_WITH_ARM)
add_definitions("-DLITE_WITH_ARM")
endif()
if (LITE_WITH_OPENCL)
add_definitions("-DLITE_WITH_OPENCL")
endif()
if (LITE_WITH_PROFILE)
add_definitions("-DLITE_WITH_PROFILE")
endif()
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK")
endif()
...@@ -26,54 +26,59 @@ if(NOT DEFINED ANDROID_NDK) ...@@ -26,54 +26,59 @@ if(NOT DEFINED ANDROID_NDK)
endif() endif()
endif() endif()
if(ARM_TARGET_LANG STREQUAL "gcc")
# gcc do not need set lang on android
set(ARM_TARGET_LANG "")
endif()
if(NOT DEFINED ANDROID_API_LEVEL) if(NOT DEFINED ANDROID_API_LEVEL)
set(ANDROID_API_LEVEL "22") set(ANDROID_API_LEVEL "22")
endif() endif()
if(NOT DEFINED ANDROID_STL_TYPE) # then check input arm abi
set(ANDROID_STL_TYPE "c++_static" CACHE STRING "stl type") if(ARM_TARGET_ARCH_ABI STREQUAL "armv7hf")
endif() message(FATAL_ERROR "ANDROID does not support hardfp on v7 use armv7 instead.")
# TODO(TJ): enable me
if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a-hf")
message(FATAL_ERROR "Not supported building android armeabi-v7a-hf yet")
endif() endif()
set(ANDROID_ARCH_ABI ${ARM_TARGET_ARCH_ABI} CACHE STRING "Choose Android Arch ABI") set(ANDROID_ARCH_ABI ${ARM_TARGET_ARCH_ABI} CACHE STRING "Choose Android Arch ABI")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv8")
set(ANDROID_ARCH_ABI "arm64-v8a")
endif()
if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a-softfp") if(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(ANDROID_ARCH_ABI "armeabi-v7a") set(ANDROID_ARCH_ABI "armeabi-v7a")
endif() endif()
set(ANDROID_ARCH_ABI_LIST "arm64-v8a" "armeabi-v7a" "armeabi-v6" "armeabi" check_input_var(ANDROID_ARCH_ABI DEFAULT ${ANDROID_ARCH_ABI} LIST "arm64-v8a" "armeabi-v7a"
"mips" "mips64" "x86" "x86_64" "armeabi-v7a-hf") "armeabi-v6" "armeabi" "mips" "mips64" "x86" "x86_64")
set_property(CACHE ANDROID_ARCH_ABI PROPERTY STRINGS ${ANDROID_ARCH_ABI_LIST}) check_input_var(ANDROID_STL_TYPE DEFAULT "c++_static" LIST "c++_static" "gnustl_static")
if(NOT ANDROID_ARCH_ABI IN_LIST ANDROID_ARCH_ABI_LIST)
message(FATAL_ERROR "ANDROID_ARCH_ABI must be in one of ${ANDROID_ARCH_ABI_LIST}")
endif()
if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a")
message(STATUS "armeabi-v7a default use softfp") message(STATUS "armeabi-v7a use softfp by default.")
set(CMAKE_ANDROID_ARM_NEON ON) set(CMAKE_ANDROID_ARM_NEON ON)
message(STATUS "NEON is enabled on arm-v7a with softfp") message(STATUS "NEON is enabled on arm-v7a with softfp.")
endif()
if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a-hf")
set(ANDROID_ARCH_ABI "armeabi-v7a")
set(CMAKE_CXX_FLAGS "-std=c++11 -march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}" )
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" )
message(STATUS "NEON is enabled on arm-v7a with hard float")
endif()
set(ANDROID_STL_TYPE_LITS "gnustl_static" "c++_static")
set_property(CACHE ANDROID_STL_TYPE PROPERTY STRINGS ${ANDROID_STL_TYPE_LITS})
if (NOT ANDROID_STL_TYPE IN_LIST ANDROID_STL_TYPE_LITS)
message(FATAL_ERROR "ANDROID_STL_TYPE must be in one of ${ANDROID_STL_TYPE_LITS}")
endif() endif()
set(CMAKE_SYSTEM_NAME Android) set(CMAKE_SYSTEM_NAME Android)
set(CMAKE_SYSTEM_VERSION ${ANDROID_API_LEVEL}) set(CMAKE_SYSTEM_VERSION ${ANDROID_API_LEVEL})
set(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ARCH_ABI}) set(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ARCH_ABI})
set(CMAKE_ANDROID_NDK ${ANDROID_NDK}) set(CMAKE_ANDROID_NDK ${ANDROID_NDK})
set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION ${ARM_TARGET_LANG})
set(CMAKE_ANDROID_STL_TYPE ${ANDROID_STL_TYPE}) set(CMAKE_ANDROID_STL_TYPE ${ANDROID_STL_TYPE})
if (ARM_TARGET_LANG STREQUAL "clang")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv8")
set(triple aarch64-v8a-linux-android)
elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(triple arm-v7a-linux-android)
else()
message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7")
endif()
set(CMAKE_C_COMPILER clang)
set(CMAKE_C_COMPILER_TARGET ${triple})
set(CMAKE_CXX_COMPILER clang++)
set(CMAKE_CXX_COMPILER_TARGET ${triple})
message(STATUS "CMAKE_CXX_COMPILER_TARGET: ${CMAKE_CXX_COMPILER_TARGET}")
endif()
...@@ -20,38 +20,22 @@ set(ARMLINUX TRUE) ...@@ -20,38 +20,22 @@ set(ARMLINUX TRUE)
add_definitions(-DLITE_WITH_LINUX) add_definitions(-DLITE_WITH_LINUX)
set(CMAKE_SYSTEM_NAME Linux) set(CMAKE_SYSTEM_NAME Linux)
if(ARM_TARGET_ARCH_ABI STREQUAL "arm64-v8a") check_input_var(ARMLINUX_ARCH_ABI DEFAULT ${ARM_TARGET_ARCH_ABI} LIST "armv8" "armv7" "armv7hf")
if(ARMLINUX_ARCH_ABI STREQUAL "armv8")
set(CMAKE_SYSTEM_PROCESSOR aarch64) set(CMAKE_SYSTEM_PROCESSOR aarch64)
set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc") set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc")
set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++") set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++")
set(CMAKE_CXX_FLAGS "-march=armv8-a ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv8-a ${CMAKE_C_FLAGS}")
message(STATUS "NEON is enabled on arm64-v8a")
endif()
if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a"
OR ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a-hf")
message(FATAL_ERROR "Not supported building arm linux arm-v7 yet")
endif() endif()
# TODO(TJ): make sure v7 works if(ARMLINUX_ARCH_ABI STREQUAL "armv7")
if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a")
set(CMAKE_SYSTEM_PROCESSOR arm) set(CMAKE_SYSTEM_PROCESSOR arm)
set(CMAKE_C_COMPILER "arm-linux-gnueabi-gcc") set(CMAKE_C_COMPILER "arm-linux-gnueabi-gcc")
set(CMAKE_CXX_COMPILER "arm-linux-gnueabi-g++") set(CMAKE_CXX_COMPILER "arm-linux-gnueabi-g++")
set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}")
message(STATUS "NEON is enabled on arm-v7a with softfp")
endif() endif()
if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a-hf") if(ARMLINUX_ARCH_ABI STREQUAL "armv7hf")
set(CMAKE_SYSTEM_PROCESSOR arm) set(CMAKE_SYSTEM_PROCESSOR arm)
set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc") set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc")
set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++") set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++")
set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" )
message(STATUS "NEON is enabled on arm-v7a with hard float")
endif() endif()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(NOT ARM_TARGET_LANG STREQUAL "clang")
# only clang need find ar tool
return()
endif()
if(NOT EXISTS "${CMAKE_CXX_COMPILER}")
message(ERROR "Can not find CMAKE_CXX_COMPILER ${CMAKE_CXX_COMPILER}")
endif()
get_filename_component(AR_PATH ${CMAKE_CXX_COMPILER} PATH)
find_file(AR_TOOL NAMES llvm-ar PATHS ${AR_PATH})
if(NOT AR_TOOL)
message(ERROR "Failed to find AR_TOOL in ${AR_PATH}")
else()
set(CMAKE_AR ${AR_TOOL})
message(STATUS "Found CMAKE_AR : " ${CMAKE_AR})
endif()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if (ANDROID)
include(cross_compiling/findar)
endif()
if(ARMLINUX)
if(ARMLINUX_ARCH_ABI STREQUAL "armv8")
set(CMAKE_CXX_FLAGS "-march=armv8-a ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv8-a ${CMAKE_C_FLAGS}")
message(STATUS "NEON is enabled on arm64-v8a")
endif()
if(ARMLINUX_ARCH_ABI STREQUAL "armv7")
set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}")
message(STATUS "NEON is enabled on arm-v7a with softfp")
endif()
if(ARMLINUX_ARCH_ABI STREQUAL "armv7hf")
set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" )
message(STATUS "NEON is enabled on arm-v7a with hard float")
endif()
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if(LITE_WITH_OPENMP)
find_package(OpenMP REQUIRED)
if(OPENMP_FOUND OR OpenMP_CXX_FOUND)
add_definitions(-DARM_WITH_OMP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
message(STATUS "Found OpenMP ${OpenMP_VERSION} ${OpenMP_CXX_VERSION}")
message(STATUS "OpenMP C flags: ${OpenMP_C_FLAGS}")
message(STATUS "OpenMP CXX flags: ${OpenMP_CXX_FLAGS}")
message(STATUS "OpenMP OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message(STATUS "OpenMP OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}")
else()
message(FATAL_ERROR "Could not found OpenMP!")
endif()
endif()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
cmake_minimum_required(VERSION 3.10)
# define check function
function(check_input_var VAR_NAME)
set(options "")
set(oneValueArgs "")
set(multiValueArgs DEFAULT LIST)
cmake_parse_arguments(check_input_var "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(var_out "")
if(NOT DEFINED ${VAR_NAME})
set(var_out ${check_input_var_DEFAULT})
else()
set(var_out ${${VAR_NAME}})
endif()
if(NOT var_out IN_LIST check_input_var_LIST)
message(FATAL_ERROR "${VAR_NAME}:${var_out} must be in one of ${check_input_var_LIST}")
endif()
set(${VAR_NAME} ${var_out} PARENT_SCOPE)
endfunction(check_input_var)
check_input_var(ARM_TARGET_OS DEFAULT "android" LIST "android" "armlinux")
check_input_var(ARM_TARGET_ARCH_ABI DEFAULT "armv8" LIST "armv8" "armv7" "armv7hf" "arm64-v8a" "armeabi-v7a")
check_input_var(ARM_TARGET_LANG DEFAULT "gcc" LIST "gcc" "clang")
check_input_var(ARM_TARGET_LIB_TYPE DEFAULT "static" LIST "static" "shared")
message(STATUS "Lite ARM Compile ${ARM_TARGET_OS} with ${ARM_TARGET_ARCH_ABI} ${ARM_TARGET_LANG}")
include(cross_compiling/host)
include(cross_compiling/armlinux)
include(cross_compiling/android)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Default use Release in android" FORCE)
endif()
if(NOT THIRD_PARTY_BUILD_TYPE)
set(THIRD_PARTY_BUILD_TYPE "MinSizeRel" CACHE STRING "Default use MinSizeRel in android" FORCE)
endif()
...@@ -141,10 +141,12 @@ endfunction() ...@@ -141,10 +141,12 @@ endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION}) message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0) if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
add_definitions("-DPADDLE_CUDA_BINVER=\"60\"")
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"70\"")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
...@@ -152,16 +154,18 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x ...@@ -152,16 +154,18 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the # CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now. # warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif() endif()
add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"")
include_directories(${CUDA_INCLUDE_DIRS}) include_directories(${CUDA_INCLUDE_DIRS})
if(NOT WITH_DSO) if(NOT WITH_DSO)
......
...@@ -96,7 +96,7 @@ if(CUDNN_FOUND) ...@@ -96,7 +96,7 @@ if(CUDNN_FOUND)
endif() endif()
message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. " message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. "
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ") "Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ")
endif() endif()
endif() endif()
...@@ -38,3 +38,5 @@ ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL) ...@@ -38,3 +38,5 @@ ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES}) SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc) ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
...@@ -12,13 +12,6 @@ if(NOT WITH_FAST_MATH) ...@@ -12,13 +12,6 @@ if(NOT WITH_FAST_MATH)
add_definitions(-DEIGEN_FAST_MATH=0) add_definitions(-DEIGEN_FAST_MATH=0)
endif() endif()
if(WIN32)
set(EIGEN_GIT_REPOSITORY https://github.com/wopeizl/eigen-git-mirror)
set(EIGEN_GIT_TAG support_cuda9_win)
else()
set(EIGEN_GIT_REPOSITORY https://github.com/eigenteam/eigen-git-mirror)
set(EIGEN_GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c)
endif()
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
...@@ -36,10 +29,10 @@ else() ...@@ -36,10 +29,10 @@ else()
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "${EIGEN_GIT_REPOSITORY}" GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror"
# eigen on cuda9.1 missing header of math_funtions.hpp # eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG ${EIGEN_GIT_TAG} GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
DOWNLOAD_NAME "eigen" DOWNLOAD_NAME "eigen"
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -18,13 +18,32 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) ...@@ -18,13 +18,32 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags)
SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags)
SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE)
IF(WIN32) IF(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ELSE(WIN32) ELSE(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ENDIF(WIN32) ENDIF(WIN32)
INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}"
"-DCMAKE_ANDROID_NDK_TOOLCHAIN_VERSION=${CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION}" )
endif()
ExternalProject_Add( ExternalProject_Add(
extern_gflags extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -32,24 +51,24 @@ ExternalProject_Add( ...@@ -32,24 +51,24 @@ ExternalProject_Add(
GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a
PREFIX ${GFLAGS_SOURCES_DIR} PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DBUILD_STATIC_LIBS=ON
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DBUILD_STATIC_LIBS=ON
-DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF -DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${OPTIONAL_ARGS}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32)
IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib")
add_custom_command(TARGET extern_gflags POST_BUILD
COMMAND cmake -E copy ${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib ${GFLAGS_INSTALL_DIR}/lib/libgflags.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags) ADD_DEPENDENCIES(gflags extern_gflags)
......
...@@ -19,7 +19,7 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog) ...@@ -19,7 +19,7 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog)
SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE) SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE)
IF(WIN32) IF(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/glog.lib" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.lib" CACHE FILEPATH "glog library." FORCE)
SET(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4530") SET(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4530")
ELSE(WIN32) ELSE(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE)
...@@ -31,6 +31,25 @@ INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR}) ...@@ -31,6 +31,25 @@ INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR})
SET(GLOG_REPOSITORY "https://github.com/google/glog.git") SET(GLOG_REPOSITORY "https://github.com/google/glog.git")
SET(GLOG_TAG "v0.3.5") SET(GLOG_TAG "v0.3.5")
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}"
"-DCMAKE_ANDROID_NDK_TOOLCHAIN_VERSION=${CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION}")
endif()
ExternalProject_Add( ExternalProject_Add(
extern_glog extern_glog
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -39,14 +58,7 @@ ExternalProject_Add( ...@@ -39,14 +58,7 @@ ExternalProject_Add(
GIT_TAG ${GLOG_TAG} GIT_TAG ${GLOG_TAG}
PREFIX ${GLOG_SOURCES_DIR} PREFIX ${GLOG_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS ${OPTIONAL_ARGS}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
...@@ -60,6 +72,13 @@ ExternalProject_Add( ...@@ -60,6 +72,13 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32)
IF(NOT EXISTS "${GLOG_INSTALL_DIR}/lib/libglog.lib")
add_custom_command(TARGET extern_glog POST_BUILD
COMMAND cmake -E copy ${GLOG_INSTALL_DIR}/lib/glog.lib ${GLOG_INSTALL_DIR}/lib/libglog.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
......
...@@ -43,6 +43,26 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -43,6 +43,26 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
SET(GTEST_DEPENDS ${MKLML_PROJECT}) SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF() ENDIF()
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}"
"-DCMAKE_ANDROID_NDK_TOOLCHAIN_VERSION=${CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION}"
)
endif()
ExternalProject_Add( ExternalProject_Add(
extern_gtest extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -51,14 +71,7 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -51,14 +71,7 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
GIT_TAG "release-1.8.0" GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR} PREFIX ${GTEST_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS ${OPTIONAL_ARGS}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_GMOCK=ON -DBUILD_GMOCK=ON
......
...@@ -38,7 +38,6 @@ IF(WIN32) ...@@ -38,7 +38,6 @@ IF(WIN32)
SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib) SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_LIB_DEPS ${MKLML_LIB_DIR}/msvcr120.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE() ELSE()
#TODO(intel-huying): #TODO(intel-huying):
......
...@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs) ...@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph") SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a") SET(NGRAPH_GIT_TAG "127e0dedfaac8c6f2b148cc03bf5f67ac5fbe6fe")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(OPENCL_CLHPP_SRCS_DIR ${THIRD_PARTY_PATH}/opencl-clhpp)
SET(OPENCL_CLHPP_INSTALL_DIR ${THIRD_PARTY_PATH}/install/opencl-clhpp)
SET(OPENCL_CLHPP_INCLUDE_DIR "${OPENCL_CLHPP_INSTALL_DIR}" CACHE PATH "opencl-clhpp include directory." FORCE)
INCLUDE_DIRECTORIES(${OPENCL_CLHPP_INCLUDE_DIR})
ExternalProject_Add(
opencl_clhpp
GIT_REPOSITORY "https://github.com/KhronosGroup/OpenCL-CLHPP.git"
GIT_TAG "v2.0.10"
PREFIX "${OPENCL_CLHPP_SRCS_DIR}"
CMAKE_ARGS -DBUILD_DOCS=OFF
-DBUILD_EXAMPLES=OFF
-DBUILD_TESTS=OFF
-DCMAKE_INSTALL_PREFIX=${OPENCL_CLHPP_INSTALL_DIR}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${OPENCL_CLHPP_INSTALL_DIR}
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
ADD_DEPENDENCIES(opencl_clhpp opencl_headers)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(OPENCL_HEADERS_SRCS_DIR ${THIRD_PARTY_PATH}/opencl-headers)
SET(OPENCL_HEADERS_INCLUDE_DIR "${OPENCL_HEADERS_SRCS_DIR}/src/opencl_headers" CACHE PATH "opencl-headers include directory." FORCE)
INCLUDE_DIRECTORIES(${OPENCL_HEADERS_INCLUDE_DIR})
ExternalProject_Add(
opencl_headers
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/KhronosGroup/OpenCL-Headers.git"
GIT_TAG "c5a4bbeabb10d8ed3d1c651b93aa31737bc473dd"
PREFIX ${OPENCL_HEADERS_SRCS_DIR}
DOWNLOAD_NAME "OpenCL-Headers"
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
...@@ -142,7 +142,6 @@ IF (WIN32) ...@@ -142,7 +142,6 @@ IF (WIN32)
ENDIF(WIN32) ENDIF(WIN32)
if (NOT "${PROTOBUF_ROOT}" STREQUAL "") if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH) find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH)
find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH)
find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH)
...@@ -178,12 +177,29 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -178,12 +177,29 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
SET(OPTIONAL_CACHE_ARGS "") SET(OPTIONAL_CACHE_ARGS "")
SET(OPTIONAL_ARGS "") SET(OPTIONAL_ARGS "")
IF(BUILD_FOR_HOST) IF(BUILD_FOR_HOST)
SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF")
ELSE()
SET(OPTIONAL_ARGS SET(OPTIONAL_ARGS
"-DCMAKE_C_COMPILER=${HOST_C_COMPILER}"
"-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}"
"-Dprotobuf_WITH_ZLIB=OFF"
"-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}")
SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}")
ELSE()
# protobuf have compile issue when use android stl c++_static
SET(PROTOBUF_REPO "https://github.com/tensor-tang/protobuf.git")
SET(PROTOBUF_TAG "mobile")
SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF"
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}"
"-DCMAKE_ANDROID_NDK_TOOLCHAIN_VERSION=${CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION}"
"-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
...@@ -191,25 +207,18 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -191,25 +207,18 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}" "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}")
"-Dprotobuf_WITH_ZLIB=ON"
"-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}"
${EXTERNAL_OPTIONAL_ARGS})
SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}")
ENDIF() ENDIF()
IF(WIN32) IF(WIN32)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64")
ENDIF() ENDIF()
SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
ExternalProject_Add( ExternalProject_Add(
${TARGET_NAME} ${TARGET_NAME}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PROTOBUF_SOURCES_DIR} PREFIX ${PROTOBUF_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
DEPENDS zlib #DEPENDS zlib
GIT_REPOSITORY ${PROTOBUF_REPO} GIT_REPOSITORY ${PROTOBUF_REPO}
GIT_TAG ${PROTOBUF_TAG} GIT_TAG ${PROTOBUF_TAG}
CONFIGURE_COMMAND CONFIGURE_COMMAND
...@@ -233,6 +242,13 @@ ENDFUNCTION() ...@@ -233,6 +242,13 @@ ENDFUNCTION()
SET(PROTOBUF_VERSION 3.1.0) SET(PROTOBUF_VERSION 3.1.0)
IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
build_protobuf(protobuf_host TRUE)
LIST(APPEND external_project_dependencies protobuf_host)
SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_host_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE)
ENDIF()
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE) build_protobuf(extern_protobuf FALSE)
...@@ -245,7 +261,12 @@ IF(NOT PROTOBUF_FOUND) ...@@ -245,7 +261,12 @@ IF(NOT PROTOBUF_FOUND)
SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY}
CACHE FILEPATH "protoc library." FORCE) CACHE FILEPATH "protoc library." FORCE)
IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf)
ELSE()
SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE) CACHE FILEPATH "protobuf executable." FORCE)
PROMPT_PROTOBUF_LIB(extern_protobuf) PROMPT_PROTOBUF_LIB(extern_protobuf)
ENDIF()
ENDIF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND)
...@@ -29,9 +29,9 @@ INCLUDE(ExternalProject) ...@@ -29,9 +29,9 @@ INCLUDE(ExternalProject)
SET(PSLIB_PROJECT "extern_pslib") SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL)) IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url") MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.1" CACHE STRING "" FORCE) SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE) SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/ps/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE) SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF() ENDIF()
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}") MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib") SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
......
...@@ -53,7 +53,12 @@ ExternalProject_Add( ...@@ -53,7 +53,12 @@ ExternalProject_Add(
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32) IF(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib") IF(NOT EXISTS "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
add_custom_command(TARGET extern_snappy POST_BUILD
COMMAND cmake -E copy ${SNAPPY_INSTALL_DIR}/lib/snappy.lib ${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib
)
ENDIF()
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
else(WIN32) else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a") set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32) endif (WIN32)
......
...@@ -64,7 +64,12 @@ ExternalProject_Add( ...@@ -64,7 +64,12 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
) )
IF(WIN32) IF(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" IF(NOT EXISTS "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}")
add_custom_command(TARGET extern_warpctc POST_BUILD
COMMAND cmake -E copy ${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} ${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}
)
ENDIF()
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE) CACHE FILEPATH "Warp-ctc Library" FORCE)
else(WIN32) else(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
......
...@@ -56,7 +56,12 @@ else() ...@@ -56,7 +56,12 @@ else()
endif() endif()
if (WIN32) if (WIN32)
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/xxhash.lib") IF(NOT EXISTS "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
add_custom_command(TARGET extern_xxhash POST_BUILD
COMMAND cmake -E copy ${XXHASH_INSTALL_DIR}/lib/xxhash.lib ${XXHASH_INSTALL_DIR}/lib/libxxhash.lib
)
ENDIF()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
else() else()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
endif () endif ()
......
...@@ -44,7 +44,12 @@ ExternalProject_Add( ...@@ -44,7 +44,12 @@ ExternalProject_Add(
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32) IF(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib" CACHE FILEPATH "zlib library." FORCE) IF(NOT EXISTS "${ZLIB_INSTALL_DIR}/lib/libz.lib")
add_custom_command(TARGET extern_zlib POST_BUILD
COMMAND cmake -E copy ${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib ${ZLIB_INSTALL_DIR}/lib/libz.lib
)
ENDIF()
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32) ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE) SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32) ENDIF(WIN32)
......
...@@ -93,7 +93,10 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) ...@@ -93,7 +93,10 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl")
if (NOT ANDROID)
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -lrt")
endif()
endif(NOT APPLE) endif(NOT APPLE)
set_property(GLOBAL PROPERTY FLUID_MODULES "") set_property(GLOBAL PROPERTY FLUID_MODULES "")
...@@ -363,10 +366,11 @@ function(cc_binary TARGET_NAME) ...@@ -363,10 +366,11 @@ function(cc_binary TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${os_dependency_modules}) target_link_libraries(${TARGET_NAME} ${os_dependency_modules})
endfunction(cc_binary) endfunction(cc_binary)
function(cc_test_build TARGET_NAME) function(cc_test TARGET_NAME)
if(WITH_TESTING) if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS}) add_executable(${TARGET_NAME} ${cc_test_SRCS})
if(WIN32) if(WIN32)
...@@ -379,18 +383,12 @@ function(cc_test_build TARGET_NAME) ...@@ -379,18 +383,12 @@ function(cc_test_build TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
endif()
endfunction()
function(cc_test_run TARGET_NAME)
if(WITH_TESTING)
set(oneValueArgs "")
set(multiValueArgs COMMAND ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND ${cc_test_COMMAND} COMMAND ${TARGET_NAME} ${cc_test_ARGS}
ARGS ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
...@@ -398,21 +396,46 @@ function(cc_test_run TARGET_NAME) ...@@ -398,21 +396,46 @@ function(cc_test_run TARGET_NAME)
# No unit test should exceed 10 minutes. # No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
endfunction() endfunction(cc_test)
function(cc_test TARGET_NAME) # cc_test without default dependencies
function(raw_cc_test TARGET_NAME)
if(WITH_TESTING) if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS) set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_test_build(${TARGET_NAME} add_executable(${TARGET_NAME} ${cc_test_SRCS})
SRCS ${cc_test_SRCS} if(WIN32)
DEPS ${cc_test_DEPS}) if("${cc_test_DEPS};" MATCHES "python;")
cc_test_run(${TARGET_NAME} list(REMOVE_ITEM cc_test_DEPS python)
COMMAND ${TARGET_NAME} target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES})
ARGS ${cc_test_ARGS})
endif() endif()
endfunction(cc_test) endif(WIN32)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} lite_gtest_main gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
endfunction(raw_cc_test)
function(_lite_cc_test args)
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
message(STATUS "building lite raw test: ${args}")
raw_cc_test(${args} ${ARGN})
else()
message(STATUS "building lite heavy test: ${args}")
cc_test(${args} ${ARGN})
endif()
endfunction()
function(nv_library TARGET_NAME) function(nv_library TARGET_NAME)
if (WITH_GPU) if (WITH_GPU)
...@@ -465,6 +488,7 @@ endfunction(nv_binary) ...@@ -465,6 +488,7 @@ endfunction(nv_binary)
function(nv_test TARGET_NAME) function(nv_test TARGET_NAME)
if (WITH_GPU AND WITH_TESTING) if (WITH_GPU AND WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
...@@ -474,6 +498,9 @@ function(nv_test TARGET_NAME) ...@@ -474,6 +498,9 @@ function(nv_test TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
...@@ -716,7 +743,7 @@ function(py_proto_compile TARGET_NAME) ...@@ -716,7 +743,7 @@ function(py_proto_compile TARGET_NAME)
cmake_parse_arguments(py_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(py_srcs) set(py_srcs)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS}) protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs} protobuf) add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction() endfunction()
function(py_test TARGET_NAME) function(py_test TARGET_NAME)
......
# Bundle several static libraries into one.
function(bundle_static_library tgt_name bundled_tgt_name fake_target)
list(APPEND static_libs ${tgt_name})
function(_recursively_collect_dependencies input_target)
set(_input_link_libraries LINK_LIBRARIES)
get_target_property(_input_type ${input_target} TYPE)
if (${_input_type} STREQUAL "INTERFACE_LIBRARY")
set(_input_link_libraries INTERFACE_LINK_LIBRARIES)
endif()
get_target_property(public_dependencies ${input_target} ${_input_link_libraries})
foreach(dependency IN LISTS public_dependencies)
if(TARGET ${dependency})
get_target_property(alias ${dependency} ALIASED_TARGET)
if (TARGET ${alias})
set(dependency ${alias})
endif()
get_target_property(_type ${dependency} TYPE)
if (${_type} STREQUAL "STATIC_LIBRARY")
list(APPEND static_libs ${dependency})
endif()
get_property(library_already_added
GLOBAL PROPERTY _${tgt_name}_static_bundle_${dependency})
if (NOT library_already_added)
set_property(GLOBAL PROPERTY _${tgt_name}_static_bundle_${dependency} ON)
_recursively_collect_dependencies(${dependency})
endif()
endif()
endforeach()
set(static_libs ${static_libs} PARENT_SCOPE)
endfunction()
_recursively_collect_dependencies(${tgt_name})
list(REMOVE_DUPLICATES static_libs)
set(bundled_tgt_full_name
${CMAKE_BINARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${bundled_tgt_name}${CMAKE_STATIC_LIBRARY_SUFFIX})
message(STATUS "+++++ bundled_tgt_full_name: ${bundled_tgt_full_name}")
file(WRITE ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in
"CREATE ${bundled_tgt_full_name}\n" )
foreach(tgt IN LISTS static_libs)
file(APPEND ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in
"ADDLIB $<TARGET_FILE:${tgt}>\n")
endforeach()
file(APPEND ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in "SAVE\n")
file(APPEND ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in "END\n")
file(GENERATE
OUTPUT ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar
INPUT ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in)
set(ar_tool ${CMAKE_AR})
if (CMAKE_INTERPROCEDURAL_OPTIMIZATION)
set(ar_tool ${CMAKE_CXX_COMPILER_AR})
endif()
add_custom_command(
COMMAND ${ar_tool} -M < ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar
OUTPUT ${bundled_tgt_full_name}
COMMENT "Bundling ${bundled_tgt_name}"
VERBATIM)
add_custom_target(${fake_target} ALL DEPENDS ${bundled_tgt_full_name})
add_dependencies(${fake_target} ${tgt_name})
add_library(${bundled_tgt_name} STATIC IMPORTED)
set_target_properties(${bundled_tgt_name}
PROPERTIES
IMPORTED_LOCATION ${bundled_tgt_full_name}
INTERFACE_INCLUDE_DIRECTORIES $<TARGET_PROPERTY:${tgt_name},INTERFACE_INCLUDE_DIRECTORIES>)
add_dependencies(${bundled_tgt_name} ${fake_target})
endfunction()
...@@ -110,7 +110,7 @@ function(op_library TARGET) ...@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -3,6 +3,9 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION}) ...@@ -3,6 +3,9 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD") set(tmp_version "HEAD")
set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?") set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+") set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
# set(LATEST_PADDLE_VERSION "latest")
set(LATEST_PADDLE_VERSION "0.0.0")
while ("${PADDLE_VERSION}" STREQUAL "") while ("${PADDLE_VERSION}" STREQUAL "")
# Check current branch name # Check current branch name
execute_process( execute_process(
...@@ -23,8 +26,8 @@ while ("${PADDLE_VERSION}" STREQUAL "") ...@@ -23,8 +26,8 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}") if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}")
# Check the tag is a correct version # Check the tag is a correct version
if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}") if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
# if no tag was found, set PADDLE_VERSION to 0.0.0 to represent latest # if no tag was found, set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "0.0.0") set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME}) string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name. else() # otherwise, get the previous git tag name.
...@@ -42,19 +45,19 @@ while ("${PADDLE_VERSION}" STREQUAL "") ...@@ -42,19 +45,19 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME}) string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME})
else() else()
set(PADDLE_VERSION "0.0.0") set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
endif() endif()
else() else()
# otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest # otherwise, we always set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "0.0.0") set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
endif() endif()
endif() endif()
else() else()
set(PADDLE_VERSION "0.0.0") set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
message(WARNING "Cannot add paddle version from git tag") message(WARNING "Cannot add paddle version from git tag")
endif() endif()
else() else()
set(PADDLE_VERSION "0.0.0") set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
message(WARNING "Cannot add paddle version for wrong git branch result") message(WARNING "Cannot add paddle version for wrong git branch result")
endif() endif()
endwhile() endwhile()
......
add_subdirectory(scripts) # to limit the mobile dependencies
add_subdirectory(testing) if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
endif()
add_subdirectory(fluid) add_subdirectory(fluid)
此差异已折叠。
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) # for mobile
add_subdirectory(lite)
return()
endif()
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
...@@ -6,7 +10,8 @@ add_subdirectory(operators) ...@@ -6,7 +10,8 @@ add_subdirectory(operators)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(recordio) add_subdirectory(recordio)
add_subdirectory(pybind) add_subdirectory(pybind)
add_subdirectory(train)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(train)
add_subdirectory(lite)
...@@ -29,8 +29,7 @@ add_subdirectory(io) ...@@ -29,8 +29,7 @@ add_subdirectory(io)
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto) proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
data_feed_proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -125,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co ...@@ -125,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type data_feed_proto)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -174,20 +173,20 @@ endif() ...@@ -174,20 +173,20 @@ endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector) cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto) graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
...@@ -202,10 +201,10 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -202,10 +201,10 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc pipeline_trainer.cc executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc section_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
device_worker_factory.cc data_set.cc dataset_factory.cc data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto feed_fetch_method graph_to_program_pass data_feed_proto
...@@ -226,8 +225,6 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) ...@@ -226,8 +225,6 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc ) cc_test(tuple_test SRCS tuple_test.cc )
cc_test(inlined_vector_test SRCS inlined_vector_test.cc)
if (NOT WIN32) if (NOT WIN32)
cc_test(rw_lock_test SRCS rw_lock_test.cc) cc_test(rw_lock_test SRCS rw_lock_test.cc)
endif (NOT WIN32) endif (NOT WIN32)
......
...@@ -85,9 +85,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -85,9 +85,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
DataFeedDesc data_feed_desc; DataFeedDesc data_feed_desc;
bool success = data_feed_desc.ParseFromString(data_feed_desc_str); google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
PADDLE_ENFORCE(success, "Fail to parse DataFeedDesc from string:\n%s", &data_feed_desc);
data_feed_desc_str.c_str());
actual_thread_num_ = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
......
...@@ -95,11 +95,6 @@ class BlockingQueue { ...@@ -95,11 +95,6 @@ class BlockingQueue {
return q_.size(); return q_.size();
} }
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(q_);
}
private: private:
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
......
...@@ -20,9 +20,6 @@ limitations under the License. */ ...@@ -20,9 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX #ifdef _LINUX
#include <stdio_ext.h> #include <stdio_ext.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#endif #endif
#include <utility> #include <utility>
#include "gflags/gflags.h" #include "gflags/gflags.h"
...@@ -90,13 +87,6 @@ void DataFeed::CheckStart() { ...@@ -90,13 +87,6 @@ void DataFeed::CheckStart() {
PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet."); PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet.");
} }
void DataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
feed_vec_[i] = scope.FindVar(use_slots_[i])->GetMutable<LoDTensor>();
}
}
template <typename T> template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size); PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
...@@ -168,7 +158,6 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -168,7 +158,6 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
mutex_for_update_memory_data_ = nullptr; mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr; this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr; this->mutex_for_pick_file_ = nullptr;
fleet_send_sleep_seconds_ = 2;
} }
template <typename T> template <typename T>
...@@ -377,7 +366,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -377,7 +366,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_); std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_); std::vector<int> send_index(trainer_num_);
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_ + 1; uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
for (auto& vec : send_vec) { for (auto& vec : send_vec) {
vec.reserve(reserve_len); vec.reserve(reserve_len);
} }
...@@ -388,34 +377,47 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -388,34 +377,47 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto interval = GetMemoryDataInterval(); auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", " VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_; << interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
for (int64_t i = interval.first; i < interval.second; // if get ins id, can also use hash
i += fleet_send_batch_size_) { // std::string ins_id = memory_data_[i].ins_id;
for (int64_t j = 0; j < fleet_send_batch_size_ && i + j < interval.second; int64_t random_num = rand_r(&rand_seed);
++j) {
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_; int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i + j])); send_vec[node_id].push_back(&((*memory_data_)[i]));
} if (i % fleet_send_batch_size_ == 0 && i != 0) {
total_status.clear(); // shuffle the sequence of sending to avoid network timeout error
std::shuffle(send_index.begin(), send_index.end(), std::random_shuffle(send_index.begin(), send_index.end());
fleet_ptr->LocalRandomEngine());
for (int index = 0; index < send_index.size(); ++index) { for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index]; int j = send_index[index];
if (send_vec[j].size() == 0) { std::string send_str;
continue; SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
} }
}
}
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
if (send_vec[j].size() != 0) {
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret)); total_status.push_back(std::move(ret));
send_vec[j].clear(); }
std::vector<T*>().swap(send_vec[j]);
} }
for (auto& t : total_status) { for (auto& t : total_status) {
t.wait(); t.wait();
} }
sleep(fleet_send_sleep_seconds_);
}
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_; VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
#endif #endif
} }
...@@ -434,24 +436,6 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() { ...@@ -434,24 +436,6 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
return std::make_pair(start, end); return std::make_pair(start, end);
} }
template <typename T>
int64_t InMemoryDataFeed<T>::GetChannelDataSize() {
if (cur_channel_ == 0) {
return shuffled_ins_->Size();
} else {
return shuffled_ins_out_->Size();
}
}
template <typename T>
void InMemoryDataFeed<T>::ReleaseChannelData() {
if (cur_channel_ == 0) {
shuffled_ins_->Clear();
} else {
shuffled_ins_out_->Clear();
}
}
// explicit instantiation // explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>; template class InMemoryDataFeed<std::vector<MultiSlotType>>;
...@@ -487,17 +471,17 @@ void MultiSlotDataFeed::Init( ...@@ -487,17 +471,17 @@ void MultiSlotDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense()); use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape; std::vector<int> local_shape;
if (slot.is_dense()) { if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) { for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(j) > 0) { if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j); total_dims_without_inductive_[i] *= slot.shape(i);
} }
if (slot.shape(j) == -1) { if (slot.shape(i) == -1) {
inductive_shape_index_[i] = j; inductive_shape_index_[i] = i;
} }
} }
} }
for (size_t j = 0; j < slot.shape_size(); ++j) { for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(j)); local_shape.push_back(slot.shape(i));
} }
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
...@@ -821,24 +805,22 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -821,24 +805,22 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_[i] = slot.name(); all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type(); all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1; use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) { if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]); use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense()); use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape; std::vector<int> local_shape;
if (slot.is_dense()) { if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) { for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(j) > 0) { if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j); total_dims_without_inductive_[i] *= slot.shape(i);
} }
if (slot.shape(j) == -1) { if (slot.shape(i) == -1) {
inductive_shape_index_[i] = j; inductive_shape_index_[i] = i;
} }
} }
} }
for (size_t j = 0; j < slot.shape_size(); ++j) { for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(j)); local_shape.push_back(slot.shape(i));
} }
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
...@@ -1019,205 +1001,5 @@ void MultiSlotInMemoryDataFeed::DeserializeIns( ...@@ -1019,205 +1001,5 @@ void MultiSlotInMemoryDataFeed::DeserializeIns(
fleet_ptr->Deserialize(ins, str); fleet_ptr->Deserialize(ins, str);
} }
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec_[i].GetType();
const auto& offset = ins_vec_[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec_[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec_[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int64_t total_dims = 1;
for (const auto e : use_slots_shape_[i]) {
total_dims *= e;
}
PADDLE_ENFORCE(
total_dims == total_instance,
"The actual data size of slot[%s] doesn't match its declaration",
use_slots_[i].c_str());
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
}
template <typename T>
int PrivateInstantDataFeed<T>::Next() {
if (ParseOneMiniBatch()) {
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
Postprocess();
std::string filename;
if (!PickOneFile(&filename)) {
return -1;
}
if (!Preprocess(filename)) {
return -1;
}
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
template <typename T>
void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
multi_inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) == -1) {
multi_inductive_shape_index_[i].push_back(j);
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
ins_vec_.resize(use_slots_.size());
finish_init_ = true;
}
template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());
struct stat sb;
fstat(fd_, &sb);
end_ = static_cast<size_t>(sb.st_size);
buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));
offset_ = 0;
return true;
}
bool MultiSlotFileInstantDataFeed::Postprocess() {
if (buffer_ != nullptr) {
munmap(buffer_, end_);
buffer_ = nullptr;
}
if (fd_ != -1) {
close(fd_);
fd_ = -1;
end_ = 0;
offset_ = 0;
}
return true;
}
bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
if (offset_ == end_) {
return false;
}
batch_size_ = 0;
while (batch_size_ < default_batch_size_ && offset_ < end_) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.");
offset_ += sizeof(uint16_t);
if (idx != -1) {
int inductive_size = multi_inductive_shape_index_[i].size();
if (UNLIKELY(batch_size_ == 0)) {
ins_vec_[idx].Init(all_slots_type_[i], default_batch_size_ * num);
ins_vec_[idx].InitOffset(default_batch_size_);
uint64_t* inductive_shape =
reinterpret_cast<uint64_t*>(buffer_ + offset_);
for (int inductive_id = 0; inductive_id < inductive_size;
++inductive_id) {
use_slots_shape_[i][multi_inductive_shape_index_[i][inductive_id]] =
static_cast<int>(*(inductive_shape + inductive_id));
}
}
num -= inductive_size;
offset_ += sizeof(uint64_t) * inductive_size;
if (type == 'f') {
ins_vec_[idx].AppendValues(
reinterpret_cast<float*>(buffer_ + offset_), num);
offset_ += num * sizeof(float);
} else if (type == 'u') {
ins_vec_[idx].AppendValues(
reinterpret_cast<uint64_t*>(buffer_ + offset_), num);
offset_ += num * sizeof(uint64_t);
}
} else {
if (type == 'f') {
offset_ += num * sizeof(float);
} else if (type == 'u') {
offset_ += num * sizeof(uint64_t);
}
}
}
++batch_size_;
// OPTIMIZE: It is better to insert check codes between instances for format
// checking
}
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_");
return true;
}
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,7 +59,7 @@ class DataFeed { ...@@ -59,7 +59,7 @@ class DataFeed {
file_idx_ = nullptr; file_idx_ = nullptr;
} }
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0; virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented."); PADDLE_THROW("This function(CheckFile) is not implemented.");
} }
...@@ -84,9 +84,6 @@ class DataFeed { ...@@ -84,9 +84,6 @@ class DataFeed {
// This function is used for binding feed_vec memory // This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name); virtual void AddFeedVar(Variable* var, const std::string& name);
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default // This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {} virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default // This function will do nothing at default
...@@ -118,9 +115,6 @@ class DataFeed { ...@@ -118,9 +115,6 @@ class DataFeed {
virtual void FillChannelToMemoryData() {} virtual void FillChannelToMemoryData() {}
// This function will do nothing at default // This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {} virtual void PutInsToChannel(const std::string& ins_str) {}
virtual int64_t GetChannelDataSize() { return 0; }
// This function will do nothing at default
virtual void ReleaseChannelData() {}
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
...@@ -151,8 +145,6 @@ class DataFeed { ...@@ -151,8 +145,6 @@ class DataFeed {
std::vector<std::vector<int>> use_slots_shape_; std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> inductive_shape_index_; std::vector<int> inductive_shape_index_;
std::vector<int> total_dims_without_inductive_; std::vector<int> total_dims_without_inductive_;
// For the inductive shape passed within data
std::vector<std::vector<int>> multi_inductive_shape_index_;
std::vector<int> std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_ use_slots_index_; // -1: not used; >=0: the index of use_slots_
...@@ -178,6 +170,7 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -178,6 +170,7 @@ class PrivateQueueDataFeed : public DataFeed {
public: public:
PrivateQueueDataFeed() {} PrivateQueueDataFeed() {}
virtual ~PrivateQueueDataFeed() {} virtual ~PrivateQueueDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual int Next(); virtual int Next();
...@@ -216,7 +209,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -216,7 +209,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public: public:
InMemoryDataFeed(); InMemoryDataFeed();
virtual ~InMemoryDataFeed() {} virtual ~InMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0; virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual int Next(); virtual int Next();
virtual void SetMemoryData(void* memory_data); virtual void SetMemoryData(void* memory_data);
...@@ -231,8 +224,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -231,8 +224,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual int64_t GetChannelDataSize();
virtual void ReleaseChannelData();
protected: protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
...@@ -257,9 +248,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -257,9 +248,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_; int64_t fleet_send_batch_size_;
// sleep after send is to slow down sending data, but it's trick,
// should be removed later.
int64_t fleet_send_sleep_seconds_;
}; };
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
...@@ -267,25 +255,16 @@ class MultiSlotType { ...@@ -267,25 +255,16 @@ class MultiSlotType {
public: public:
MultiSlotType() {} MultiSlotType() {}
~MultiSlotType() {} ~MultiSlotType() {}
void Init(const std::string& type, size_t reserved_size = 0) { void Init(const std::string& type) {
CheckType(type); CheckType(type);
if (type_[0] == 'f') { if (type_[0] == 'f') {
float_feasign_.clear(); float_feasign_.clear();
if (reserved_size) {
float_feasign_.reserve(reserved_size);
}
} else if (type_[0] == 'u') { } else if (type_[0] == 'u') {
uint64_feasign_.clear(); uint64_feasign_.clear();
if (reserved_size) {
uint64_feasign_.reserve(reserved_size);
}
} }
type_ = type; type_ = type;
} }
void InitOffset(size_t max_batch_size = 0) { void InitOffset() {
if (max_batch_size > 0) {
offset_.reserve(max_batch_size + 1);
}
offset_.resize(1); offset_.resize(1);
// LoDTensor' lod is counted from 0, the size of lod // LoDTensor' lod is counted from 0, the size of lod
// is one size larger than the size of data. // is one size larger than the size of data.
...@@ -301,16 +280,6 @@ class MultiSlotType { ...@@ -301,16 +280,6 @@ class MultiSlotType {
CheckUint64(); CheckUint64();
uint64_feasign_.push_back(v); uint64_feasign_.push_back(v);
} }
void CopyValues(const float* input, size_t size) {
CheckFloat();
float_feasign_.resize(size);
memcpy(float_feasign_.data(), input, size * sizeof(float));
}
void CopyValues(const uint64_t* input, size_t size) {
CheckUint64();
uint64_feasign_.resize(size);
memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
}
void AddIns(const MultiSlotType& ins) { void AddIns(const MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { // float if (ins.GetType()[0] == 'f') { // float
CheckFloat(); CheckFloat();
...@@ -324,22 +293,11 @@ class MultiSlotType { ...@@ -324,22 +293,11 @@ class MultiSlotType {
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end()); uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
} }
} }
void AppendValues(const uint64_t* input, size_t size) {
CheckUint64();
offset_.push_back(offset_.back() + size);
uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
}
void AppendValues(const float* input, size_t size) {
CheckFloat();
offset_.push_back(offset_.back() + size);
float_feasign_.insert(float_feasign_.end(), input, input + size);
}
const std::vector<float>& GetFloatData() const { return float_feasign_; } const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; } std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; } const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; } std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; } const std::string& GetType() const { return type_; }
size_t GetBatchSize() { return offset_.size() - 1; }
std::string& MutableType() { return type_; } std::string& MutableType() { return type_; }
std::string DebugString() { std::string DebugString() {
...@@ -389,7 +347,7 @@ class MultiSlotDataFeed ...@@ -389,7 +347,7 @@ class MultiSlotDataFeed
public: public:
MultiSlotDataFeed() {} MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {} virtual ~MultiSlotDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc); virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename); virtual bool CheckFile(const char* filename);
// virtual void ReadThread(); // virtual void ReadThread();
...@@ -408,7 +366,7 @@ class MultiSlotInMemoryDataFeed ...@@ -408,7 +366,7 @@ class MultiSlotInMemoryDataFeed
public: public:
MultiSlotInMemoryDataFeed() {} MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {} virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc); virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected: protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins, virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
...@@ -423,54 +381,5 @@ class MultiSlotInMemoryDataFeed ...@@ -423,54 +381,5 @@ class MultiSlotInMemoryDataFeed
const std::string& str); const std::string& str);
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
public:
PrivateInstantDataFeed() {}
virtual ~PrivateInstantDataFeed() {}
void Init(const DataFeedDesc& data_feed_desc) override;
bool Start() override { return true; }
int Next() override;
protected:
// The batched data buffer
std::vector<MultiSlotType> ins_vec_;
// This function is used to preprocess with a given filename, e.g. open it or
// mmap
virtual bool Preprocess(const std::string& filename) = 0;
// This function is used to postprocess system resource such as closing file
// NOTICE: Ensure that it is safe to call before Preprocess
virtual bool Postprocess() = 0;
// The reading and parsing method.
virtual bool ParseOneMiniBatch() = 0;
// This function is used to put ins_vec to feed_vec
virtual void PutToFeedVec();
};
class MultiSlotFileInstantDataFeed
: public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotFileInstantDataFeed() {}
virtual ~MultiSlotFileInstantDataFeed() {}
protected:
int fd_{-1};
char* buffer_{nullptr};
size_t end_{0};
size_t offset_{0};
bool Preprocess(const std::string& filename) override;
bool Postprocess() override;
bool ParseOneMiniBatch() override;
};
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -64,8 +64,5 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( ...@@ -64,8 +64,5 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,13 +13,11 @@ ...@@ -13,13 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -147,6 +145,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -147,6 +145,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: %s", in.type()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
...@@ -157,21 +156,14 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -157,21 +156,14 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
if (in_format != out_format) { if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
const std::string key = platform::ReorderMKLDNNHandler::GetHash( auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
in_tz, in_format, out_format, std::to_string(in_type));
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, auto in_memory =
cpu_engine, key); memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data); platform::Reorder(in_memory, out_memory);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, expected_kernel_type.place_);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*reorder_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} else { } else {
out->ShareDataWith(in); out->ShareDataWith(in);
} }
......
...@@ -141,9 +141,6 @@ template <typename T> ...@@ -141,9 +141,6 @@ template <typename T>
void DatasetImpl<T>::ReleaseMemory() { void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_); std::vector<T>().swap(memory_data_);
for (int i = 0; i < readers_.size(); ++i) {
readers_[i]->ReleaseChannelData();
}
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
} }
...@@ -181,10 +178,8 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -181,10 +178,8 @@ void DatasetImpl<T>::GlobalShuffle() {
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
auto fleet_ptr = FleetWrapper::GetInstance(); // if it is not InMemory, memory_data_ is empty
// local shuffle all data before global shuffle std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::shuffle(memory_data_.begin(), memory_data_.end(),
fleet_ptr->LocalRandomEngine());
VLOG(3) << "start global shuffle threads"; VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
...@@ -265,20 +260,6 @@ void DatasetImpl<T>::DestroyReaders() { ...@@ -265,20 +260,6 @@ void DatasetImpl<T>::DestroyReaders() {
} }
} }
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
return memory_data_.size();
}
template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0;
for (int i = 0; i < readers_.size(); ++i) {
sum += readers_[i]->GetChannelDataSize();
}
return sum;
}
template <typename T> template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) { const std::string& msg) {
...@@ -286,7 +267,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, ...@@ -286,7 +267,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length(); << ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_; int64_t index = rand_r(&rand_seed) % thread_num_;
VLOG(3) << "ramdom index=" << index; VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg); readers_[index]->PutInsToChannel(msg);
#endif #endif
......
...@@ -85,10 +85,6 @@ class Dataset { ...@@ -85,10 +85,6 @@ class Dataset {
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
// destroy readers // destroy readers
virtual void DestroyReaders() = 0; virtual void DestroyReaders() = 0;
// get memory data size
virtual int64_t GetMemoryDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -131,8 +127,6 @@ class DatasetImpl : public Dataset { ...@@ -131,8 +127,6 @@ class DatasetImpl : public Dataset {
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
......
...@@ -93,6 +93,6 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -93,6 +93,6 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass) record_skip_memory_opt_vars_pass)
...@@ -35,9 +35,16 @@ namespace details { ...@@ -35,9 +35,16 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLCommunicator *ctxs) const platform::NCCLContextMap *ctxs)
: NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) { : OpHandleBase(node),
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...@@ -64,9 +71,7 @@ void AllReduceOpHandle::RunAllReduceFuncs( ...@@ -64,9 +71,7 @@ void AllReduceOpHandle::RunAllReduceFuncs(
if (FLAGS_sync_nccl_allreduce) { if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) { for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto *nccl_ctxs = auto &nccl_ctx = nccl_ctxs_->at(dev_id);
nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream); cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) { if (e_sync != 0) {
...@@ -129,12 +134,21 @@ void AllReduceOpHandle::RunImpl() { ...@@ -129,12 +134,21 @@ void AllReduceOpHandle::RunImpl() {
numel = static_cast<size_t>(lod_tensor.numel()); numel = static_cast<size_t>(lod_tensor.numel());
} }
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
NCCLAllReduce(p, buffer, buffer, numel, PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
static_cast<ncclDataType_t>(dtype), ncclSum); buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
}); });
} }
VLOG(10) << "allreduce size:" << numel * SizeOfType(lod_tensors[0]->type());
RunAllReduceFuncs(all_reduce_calls); RunAllReduceFuncs(all_reduce_calls);
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -29,15 +28,13 @@ namespace paddle { ...@@ -29,15 +28,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) class AllReduceOpHandle : public OpHandleBase {
class AllReduceOpHandle : public NCCLOpHandleBase {
public: public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLCommunicator *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
class AllReduceOpHandle : public OpHandleBase {
public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
...@@ -49,17 +46,13 @@ class AllReduceOpHandle : public OpHandleBase { ...@@ -49,17 +46,13 @@ class AllReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32)) std::vector<Scope *> local_scopes_;
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunAllReduceFuncs( void RunAllReduceFuncs(
const std::vector<std::function<void()>> &all_reduce_calls); const std::vector<std::function<void()>> &all_reduce_calls);
const platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
}; };
......
...@@ -51,7 +51,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -51,7 +51,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
VLOG(3) << "ProcessGraph"; VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx; RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx; RpcCtxMap recv_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) { for (auto i = 0; i < graphs.size(); ++i) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name(); VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) { if (node && node->IsOp()) {
if (node->Name() == "send") { if (node->Name() == "send") {
...@@ -64,8 +66,10 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -64,8 +66,10 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("sections")); node->Op()->GetNullableAttr("sections"));
auto trainer_id = auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id")); boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( send_varname_to_ctx[send_var_name] =
send_var_name, send_varnames, epmap, height_section, trainer_id); operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section,
trainer_id);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") { } else if (node->Name() == "recv") {
...@@ -76,14 +80,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -76,14 +80,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("epmap")); node->Op()->GetNullableAttr("epmap"));
auto trainer_id = auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id")); boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext( recv_varname_to_ctx[recv_var_name] =
recv_var_name, recv_varnames, epmap, {}, trainer_id); operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {}, trainer_id);
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: " VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name]; << recv_varname_to_ctx[recv_var_name];
} }
} }
} }
}
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator"; VLOG(3) << "this is distribute mode, will use communicator";
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -27,8 +26,6 @@ limitations under the License. */ ...@@ -27,8 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -49,7 +46,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -49,7 +46,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add graph_viz_pass";
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf( const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
...@@ -57,27 +53,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -57,27 +53,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
VLOG(1) << "Add record_skip_memory_opt_vars_pass";
AppendPass("record_skip_memory_opt_vars_pass"); AppendPass("record_skip_memory_opt_vars_pass");
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
VLOG(1) << "Add mkldnn_placement_pass";
AppendPass("mkldnn_placement_pass");
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
if (strategy_.enable_sequential_execution_) { if (strategy_.enable_sequential_execution_) {
VLOG(1) << "Add sequential_execution_pass"; VLOG(5) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
} }
...@@ -88,7 +67,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -88,7 +67,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion. // Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) { if (strategy.fuse_relu_depthwise_conv_) {
VLOG(1) << "Add fuse_relu_depthwise_conv_pass"; VLOG(5) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass"); AppendPass("fuse_relu_depthwise_conv_pass");
} }
...@@ -100,19 +79,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -100,19 +79,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace. // Add automatically inplace.
if (strategy_.enable_inplace_) { if (strategy_.enable_inplace_) {
VLOG(1) << "Add inplace_pass"; VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass"); AppendPass("inplace_pass");
} }
if (strategy_.fuse_elewise_add_act_ops_) { if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(1) << "Add fuse_elewise_add_act_pass"; VLOG(5) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass"); AppendPass("fuse_elewise_add_act_pass");
} }
// for single card training, fuse_all_reduce_ops is unnecessary. // for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass. // alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
VLOG(1) << "Add alloc_continuous_space_for_grad_pass"; VLOG(5) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass"); AppendPass("alloc_continuous_space_for_grad_pass");
} }
...@@ -127,11 +106,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -127,11 +106,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first, // NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing. // if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused. // Currently, only one type of optimization algorithm can be fused.
VLOG(1) << "Add fuse_adam_op_pass"; VLOG(5) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass"); AppendPass("fuse_adam_op_pass");
VLOG(1) << "Add fuse_sgd_op_pass"; VLOG(5) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_sgd_op_pass");
VLOG(1) << "Add fuse_momentum_op_pass"; VLOG(5) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass"); AppendPass("fuse_momentum_op_pass");
} }
} }
...@@ -161,7 +140,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -161,7 +140,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars // A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface. // , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) { if (strategy_.memory_optimize_) {
VLOG(1) << "Add memory_optimize_pass"; VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass"); AppendPass("memory_optimize_pass");
} }
...@@ -169,22 +148,26 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -169,22 +148,26 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// all original and fused operators. But no operators can be enabled this // all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass. // attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) { if (strategy_.cache_runtime_context_) {
VLOG(1) << "Add runtime_context_cache_pass"; VLOG(5) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass"); AppendPass("runtime_context_cache_pass");
} }
if (strategy_.cache_expected_kernel_) {
VLOG(10) << "Add expected_kernel_cache_pass";
AppendPass("expected_kernel_cache_pass");
}
AppendMultiDevPass(strategy_); AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing. // first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(1) << "Add fuse_all_reduce_op_pass"; VLOG(5) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass"); AppendPass("fuse_all_reduce_op_pass");
} }
// Add a graph print pass to record a graph with device info. // Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add multi_devices_print_pass";
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
const std::string graph_path = const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(), string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
...@@ -200,22 +183,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -200,22 +183,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (!strategy_.enable_parallel_graph_ && if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) || (SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) { strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(1) << "Add all_reduce_deps_pass"; VLOG(5) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
} }
if (strategy_.enable_backward_optimizer_op_deps_) {
VLOG(1) << "Add backward_op_deps_pass";
AppendPass("backward_optimizer_op_deps_pass");
}
if (strategy_.remove_unnecessary_lock_) { if (strategy_.remove_unnecessary_lock_) {
VLOG(1) << "Add modify_op_lock_and_record_event_pass"; VLOG(5) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass"); AppendPass("modify_op_lock_and_record_event_pass");
} }
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
VLOG(1) << "Add multi_devices_check_pass";
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
} }
...@@ -224,19 +201,18 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -224,19 +201,18 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ir::Pass *multi_devices_pass = nullptr; ir::Pass *multi_devices_pass = nullptr;
if (strategy_.async_mode_) { if (strategy_.async_mode_) {
VLOG(1) << "Add async_multi_devices_pass";
multi_devices_pass = AppendPass("async_multi_devices_pass").get(); multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) { } else if (strategy_.is_distribution_) {
VLOG(1) VLOG(5)
<< "Add dist_multi_devices_pass, multi device parameter server mode"; << "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(1) << "Add all_reduce_mode_multi_devices_pass"; VLOG(5) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass = multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get(); AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(1) << "Add reduce_mode_multi_devices_pass"; VLOG(5) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else { } else {
PADDLE_THROW("Unknown reduce strategy."); PADDLE_THROW("Unknown reduce strategy.");
...@@ -273,7 +249,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -273,7 +249,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, const bool use_cuda,
platform::NCCLCommunicator *nccl_ctxs) const { platform::NCCLContextMap *nccl_ctxs) const {
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
...@@ -295,9 +271,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -295,9 +271,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Set<size_t>(ir::kNRanks, new size_t(nranks)); pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif #endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" || } else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" || pass->Type() == "fuse_adam_op_pass" ||
...@@ -311,12 +287,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -311,12 +287,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
&local_scopes); &local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") { if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif #endif
} }
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") { } else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
...@@ -329,14 +302,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -329,14 +302,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
LOG(INFO) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
} else if (pass->Type() == "all_reduce_deps_pass") { } else if (pass->Type() == "all_reduce_deps_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif
LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this) LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_; << ", num_trainers:" << num_trainers_;
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
...@@ -348,9 +313,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -348,9 +313,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "inplace_pass") { } else if (pass->Type() == "inplace_pass") {
pass->Erase(ir::kUseCuda); pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda)); pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -377,7 +339,6 @@ USE_PASS(multi_devices_print_pass); ...@@ -377,7 +339,6 @@ USE_PASS(multi_devices_print_pass);
USE_PASS(memory_optimize_pass); USE_PASS(memory_optimize_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass); USE_PASS(all_reduce_deps_pass);
USE_PASS(backward_optimizer_op_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass); USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass); USE_PASS(lock_free_optimize_pass);
...@@ -388,7 +349,5 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -388,7 +349,5 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
USE_PASS(record_skip_memory_opt_vars_pass); USE_PASS(record_skip_memory_opt_vars_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
...@@ -80,8 +79,6 @@ struct BuildStrategy { ...@@ -80,8 +79,6 @@ struct BuildStrategy {
bool fuse_all_reduce_ops_{false}; bool fuse_all_reduce_ops_{false};
bool enable_backward_optimizer_op_deps_{false};
bool fuse_relu_depthwise_conv_{false}; bool fuse_relu_depthwise_conv_{false};
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
...@@ -111,18 +108,7 @@ struct BuildStrategy { ...@@ -111,18 +108,7 @@ struct BuildStrategy {
bool remove_unnecessary_lock_{true}; bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false}; bool cache_runtime_context_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_; bool cache_expected_kernel_{true};
size_t nccl_comm_num_{1};
// The picture is here:
// https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
bool use_hierarchical_allreduce_{false};
// Nccl ranks in a node when use hierarchical allreduce, it's setted to gpu
// cards' number in most cases.
size_t hierarchical_allreduce_inter_nranks_{0};
// Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to
// nodes number.
size_t hierarchical_allreduce_exter_nranks_{0};
// NOTE: // NOTE:
// Before you add new options, think if it's a general strategy that works // Before you add new options, think if it's a general strategy that works
...@@ -149,7 +135,7 @@ struct BuildStrategy { ...@@ -149,7 +135,7 @@ struct BuildStrategy {
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, const bool use_cuda,
platform::NCCLCommunicator *nccl_ctxs) const; platform::NCCLContextMap *nccl_ctxs) const;
#else #else
const bool use_cuda) const; const bool use_cuda) const;
#endif #endif
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
...@@ -66,7 +65,6 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -66,7 +65,6 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
Scope *exec_scope = nullptr; Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -44,97 +43,35 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -44,97 +43,35 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_.emplace_back(op); bootstrap_ops_.emplace_back(op);
} }
} }
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
} }
FeedFetchList FastThreadedSSAGraphExecutor::Run( FeedFetchList FastThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>> std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get(); op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
size_t num_ops = op_deps->size();
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<OpHandleBase *> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops; std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear();
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops);
event.reset(nullptr);
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
size_t num_complete = 0;
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
int remaining = 0;
while (true) {
remaining = remaining_;
if (remaining == 0) {
break;
}
for (int i = 0; i < remaining; ++i) {
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
}
num_complete += num_comp;
}
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
}
void FastThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops) {
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors.at(i); auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars->find(var_name); auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(), PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable(%s).(Perhaps the main_program " "Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)", "is not set to ParallelExecutor)",
var_name); var_name);
...@@ -143,8 +80,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -143,8 +80,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node = ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_); auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops->emplace_back(op); fetch_ops.emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p)); op->SetDeviceContext(p, fetch_ctxs_.Get(p));
...@@ -157,22 +94,55 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -157,22 +94,55 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep; (*op_deps)[op] = dep;
if (dep == 0) { if (dep == 0) {
ready_fetch_ops->emplace_back(op); ready_fetch_ops.emplace_back(op);
} }
} }
size_t num_complete = 0;
remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
int remaining = 0;
while (true) {
remaining = remaining_;
if (remaining == 0) {
break;
}
for (int i = 0; i < remaining; ++i) {
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ClearFetchOp(graph_, &fetch_ops);
exception_.ReThrow();
}
}
num_complete += num_comp;
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
} }
bool FastThreadedSSAGraphExecutor::RunOp( bool FastThreadedSSAGraphExecutor::RunOp(
OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q, OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete) { size_t *complete) {
RunOpSync(op); try {
if (LIKELY(!exception_.IsCaught())) {
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
RecordOps(op); op->Run(strategy_.use_cuda_);
} }
++(*complete); ++(*complete);
return true; return true;
} else { } catch (...) {
exception_.Catch(std::current_exception());
--remaining_; --remaining_;
complete_q->Push(-1UL); complete_q->Push(-1UL);
return false; return false;
...@@ -224,7 +194,6 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -224,7 +194,6 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
complete_q->Push(complete); complete_q->Push(complete);
}); });
} }
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] { atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>; auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
...@@ -237,44 +206,6 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { ...@@ -237,44 +206,6 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
} }
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_.Catch(std::current_exception());
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -60,8 +60,6 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -60,8 +60,6 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool pool_; ::ThreadPool pool_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
std::vector<OpHandleBase *> traced_ops_;
bool RunOp(OpHandleBase *op, bool RunOp(OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete); size_t *complete);
...@@ -71,22 +69,6 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -71,22 +69,6 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::shared_ptr<BlockingQueue<size_t>> &complete_q); const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps(); void PrepareAtomicOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -44,10 +44,17 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>> ...@@ -44,10 +44,17 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
FusedAllReduceOpHandle::FusedAllReduceOpHandle( FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce, const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
const platform::NCCLCommunicator *ctxs) const platform::NCCLContextMap *ctxs)
: NCCLOpHandleBase(node, places, ctxs), : OpHandleBase(node),
local_scopes_(local_scopes), local_scopes_(local_scopes),
num_of_all_reduce_(num_of_all_reduce) { places_(places),
num_of_all_reduce_(num_of_all_reduce),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
} }
#else #else
...@@ -160,14 +167,17 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -160,14 +167,17 @@ void FusedAllReduceOpHandle::RunImpl() {
auto &p = places_[i]; auto &p = places_[i];
void *buffer = const_cast<void *>(lod_tensor_data.at(i)); void *buffer = const_cast<void *>(lod_tensor_data.at(i));
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
NCCLAllReduce(p, buffer, buffer, numel, PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
static_cast<ncclDataType_t>(nccl_dtype), ncclSum); buffer, buffer, numel, static_cast<ncclDataType_t>(nccl_dtype),
ncclSum, comm, stream));
}); });
} }
VLOG(10) << "fusedallreduce size:" << numel * SizeOfType(dtype);
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) { if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device // Do not use NCCLGroup when manage NCCL by per thread per device
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -29,15 +28,14 @@ namespace paddle { ...@@ -29,15 +28,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct FusedAllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
struct FusedAllReduceOpHandle : public NCCLOpHandleBase {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const platform::NCCLCommunicator *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
struct FusedAllReduceOpHandle : public OpHandleBase {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -54,12 +52,11 @@ struct FusedAllReduceOpHandle : public OpHandleBase { ...@@ -54,12 +52,11 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
private: private:
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#endif
size_t num_of_all_reduce_; size_t num_of_all_reduce_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs_;
#endif
// Check the dtype of the input // Check the dtype of the input
void GetDTypeAndNumel( void GetDTypeAndNumel(
......
...@@ -45,7 +45,6 @@ constexpr char kGraphVars[] = "vars"; ...@@ -45,7 +45,6 @@ constexpr char kGraphVars[] = "vars";
constexpr char kPlaces[] = "places"; constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs"; constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars; typedef std::unordered_set<VarHandleBase *> GraphDepVars;
......
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
namespace details { namespace details {
std::string OpHandleBase::DebugString() const { std::string OpHandleBase::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << Name() << "("; ss << "(";
for (auto *var : inputs_) { for (auto *var : inputs_) {
ss << var->DebugString() << ", "; ss << var->DebugString() << ", ";
} }
...@@ -187,11 +187,6 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -187,11 +187,6 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
std::function<void()> method = callback; std::function<void()> method = callback;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
method = [method, p, this]() { method = [method, p, this]() {
VLOG(10) << "cudadevicecontext:"
<< static_cast<platform::CUDADeviceContext *>(p.second)
<< ", dev_id:"
<< boost::get<platform::CUDAPlace>(p.first).device;
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent( static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device), events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method); method);
......
...@@ -95,7 +95,6 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -95,7 +95,6 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
for (size_t i = 0; i < graphs_.size(); ++i) { for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release())); graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
} }
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -30,8 +29,6 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, ...@@ -30,8 +29,6 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place(); auto &p = static_cast<VarHandle *>(in)->place();
if (ir::IsControlDepVar(*in->Node())) { if (ir::IsControlDepVar(*in->Node())) {
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include <string> #include <string>
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -67,7 +67,6 @@ struct ScaleLossGradFunctor { ...@@ -67,7 +67,6 @@ struct ScaleLossGradFunctor {
}; };
void ScaleLossGradOpHandle::RunImpl() { void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
// Doesn't wait any event // Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name(); std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -36,10 +36,26 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -36,10 +36,26 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) { if (drop_scope_counter_ == 0) {
platform::RecordEvent e("InitLocalExeScopes"); // Create local scopes.
PrepareLocalExeScopes(); for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
} }
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
}
std::vector<framework::LoDTensor> fetch_data; std::vector<framework::LoDTensor> fetch_data;
std::exception_ptr eptr = nullptr; std::exception_ptr eptr = nullptr;
try { try {
...@@ -48,7 +64,9 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -48,7 +64,9 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
eptr = std::current_exception(); eptr = std::current_exception();
} }
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun");
++drop_scope_counter_; ++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
DropLocalExeScopes(); DropLocalExeScopes();
} }
...@@ -60,41 +78,17 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -60,41 +78,17 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
} }
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() { void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
for (auto p : places_) { for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName); auto &local_scope =
if (local_scope_var != nullptr) { *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
auto &local_scope = *local_scope_var->GetMutable<Scope *>();
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
scope->EraseVars({std::string(details::kLocalExecScopeName)});
VLOG(3) << "Drop local execution scope: " << local_scope; VLOG(3) << "Drop local execution scope: " << local_scope;
} }
}
}
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
} }
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() { bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <ThreadPool.h>
#include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -52,8 +51,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -52,8 +51,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
bool NeedCreateLocalExeScope(); bool NeedCreateLocalExeScope();
void PrepareLocalExeScopes();
private: private:
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -30,7 +30,7 @@ namespace details { ...@@ -30,7 +30,7 @@ namespace details {
SparseAllReduceOpHandle::SparseAllReduceOpHandle( SparseAllReduceOpHandle::SparseAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLCommunicator *ctxs, bool is_encoded, int nranks) const platform::NCCLContextMap *ctxs, bool is_encoded, int nranks)
: AllReduceOpHandle(node, local_scopes, places, ctxs), : AllReduceOpHandle(node, local_scopes, places, ctxs),
is_encoded_(is_encoded), is_encoded_(is_encoded),
nranks_(nranks) { nranks_(nranks) {
...@@ -102,8 +102,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -102,8 +102,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel; out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device; int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
......
...@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { ...@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
SparseAllReduceOpHandle(ir::Node *node, SparseAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLCommunicator *ctxs, const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1); bool is_encoded = false, int nranks = -1);
std::string Name() const override; std::string Name() const override;
......
...@@ -19,13 +19,10 @@ namespace framework { ...@@ -19,13 +19,10 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) { void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
......
...@@ -38,7 +38,7 @@ class SSAGraphExecutor { ...@@ -38,7 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops); void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -53,40 +53,27 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -53,40 +53,27 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get(); std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
CopyOpDeps(); CopyOpDeps();
VLOG(10) << "ThreadedSSAGraphExecutor::Run"; VLOG(10) << "ThreadedSSAGraphExecutor::Run";
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars( std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
new BlockingQueue<VarHandleBase *>); new BlockingQueue<VarHandleBase *>);
auto &pending_ops = op_deps->pending_ops_; auto &pending_ops = op_deps->pending_ops_;
auto &pending_vars = op_deps->pending_vars_; auto &pending_vars = op_deps->pending_vars_;
auto &ready_ops = op_deps->ready_ops_; auto &ready_ops = op_deps->ready_ops_;
size_t num_ops = op_deps->num_ops_;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<OpHandleBase *> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_ops, &pending_vars, &fetch_data); &pending_ops, &pending_vars, &fetch_data);
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
RunOp(ready_vars, op); RunOp(ready_vars, op);
...@@ -95,7 +82,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -95,7 +82,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
}; };
// Clean run context // Clean run context
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
// Keep loop until all vars are ready. // Keep loop until all vars are ready.
...@@ -105,11 +94,14 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -105,11 +94,14 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
if (exception_holder_.IsCaught()) { ClearFetchOp(graph_, &fetch_ops);
ExecutionFinal(&fetch_ops); exception_holder_.ReThrow();
} else { } else {
continue; continue;
} }
...@@ -129,8 +121,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -129,8 +121,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
} }
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
}
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
...@@ -147,7 +137,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -147,7 +137,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<OpHandleBase *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
...@@ -253,9 +243,6 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() { ...@@ -253,9 +243,6 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
InsertPendingOp(&pending_ops, op); InsertPendingOp(&pending_ops, op);
} }
} }
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators.");
for (auto ready_var : ready_vars) { for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) { for (auto *op : ready_var->PendingOps()) {
...@@ -277,7 +264,6 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() { ...@@ -277,7 +264,6 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_->pending_vars_.end()); op_deps_->pending_vars_.end());
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(), op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
op_deps_->ready_ops_.end()); op_deps_->ready_ops_.end());
op_deps->num_ops_ = op_deps_->num_ops_;
return std::unique_ptr<OpDependentData>(op_deps); return std::unique_ptr<OpDependentData>(op_deps);
}); });
} }
...@@ -286,35 +272,6 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -286,35 +272,6 @@ void ThreadedSSAGraphExecutor::RunOp(
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) { details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
RunOpSync(op);
try {
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
};
if (pool_) {
run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try { try {
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
...@@ -323,21 +280,16 @@ void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { ...@@ -323,21 +280,16 @@ void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
} };
if (pool_) {
void ThreadedSSAGraphExecutor::ExecutionFinal( run_op_futures_.emplace_back(pool_->enqueue(op_run));
std::vector<OpHandleBase *> *fetch_ops) { } else {
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it"; op_run();
ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow();
}
void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
} }
} }
} // namespace details } // namespace details
......
...@@ -44,7 +44,6 @@ struct OpDependentData { ...@@ -44,7 +44,6 @@ struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_; std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_; std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_; std::unordered_set<OpHandleBase *> ready_ops_;
size_t num_ops_{0};
}; };
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...@@ -81,7 +80,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -81,7 +80,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::list<std::future<void>> run_op_futures_; std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<OpHandleBase *> traced_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops, void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const; OpHandleBase *op_instance) const;
...@@ -91,7 +89,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -91,7 +89,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors, void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<OpHandleBase *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
...@@ -99,16 +97,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -99,16 +97,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList *fetch_data); FeedFetchList *fetch_data);
void PrepareOpDeps(); void PrepareOpDeps();
void CopyOpDeps(); void CopyOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
}; };
} // namespace details } // namespace details
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <atomic>
#include <fstream> #include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -36,17 +35,9 @@ limitations under the License. */ ...@@ -36,17 +35,9 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define SEC_LOG \
VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \
<< "]: "
class PullDenseWorker { class PullDenseWorker {
public: public:
virtual ~PullDenseWorker() {} virtual ~PullDenseWorker() {}
...@@ -57,7 +48,6 @@ class PullDenseWorker { ...@@ -57,7 +48,6 @@ class PullDenseWorker {
void IncreaseThreadVersion(int thread_id, uint64_t table_id); void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id); void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec); void Wait(std::vector<::std::future<int32_t>>* status_vec);
void PullDense(bool force_update = false);
static std::shared_ptr<PullDenseWorker> GetInstance() { static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker()); s_instance_.reset(new paddle::framework::PullDenseWorker());
...@@ -102,7 +92,7 @@ class PullDenseWorker { ...@@ -102,7 +92,7 @@ class PullDenseWorker {
// should incorporate different type of device // should incorporate different type of device
class DeviceWorker { class DeviceWorker {
public: public:
DeviceWorker() { use_cvm_ = false; } DeviceWorker() {}
virtual ~DeviceWorker() {} virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0; virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0; virtual void SetDeviceIndex(int tid) = 0;
...@@ -124,7 +114,6 @@ class DeviceWorker { ...@@ -124,7 +114,6 @@ class DeviceWorker {
std::shared_ptr<DataFeed> device_reader_; std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_; int64_t batch_num_;
FetchConfig fetch_config_; FetchConfig fetch_config_;
bool use_cvm_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
...@@ -205,101 +194,5 @@ class DownpourWorker : public HogwildWorker { ...@@ -205,101 +194,5 @@ class DownpourWorker : public HogwildWorker {
std::vector<::std::future<int32_t>> push_dense_status_; std::vector<::std::future<int32_t>> push_dense_status_;
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
using ScopeQueue = operators::reader::BlockingQueue<Scope*>;
class SyncFunctor {
public:
SyncFunctor(int rank_id, int rank_num, int sync_steps);
virtual ~SyncFunctor() {}
void SetSyncParam(const std::vector<std::string>& sync_param) {
sync_param_ = &sync_param;
}
void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) {
nccl_ctx_map_ = nccl_ctx_map;
}
int operator()(Scope* scope);
static std::vector<Scope*> pipeline_scopes_;
static uint64_t sync_flag_;
protected:
const int rank_id_;
const int rank_num_;
const std::vector<std::string>* sync_param_ = nullptr;
platform::NCCLContextMap* nccl_ctx_map_ = nullptr;
uint64_t sync_signal_;
const int sync_steps_;
int counter_;
void Synchronize();
};
class SectionWorker : public DeviceWorker {
public:
SectionWorker() {}
~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
void BindingDataFeedMemory() override {}
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
void TrainFiles() override;
void TrainFilesWithProfiler() override;
void PrintFetchVars() override {}
const platform::Place& place() const { return place_; }
void SetSectionIndex(int section_id) { section_id_ = section_id; }
void SetDeviceIndex(int tid) override { pipeline_id_ = tid; }
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetVarNames(const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
in_var_names_ = &in_var_names;
out_var_names_ = &out_var_names;
}
void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) {
in_scope_queue_ = in_scope_queue;
out_scope_queue_ = out_scope_queue;
}
void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; }
void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; }
void SetSectionNum(int section_num) { section_num_ = section_num; }
void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; }
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
}
SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }
static std::atomic<int> cpu_id_;
protected:
void AutoSetCPUAffinity(bool reuse);
int section_id_;
int pipeline_id_;
int section_num_;
int pipeline_num_;
int thread_id_;
// This worker will consume scope from in_scope_queue_
// and produce scope to out_scope_queue_
ScopeQueue* in_scope_queue_ = nullptr;
ScopeQueue* out_scope_queue_ = nullptr;
const std::vector<std::string>* in_var_names_ = nullptr;
const std::vector<std::string>* out_var_names_ = nullptr;
std::mutex* worker_count_mutex_ = nullptr;
int* worker_count_ = nullptr;
paddle::platform::Place next_section_place_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -61,8 +61,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker( ...@@ -61,8 +61,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -63,7 +63,6 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -63,7 +63,6 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm();
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
...@@ -140,16 +139,6 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -140,16 +139,6 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
LoD data_lod{tensor_lod}; LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod); tensor_emb->set_lod(data_lod);
for (int index = 0; index < len; ++index) { for (int index = 0; index < len; ++index) {
if (use_cvm_) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(),
sizeof(float) * table.emb_dim());
fea_idx++;
} else {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2, memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
...@@ -160,7 +149,6 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -160,7 +149,6 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
fea_idx++; fea_idx++;
} }
} }
}
} }
void DownpourWorker::TrainFilesWithProfiler() { void DownpourWorker::TrainFilesWithProfiler() {
...@@ -209,9 +197,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -209,9 +197,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i)); param_.program_config(0).pull_sparse_table_id(i));
TableParameter table; TableParameter table;
for (auto j : param_.sparse_table()) { for (auto i : param_.sparse_table()) {
if (j.table_id() == tid) { if (i.table_id() == tid) {
table = j; table = i;
break; break;
} }
} }
...@@ -271,7 +259,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -271,7 +259,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_); &feature_grads_[tid], &push_sparse_status_);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -379,9 +367,9 @@ void DownpourWorker::TrainFiles() { ...@@ -379,9 +367,9 @@ void DownpourWorker::TrainFiles() {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i)); param_.program_config(0).pull_sparse_table_id(i));
TableParameter table; TableParameter table;
for (auto j : param_.sparse_table()) { for (auto i : param_.sparse_table()) {
if (j.table_id() == tid) { if (i.table_id() == tid) {
table = j; table = i;
break; break;
} }
} }
...@@ -423,7 +411,7 @@ void DownpourWorker::TrainFiles() { ...@@ -423,7 +411,7 @@ void DownpourWorker::TrainFiles() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_); &feature_grads_[tid], &push_sparse_status_);
} }
} }
......
...@@ -122,9 +122,8 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope, ...@@ -122,9 +122,8 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
const std::string& trainer_desc_str) { const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor"; VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc; TrainerDesc trainer_desc;
bool success = trainer_desc.ParseFromString(trainer_desc_str); google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
PADDLE_ENFORCE(success, "Fail to parse TrainerDesc from string:\n%s", &trainer_desc);
trainer_desc_str.c_str());
VLOG(3) << "Going to create trainer, trainer class is " VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name(); << trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer; std::shared_ptr<TrainerBase> trainer;
...@@ -245,12 +244,6 @@ static bool has_fetch_operators( ...@@ -245,12 +244,6 @@ static bool has_fetch_operators(
return fetch_count > 0; return fetch_count > 0;
} }
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc);
}
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
...@@ -335,7 +328,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -335,7 +328,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph && ctx->block_id_ == 0) { if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps( paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_); ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
} }
...@@ -375,7 +368,6 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -375,7 +368,6 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
bool keep_kids) { bool keep_kids) {
platform::RecordBlock b(kProgramId);
PADDLE_ENFORCE_NOT_NULL(scope); PADDLE_ENFORCE_NOT_NULL(scope);
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
...@@ -415,6 +407,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -415,6 +407,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get()); DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
} }
......
...@@ -83,21 +83,6 @@ class Executor { ...@@ -83,21 +83,6 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
std::unique_ptr<ExecutorPrepareContext> PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
...@@ -116,6 +101,15 @@ class Executor { ...@@ -116,6 +101,15 @@ class Executor {
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true, bool keep_kids = false); bool create_vars = true, bool keep_kids = false);
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope, void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
......
...@@ -281,16 +281,9 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -281,16 +281,9 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
const std::vector<std::string>& sparse_key_names, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status) {
const int batch_size, const bool use_cvm) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int grad_dim = emb_dim;
if (use_cvm) {
offset = 0;
grad_dim = emb_dim - 2;
}
CHECK_GE(grad_dim, 0);
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]); Variable* g_var = scope.FindVar(sparse_grad_names[i]);
...@@ -314,13 +307,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -314,13 +307,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
for (auto& t : *push_values) { for (auto& t : *push_values) {
t.resize(emb_dim + offset); t.resize(emb_dim + offset);
} }
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset;
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / dim, dim);
g_mat.rightCols(grad_dim) *= batch_size;
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) { for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) { if (ids[id_idx] == 0) {
g += emb_dim; g += emb_dim;
...@@ -328,15 +315,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -328,15 +315,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size()); CHECK(fea_idx < fea_labels.size());
if (use_cvm) {
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
} else {
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f; (*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]); (*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
}
g += emb_dim; g += emb_dim;
fea_idx++; fea_idx++;
} }
...@@ -355,89 +337,6 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -355,89 +337,6 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::SaveModel does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
ret.wait();
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list,
float decay) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (std::string& name : var_list) {
if (name.find("batch_sum") != std::string::npos) {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
VLOG(3) << "prepare shrink dense batch_sum";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
Eigen::Map<Eigen::MatrixXf> mat(g, 1, tensor->numel());
mat *= decay;
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
} else {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push shrink dense param failed, status[" << status << "]";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ClientFlush() {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->flush();
ret.wait();
#else
VLOG(0) << "FleetWrapper::ServerFlush does nothing when no pslib";
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
...@@ -499,24 +398,6 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) { ...@@ -499,24 +398,6 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#endif #endif
} }
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
#ifdef PADDLE_WITH_PSLIB
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
#endif
};
thread_local engine_wrapper_t r;
return r.engine;
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>( template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*); const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>( template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
......
...@@ -55,7 +55,7 @@ namespace framework { ...@@ -55,7 +55,7 @@ namespace framework {
class FleetWrapper { class FleetWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() { scale_sparse_gradient_with_batch_size_ = true; } FleetWrapper() {}
// Pull sparse variables from server in Sync mode // Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys // Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values // Param<out>: fea_values
...@@ -99,8 +99,7 @@ class FleetWrapper { ...@@ -99,8 +99,7 @@ class FleetWrapper {
const std::vector<std::string>& sparse_key_names, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status);
const int batch_size, const bool use_cvm);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
...@@ -129,19 +128,6 @@ class FleetWrapper { ...@@ -129,19 +128,6 @@ class FleetWrapper {
// create client to client connection // create client to client connection
void CreateClient2ClientConnection(); void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
void ShrinkSparseTable(int table_id);
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay);
// register client to client communication // register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
...@@ -160,9 +146,6 @@ class FleetWrapper { ...@@ -160,9 +146,6 @@ class FleetWrapper {
return s_instance_; return s_instance_;
} }
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_; static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif #endif
...@@ -175,7 +158,6 @@ class FleetWrapper { ...@@ -175,7 +158,6 @@ class FleetWrapper {
protected: protected:
static bool is_initialized_; static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper); DISABLE_COPY_AND_ASSIGN(FleetWrapper);
}; };
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
option optimize_for = LITE_RUNTIME; // option optimize_for = LITE_RUNTIME;
package paddle.framework.proto; package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should // Any incompatible changes to ProgramDesc and its dependencies should
......
...@@ -24,10 +24,9 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) { ...@@ -24,10 +24,9 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param(); param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size()); skip_ops_.resize(param_.skip_ops_size());
for (int i = 0; i < param_.skip_ops_size(); ++i) { for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
use_cvm_ = desc.use_cvm();
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
......
...@@ -72,12 +72,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference) ...@@ -72,12 +72,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference) pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
if(ANAKIN_SUBGRAPH) if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference) pass_library(simplify_anakin_priorbox_detection_out_pass inference)
endif() endif()
...@@ -86,23 +86,12 @@ if(WITH_MKLDNN) ...@@ -86,23 +86,12 @@ if(WITH_MKLDNN)
pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn) pass_library(cpu_quantize_squash_pass inference mkldnn)
endif() endif()
if(WITH_NGRAPH)
cc_library(ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge
analysis_helper subgraph_detector graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(APPEND ${pass_file} "USE_PASS(ngraph_subgraph_pass);\n")
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
...@@ -126,8 +115,6 @@ if (WITH_MKLDNN) ...@@ -126,8 +115,6 @@ if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
......
...@@ -23,16 +23,15 @@ ...@@ -23,16 +23,15 @@
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size(MB)" "fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input " "of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). " "of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that " "The default value is 0, it means that "
"not set group according to memory_size."); "not set group according to memory_size.");
DEFINE_int32( DEFINE_int32(
fuse_parameter_groups_size, 1, fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the up limited size of one group " "fuse_parameter_groups_size is the size of one group parameters' gradient. "
"parameters' gradient. "
"The default value is a experimental result. If the " "The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is " "fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is " "the number of parameters' gradient. If the fuse_parameter_groups_size is "
...@@ -42,9 +41,6 @@ DEFINE_int32( ...@@ -42,9 +41,6 @@ DEFINE_int32(
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// unit of the FLAGS_fuse_parameter_memory_size.
static constexpr double kMB = 1048576.0;
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit // SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size' // test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test. // and 'FLAGS_fuse_parameter_groups_size' in unit test.
...@@ -54,12 +50,15 @@ void SetFuseParameterGroupsSize(int group_size) { ...@@ -54,12 +50,15 @@ void SetFuseParameterGroupsSize(int group_size) {
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; } int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(double memory_size) { void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size; FLAGS_fuse_parameter_memory_size = memory_size;
} }
double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; } uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype = static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL; framework::proto::VarType::Type::VarType_Type_BOOL;
...@@ -84,7 +83,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -84,7 +83,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
if (params_grads.size() == 0) { if (params_grads.size() == 0) {
LOG(WARNING) << "Doesn't find gradients"; VLOG(10) << "Doesn't find gradients";
return; return;
} }
...@@ -170,6 +169,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -170,6 +169,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
details::GroupGradsAndParams *group_grads_params) const { details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params); SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params); SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
} }
void SetGroupAccordingToLayers( void SetGroupAccordingToLayers(
...@@ -181,7 +181,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -181,7 +181,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
for (size_t i = 0; i < params_grads.size(); ++i) { for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of("."); auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) { if (pos == std::string::npos) {
layer_params[params_grads[i].first].emplace_back(i); layer_params[std::string(kUnKnow)].emplace_back(i);
} else { } else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i); layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
} }
...@@ -190,7 +190,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -190,7 +190,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_grads_params->reserve(layer_params.size()); group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) { for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of("."); auto pos = params_grads[i].first.find_first_of(".");
std::string key = params_grads[i].first; std::string key = kUnKnow;
if (pos != std::string::npos) { if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos); key = params_grads[i].first.substr(0, pos);
} }
...@@ -207,40 +207,21 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -207,40 +207,21 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
VLOG(10) << "SetGroupAccordingToLayers: "; VLOG(10) << "SetGroupAccordingToLayers: ";
if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params);
}
}
void PrintGroupInfo(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
for (size_t i = 0; i < group_grads_params->size(); ++i) { for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i; VLOG(10) << "group " << i;
std::stringstream out; std::stringstream out;
size_t gps_size = 0; for (auto &p_g : group_grads_params->at(i)) {
for (auto &g_p : group_grads_params->at(i)) { out << "(" << p_g.second << ", " << p_g.first << "), ";
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
gps_size += size;
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
} }
VLOG(10) << out.str() VLOG(10) << out.str();
<< ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << static_cast<double>(gps_size) / kMB
<< "(MB)";
} }
} }
void SetGroupAccordingToMemorySize( void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes, const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const { details::GroupGradsAndParams *group_grads_params) const {
const double group_memory_size = GetFuseParameterMemorySize(); const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size <= 0.0) { if (group_memory_size == 0) {
return; return;
} }
details::GroupGradsAndParams local_group_grads_params; details::GroupGradsAndParams local_group_grads_params;
...@@ -267,14 +248,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -267,14 +248,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(), group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end()); group_grads_params->at(j).end());
++j; ++j;
if (GetFuseParameterGroupsSize() > 1 && if (local_group_memory_size >= group_memory_size) {
group_p_g.size() >
static_cast<size_t>(GetFuseParameterGroupsSize())) {
break;
}
if (static_cast<double>(local_group_memory_size) / kMB >=
group_memory_size) {
break; break;
} }
} }
...@@ -283,10 +257,60 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -283,10 +257,60 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
std::swap(*group_grads_params, local_group_grads_params); std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf( VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size); "SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
if (VLOG_IS_ON(10)) { VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
PrintGroupInfo(var_nodes, group_grads_params); group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
} }
} }
......
...@@ -21,8 +21,8 @@ namespace ir { ...@@ -21,8 +21,8 @@ namespace ir {
void SetFuseParameterGroupsSize(int group_size); void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize(); int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(double memory_size); void SetFuseParameterMemorySize(uint64_t memory_size);
double GetFuseParameterMemorySize(); uint64_t GetFuseParameterMemorySize();
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param) { void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters // Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto* scope = graph->Get<Scope*>(kParamScopeAttr);
// Create new parameters. // Create new parameters.
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>(); scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>(); scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope.Var(param.Hidden)->GetMutable<LoDTensor>(); scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope.Var(param.Cell)->GetMutable<LoDTensor>(); scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>(); scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>(); scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope.Var(param.LSTMX)->GetMutable<LoDTensor>(); scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>(); scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \ #define GATE_W(name__) \
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \ auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \ auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \ auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \ VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
...@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) { ...@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
GATE_W(c); GATE_W(c);
#undef GATE_W #undef GATE_W
auto* attention_fc_w = scope.FindVar("attention_fc.w_0"); auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope.FindVar("attention_fc.b_0"); auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope.FindVar("attention_output.w_0"); auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope.FindVar("attention_output.b_0"); auto* attention_output_b = scope->FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w, CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b); attention_output_b);
auto* lstm_weight = scope.Var(param.LSTMWeight); auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>(); auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope.Var(param.LSTMBias); auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>(); auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias // reshape attention_bias
auto* attention_bias_t = auto* attention_bias_t =
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>(); scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t = auto* attention_scalar_bias_t =
scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>(); scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize( attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]})); make_ddim({1, attention_scalar_bias_t->dims()[0]}));
......
...@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define OP_SET_OUT(x) \ #define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \ const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \ op_desc.SetOutput(#x, {x}); \
scope.Var(x)->GetMutable<LoDTensor>() scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -78,15 +77,9 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -78,15 +77,9 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8")); desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale")); desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale")); desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
if (base_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
} }
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
...@@ -69,15 +69,16 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -69,15 +69,16 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto* scope = graph->Get<Scope*>(kParamScopeAttr);
PADDLE_ENFORCE(scope);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias // Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor = auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>(); fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var); PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope.FindVar(bias->Name()); auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name()); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var); PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
...@@ -93,7 +94,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -93,7 +94,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef GET_NODE #undef GET_NODE
#define NEW_IMTERMEDIATE_OUT(key) \ #define NEW_IMTERMEDIATE_OUT(key) \
scope.Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>() scope->Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
NEW_IMTERMEDIATE_OUT(ReorderedH0); NEW_IMTERMEDIATE_OUT(ReorderedH0);
NEW_IMTERMEDIATE_OUT(XX); NEW_IMTERMEDIATE_OUT(XX);
NEW_IMTERMEDIATE_OUT(BatchedInput); NEW_IMTERMEDIATE_OUT(BatchedInput);
......
...@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define OP_SET_OUT(x) \ #define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \ const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \ op_desc.SetOutput(#x, {x}); \
scope.Var(x)->GetMutable<LoDTensor>() scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
......
...@@ -26,7 +26,7 @@ namespace framework { ...@@ -26,7 +26,7 @@ namespace framework {
namespace ir { namespace ir {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale", "tanh"}; std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(graph, act_types); graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types); graph = FuseElewiseAddAct(graph, act_types);
// backward // backward
......
...@@ -26,8 +26,7 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { ...@@ -26,8 +26,7 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
Scope* FusePassBase::param_scope() const { Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr); return graph_->Get<framework::Scope*>(kParamScopeAttr);
return &scope;
} }
void FusePassBase::AddStatis(int count_of_fused) const { void FusePassBase::AddStatis(int count_of_fused) const {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册