未验证 提交 b57eac73 编写于 作者: Y Yan Chunwei 提交者: GitHub

Lite/update for x86 (#19027)

上级 fbbd8208
...@@ -19,36 +19,6 @@ set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) ...@@ -19,36 +19,6 @@ set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
include(system) include(system)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cmake_minimum_required(VERSION 3.10)
# TODO(TJ): make as function check_default
if(NOT DEFINED ARM_TARGET_OS)
set(ARM_TARGET_OS "android" CACHE STRING "Choose ARM Target OS")
endif()
set(ARM_TARGET_OS_LIST "android" "armlinux") # TODO: "ios"
set_property(CACHE ARM_TARGET_OS PROPERTY STRINGS ${ARM_TARGET_OS_LIST})
if (NOT ARM_TARGET_OS IN_LIST ARM_TARGET_OS_LIST)
message(FATAL_ERROR "ARM_TARGET_OS must be in one of ${ARM_TARGET_OS_LIST}")
endif()
if(NOT DEFINED ARM_TARGET_ARCH_ABI)
set(ARM_TARGET_ARCH_ABI "arm64-v8a" CACHE STRING "Choose ARM Target ARCH ABI")
endif()
set(ARM_TARGET_ARCH_ABI_LIST "arm64-v8a" "armeabi-v7a" "armeabi-v7a-softfp" "armeabi-v7a-hf")
set_property(CACHE ARM_TARGET_ARCH_ABI PROPERTY STRINGS ${ARM_TARGET_ARCH_ABI_LIST})
if (NOT ARM_TARGET_ARCH_ABI IN_LIST ARM_TARGET_ARCH_ABI_LIST)
message(FATAL_ERROR "ARM_TARGET_ARCH_ABI must be in one of ${ARM_TARGET_ARCH_ABI_LIST}")
endif()
if(NOT DEFINED TARGET_ARCH_ABI)
set(ARCH_ABI "arm64-v8a" CACHE STRING "Choose android platform")
endif()
include(cross_compiling/host)
include(cross_compiling/armlinux)
include(cross_compiling/android)
endif()
project(paddle CXX C) project(paddle CXX C)
message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
"${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") "${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
...@@ -71,9 +41,7 @@ if(WIN32) ...@@ -71,9 +41,7 @@ if(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
endif(WIN32) endif(WIN32)
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) find_package(CUDA QUIET)
find_package(CUDA QUIET)
endif()
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
...@@ -111,41 +79,11 @@ option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VER ...@@ -111,41 +79,11 @@ option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VER
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
if(ANDROID OR IOS OR ARMLINUX) # PY_VERSION
set(WITH_GPU OFF CACHE STRING if(NOT PY_VERSION)
"Disable GPU when cross-compiling for Android and iOS" FORCE) set(PY_VERSION 2.7)
set(WITH_DSO OFF CACHE STRING
"Disable DSO when cross-compiling for Android and iOS" FORCE)
set(WITH_AVX OFF CACHE STRING
"Disable AVX when cross-compiling for Android and iOS" FORCE)
set(WITH_PYTHON OFF CACHE STRING
"Disable PYTHON when cross-compiling for Android and iOS" FORCE)
set(WITH_RDMA OFF CACHE STRING
"Disable RDMA when cross-compiling for Android and iOS" FORCE)
set(WITH_MKL OFF CACHE STRING
"Disable MKL when cross-compiling for Android and iOS" FORCE)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING
"Default use Release in android" FORCE)
endif()
if(NOT THIRD_PARTY_BUILD_TYPE)
set(THIRD_PARTY_BUILD_TYPE "MinSizeRel" CACHE STRING
"Default use MinSizeRel in android" FORCE)
endif()
endif() endif()
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
# for lite, both server and mobile framework.
option(WITH_LITE "Enable lite framework" OFF)
option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_ARM "Enable ARM in lite mode" OFF)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF)
option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -154,36 +92,6 @@ if(NOT CMAKE_BUILD_TYPE) ...@@ -154,36 +92,6 @@ if(NOT CMAKE_BUILD_TYPE)
FORCE) FORCE)
endif() endif()
include_directories("${PADDLE_SOURCE_DIR}")
# for mobile
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
message(STATUS "Building the mobile framework")
# include the necessary thirdparty dependencies
include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog
include(external/gtest) # download, build, install gtest
#include(external/zlib) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf
include(external/eigen) # download eigen3
include(generic) # simplify cmake module
include(configure) # add paddle env configuration
add_definitions(-std=c++11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_subdirectory(paddle)
return()
endif()
# PY_VERSION
if(NOT PY_VERSION)
set(PY_VERSION 2.7)
endif()
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
if (APPLE) if (APPLE)
set(WITH_MKL OFF CACHE STRING set(WITH_MKL OFF CACHE STRING
"Disable MKL for building on mac" FORCE) "Disable MKL for building on mac" FORCE)
...@@ -194,12 +102,16 @@ if (WIN32) ...@@ -194,12 +102,16 @@ if (WIN32)
"Disable DISTRIBUTE when compiling for Windows" FORCE) "Disable DISTRIBUTE when compiling for Windows" FORCE)
endif() endif()
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
"A path setting fluid shared and static libraries") "A path setting fluid shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries") "A path setting fluid inference shared and static libraries")
set(THIRD_PARTY_BUILD_TYPE Release)
set(WITH_MKLML ${WITH_MKL}) set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN) if (NOT DEFINED WITH_MKLDNN)
...@@ -273,6 +185,7 @@ if(WITH_BRPC_RDMA) ...@@ -273,6 +185,7 @@ if(WITH_BRPC_RDMA)
endif() endif()
endif() endif()
include(external/threadpool) include(external/threadpool)
include(flags) # set paddle compile flags include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
...@@ -321,6 +234,7 @@ include(coveralls) # set code coverage ...@@ -321,6 +234,7 @@ include(coveralls) # set code coverage
include(inference_lib) # add paddle fluid inference libraries include(inference_lib) # add paddle fluid inference libraries
include_directories("${PADDLE_SOURCE_DIR}")
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
find_package(HIP) find_package(HIP)
......
# A image for building paddle binaries # A image for building paddle binaries
# Use cuda devel base image for both cpu and gpu environment # Use cuda devel base image for both cpu and gpu environment
# When you modify it, please be aware of cudnn-runtime version # When you modify it, please be aware of cudnn-runtime version
# and libcudnn.so.x in paddle/scripts/docker/build.sh
FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
...@@ -76,7 +75,7 @@ RUN curl -s -q https://glide.sh/get | sh ...@@ -76,7 +75,7 @@ RUN curl -s -q https://glide.sh/get | sh
# 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle. # 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle.
# See https://github.com/PaddlePaddle/Paddle/issues/10129 for details. # See https://github.com/PaddlePaddle/Paddle/issues/10129 for details.
RUN wget -q https://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \ RUN wget -q https://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \
tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \ tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \ cp -rf /usr/local/TensorRT/include /usr && \
cp -rf /usr/local/TensorRT/lib /usr cp -rf /usr/local/TensorRT/lib /usr
...@@ -93,17 +92,17 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 ...@@ -93,17 +92,17 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U # specify sphinx version as 1.5.6 and remove -U option for [pip install -U
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest # sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
# version(1.7.1 for now), which causes building documentation failed. # version(1.7.1 for now), which causes building documentation failed.
RUN pip3 --no-cache-dir install -U wheel && \ RUN pip3 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \
pip3 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip3 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.6 --no-cache-dir install -U wheel && \ pip3.6 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \
pip3.6 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip3.6 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.6 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3.6 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.7 --no-cache-dir install -U wheel && \ pip3.7 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \
pip3.7 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip3.7 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.7 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3.7 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip --no-cache-dir install -U pip setuptools wheel && \ pip --no-cache-dir install -U pip setuptools wheel py-cpuinfo==5.0.0 && \
pip --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \ pip --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark pip --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark
......
...@@ -98,9 +98,11 @@ We provide [English](http://www.paddlepaddle.org/documentation/docs/en/1.4/begin ...@@ -98,9 +98,11 @@ We provide [English](http://www.paddlepaddle.org/documentation/docs/en/1.4/begin
We appreciate your contributions! We appreciate your contributions!
## Ask Questions ## Communication
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues). - [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 432676488 (PaddlePaddle).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Copyright and License ## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE). PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
...@@ -80,9 +80,11 @@ pip install paddlepaddle-gpu==1.4.1.post85 ...@@ -80,9 +80,11 @@ pip install paddlepaddle-gpu==1.4.1.post85
欢迎您的贡献! 欢迎您的贡献!
## 答疑 ## 交流与反馈
欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交 - 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议
- QQ群: 432676488 (PaddlePaddle)
- [论坛](http://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 版权和许可证 ## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供 PaddlePaddle由[Apache-2.0 license](LICENSE)提供
if(NOT WITH_GPU)
return()
endif()
set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT") set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT")
find_path(ANAKIN_INCLUDE_DIR anakin_config.h find_path(ANAKIN_INCLUDE_DIR anakin_config.h
PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include
...@@ -16,9 +12,7 @@ find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so ...@@ -16,9 +12,7 @@ find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so
DOC "Path to ANAKIN library.") DOC "Path to ANAKIN library.")
if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY) if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY)
if(WITH_DSO)
set(ANAKIN_FOUND ON) set(ANAKIN_FOUND ON)
endif(WITH_DSO)
else() else()
set(ANAKIN_FOUND OFF) set(ANAKIN_FOUND OFF)
endif() endif()
...@@ -31,3 +25,8 @@ if(ANAKIN_FOUND) ...@@ -31,3 +25,8 @@ if(ANAKIN_FOUND)
link_directories(${ANAKIN_ROOT}) link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN) add_definitions(-DPADDLE_WITH_ANAKIN)
endif() endif()
if(ANAKIN_FOUND AND WITH_GPU AND WITH_DSO)
message(STATUS "Compile with anakin subgraph.")
set(ANAKIN_SUBGRAPH ON)
endif()
...@@ -30,6 +30,7 @@ endif(NOT WITH_PROFILER) ...@@ -30,6 +30,7 @@ endif(NOT WITH_PROFILER)
if(WITH_AVX AND AVX_FOUND) if(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG}) set(SIMD_FLAG ${AVX_FLAG})
add_definitions(-DPADDLE_WITH_AVX)
elseif(SSE3_FOUND) elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG}) set(SIMD_FLAG ${SSE3_FLAG})
endif() endif()
...@@ -157,29 +158,3 @@ endif(WITH_BRPC_RDMA) ...@@ -157,29 +158,3 @@ endif(WITH_BRPC_RDMA)
if(ON_INFER) if(ON_INFER)
add_definitions(-DPADDLE_ON_INFERENCE) add_definitions(-DPADDLE_ON_INFERENCE)
endif(ON_INFER) endif(ON_INFER)
if(WITH_WBAES)
add_definitions(-DPADDLE_WITH_WBAES)
endif(WITH_WBAES)
# for lite
# TODO(Superjomn) not work fine with the option
if (LITE_WITH_CUDA)
add_definitions("-DLITE_WITH_CUDA")
endif()
if (LITE_WITH_X86)
add_definitions("-DLITE_WITH_X86")
endif()
if (LITE_WITH_ARM)
add_definitions("-DLITE_WITH_ARM")
endif()
if (LITE_WITH_PROFILE)
add_definitions("-DLITE_WITH_PROFILE")
endif()
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK")
endif()
...@@ -141,12 +141,10 @@ endfunction() ...@@ -141,12 +141,10 @@ endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION}) message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0) if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
add_definitions("-DPADDLE_CUDA_BINVER=\"60\"")
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"70\"")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
...@@ -154,18 +152,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x ...@@ -154,18 +152,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the # CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now. # warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif() endif()
add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"")
include_directories(${CUDA_INCLUDE_DIRS}) include_directories(${CUDA_INCLUDE_DIRS})
if(NOT WITH_DSO) if(NOT WITH_DSO)
......
...@@ -96,7 +96,7 @@ if(CUDNN_FOUND) ...@@ -96,7 +96,7 @@ if(CUDNN_FOUND)
endif() endif()
message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. " message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. "
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ") "Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ")
endif() endif()
endif() endif()
...@@ -38,5 +38,3 @@ ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL) ...@@ -38,5 +38,3 @@ ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES}) SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc) ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
...@@ -12,6 +12,13 @@ if(NOT WITH_FAST_MATH) ...@@ -12,6 +12,13 @@ if(NOT WITH_FAST_MATH)
add_definitions(-DEIGEN_FAST_MATH=0) add_definitions(-DEIGEN_FAST_MATH=0)
endif() endif()
if(WIN32)
set(EIGEN_GIT_REPOSITORY https://github.com/wopeizl/eigen-git-mirror)
set(EIGEN_GIT_TAG support_cuda9_win)
else()
set(EIGEN_GIT_REPOSITORY https://github.com/eigenteam/eigen-git-mirror)
set(EIGEN_GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c)
endif()
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
...@@ -29,10 +36,10 @@ else() ...@@ -29,10 +36,10 @@ else()
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror" GIT_REPOSITORY "${EIGEN_GIT_REPOSITORY}"
# eigen on cuda9.1 missing header of math_funtions.hpp # eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c GIT_TAG ${EIGEN_GIT_TAG}
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
DOWNLOAD_NAME "eigen" DOWNLOAD_NAME "eigen"
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -18,31 +18,13 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) ...@@ -18,31 +18,13 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags)
SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags)
SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE)
IF(WIN32) IF(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ELSE(WIN32) ELSE(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ENDIF(WIN32) ENDIF(WIN32)
INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}")
endif()
ExternalProject_Add( ExternalProject_Add(
extern_gflags extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -50,24 +32,24 @@ ExternalProject_Add( ...@@ -50,24 +32,24 @@ ExternalProject_Add(
GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a
PREFIX ${GFLAGS_SOURCES_DIR} PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DBUILD_STATIC_LIBS=ON CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DBUILD_STATIC_LIBS=ON
-DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF -DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${OPTIONAL_ARGS}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32)
IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib")
add_custom_command(TARGET extern_gflags POST_BUILD
COMMAND cmake -E copy ${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib ${GFLAGS_INSTALL_DIR}/lib/libgflags.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags) ADD_DEPENDENCIES(gflags extern_gflags)
......
...@@ -19,7 +19,7 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog) ...@@ -19,7 +19,7 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog)
SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE) SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE)
IF(WIN32) IF(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.lib" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/glog.lib" CACHE FILEPATH "glog library." FORCE)
SET(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4530") SET(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4530")
ELSE(WIN32) ELSE(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE) SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE)
...@@ -31,24 +31,6 @@ INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR}) ...@@ -31,24 +31,6 @@ INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR})
SET(GLOG_REPOSITORY "https://github.com/google/glog.git") SET(GLOG_REPOSITORY "https://github.com/google/glog.git")
SET(GLOG_TAG "v0.3.5") SET(GLOG_TAG "v0.3.5")
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}")
endif()
ExternalProject_Add( ExternalProject_Add(
extern_glog extern_glog
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -57,7 +39,14 @@ ExternalProject_Add( ...@@ -57,7 +39,14 @@ ExternalProject_Add(
GIT_TAG ${GLOG_TAG} GIT_TAG ${GLOG_TAG}
PREFIX ${GLOG_SOURCES_DIR} PREFIX ${GLOG_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS ${OPTIONAL_ARGS} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
...@@ -71,13 +60,6 @@ ExternalProject_Add( ...@@ -71,13 +60,6 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32)
IF(NOT EXISTS "${GLOG_INSTALL_DIR}/lib/libglog.lib")
add_custom_command(TARGET extern_glog POST_BUILD
COMMAND cmake -E copy ${GLOG_INSTALL_DIR}/lib/glog.lib ${GLOG_INSTALL_DIR}/lib/libglog.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
......
...@@ -43,24 +43,6 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -43,24 +43,6 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
SET(GTEST_DEPENDS ${MKLML_PROJECT}) SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF() ENDIF()
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}")
endif()
ExternalProject_Add( ExternalProject_Add(
extern_gtest extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -69,7 +51,14 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -69,7 +51,14 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
GIT_TAG "release-1.8.0" GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR} PREFIX ${GTEST_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS ${OPTIONAL_ARGS} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_GMOCK=ON -DBUILD_GMOCK=ON
......
...@@ -38,6 +38,7 @@ IF(WIN32) ...@@ -38,6 +38,7 @@ IF(WIN32)
SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib) SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_LIB_DEPS ${MKLML_LIB_DIR}/msvcr120.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE() ELSE()
#TODO(intel-huying): #TODO(intel-huying):
......
...@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs) ...@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph") SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_GIT_TAG "127e0dedfaac8c6f2b148cc03bf5f67ac5fbe6fe") SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
......
...@@ -142,6 +142,7 @@ IF (WIN32) ...@@ -142,6 +142,7 @@ IF (WIN32)
ENDIF(WIN32) ENDIF(WIN32)
if (NOT "${PROTOBUF_ROOT}" STREQUAL "") if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH) find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH)
find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH)
find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH)
...@@ -177,28 +178,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -177,28 +178,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
SET(OPTIONAL_CACHE_ARGS "") SET(OPTIONAL_CACHE_ARGS "")
SET(OPTIONAL_ARGS "") SET(OPTIONAL_ARGS "")
IF(BUILD_FOR_HOST) IF(BUILD_FOR_HOST)
SET(OPTIONAL_ARGS SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF")
"-DCMAKE_C_COMPILER=${HOST_C_COMPILER}"
"-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}"
"-Dprotobuf_WITH_ZLIB=OFF"
"-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}")
SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}")
ELSE() ELSE()
# protobuf have compile issue when use android stl c++_static SET(OPTIONAL_ARGS
SET(PROTOBUF_REPO "https://github.com/tensor-tang/protobuf.git")
SET(PROTOBUF_TAG "mobile")
SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF"
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}"
"-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
...@@ -206,18 +191,25 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -206,18 +191,25 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}" "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}") "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-Dprotobuf_WITH_ZLIB=ON"
"-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}"
${EXTERNAL_OPTIONAL_ARGS})
SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}")
ENDIF() ENDIF()
IF(WIN32) IF(WIN32)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64")
ENDIF() ENDIF()
SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
ExternalProject_Add( ExternalProject_Add(
${TARGET_NAME} ${TARGET_NAME}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PROTOBUF_SOURCES_DIR} PREFIX ${PROTOBUF_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
#DEPENDS zlib DEPENDS zlib
GIT_REPOSITORY ${PROTOBUF_REPO} GIT_REPOSITORY ${PROTOBUF_REPO}
GIT_TAG ${PROTOBUF_TAG} GIT_TAG ${PROTOBUF_TAG}
CONFIGURE_COMMAND CONFIGURE_COMMAND
...@@ -241,13 +233,6 @@ ENDFUNCTION() ...@@ -241,13 +233,6 @@ ENDFUNCTION()
SET(PROTOBUF_VERSION 3.1.0) SET(PROTOBUF_VERSION 3.1.0)
IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
build_protobuf(protobuf_host TRUE)
LIST(APPEND external_project_dependencies protobuf_host)
SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_host_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE)
ENDIF()
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE) build_protobuf(extern_protobuf FALSE)
...@@ -260,12 +245,7 @@ IF(NOT PROTOBUF_FOUND) ...@@ -260,12 +245,7 @@ IF(NOT PROTOBUF_FOUND)
SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY}
CACHE FILEPATH "protoc library." FORCE) CACHE FILEPATH "protoc library." FORCE)
IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf)
ELSE()
SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE) CACHE FILEPATH "protobuf executable." FORCE)
PROMPT_PROTOBUF_LIB(extern_protobuf) PROMPT_PROTOBUF_LIB(extern_protobuf)
ENDIF()
ENDIF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND)
...@@ -29,9 +29,9 @@ INCLUDE(ExternalProject) ...@@ -29,9 +29,9 @@ INCLUDE(ExternalProject)
SET(PSLIB_PROJECT "extern_pslib") SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL)) IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url") MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE) SET(PSLIB_VER "0.1.1" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE) SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE) SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/ps/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF() ENDIF()
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}") MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib") SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
......
...@@ -53,12 +53,7 @@ ExternalProject_Add( ...@@ -53,12 +53,7 @@ ExternalProject_Add(
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32) IF(WIN32)
IF(NOT EXISTS "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib") set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
add_custom_command(TARGET extern_snappy POST_BUILD
COMMAND cmake -E copy ${SNAPPY_INSTALL_DIR}/lib/snappy.lib ${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib
)
ENDIF()
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
else(WIN32) else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a") set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32) endif (WIN32)
......
...@@ -64,12 +64,7 @@ ExternalProject_Add( ...@@ -64,12 +64,7 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
) )
IF(WIN32) IF(WIN32)
IF(NOT EXISTS "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}") SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
add_custom_command(TARGET extern_warpctc POST_BUILD
COMMAND cmake -E copy ${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} ${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}
)
ENDIF()
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE) CACHE FILEPATH "Warp-ctc Library" FORCE)
else(WIN32) else(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
......
...@@ -56,12 +56,7 @@ else() ...@@ -56,12 +56,7 @@ else()
endif() endif()
if (WIN32) if (WIN32)
IF(NOT EXISTS "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib") set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/xxhash.lib")
add_custom_command(TARGET extern_xxhash POST_BUILD
COMMAND cmake -E copy ${XXHASH_INSTALL_DIR}/lib/xxhash.lib ${XXHASH_INSTALL_DIR}/lib/libxxhash.lib
)
ENDIF()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
else() else()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
endif () endif ()
......
...@@ -44,12 +44,7 @@ ExternalProject_Add( ...@@ -44,12 +44,7 @@ ExternalProject_Add(
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
IF(WIN32) IF(WIN32)
IF(NOT EXISTS "${ZLIB_INSTALL_DIR}/lib/libz.lib") SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib" CACHE FILEPATH "zlib library." FORCE)
add_custom_command(TARGET extern_zlib POST_BUILD
COMMAND cmake -E copy ${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib ${ZLIB_INSTALL_DIR}/lib/libz.lib
)
ENDIF()
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32) ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE) SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32) ENDIF(WIN32)
......
...@@ -93,10 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) ...@@ -93,10 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt")
if (NOT ANDROID)
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -lrt")
endif()
endif(NOT APPLE) endif(NOT APPLE)
set_property(GLOBAL PROPERTY FLUID_MODULES "") set_property(GLOBAL PROPERTY FLUID_MODULES "")
...@@ -366,11 +363,10 @@ function(cc_binary TARGET_NAME) ...@@ -366,11 +363,10 @@ function(cc_binary TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${os_dependency_modules}) target_link_libraries(${TARGET_NAME} ${os_dependency_modules})
endfunction(cc_binary) endfunction(cc_binary)
function(cc_test TARGET_NAME) function(cc_test_build TARGET_NAME)
if(WITH_TESTING) if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS}) add_executable(${TARGET_NAME} ${cc_test_SRCS})
if(WIN32) if(WIN32)
...@@ -383,12 +379,18 @@ function(cc_test TARGET_NAME) ...@@ -383,12 +379,18 @@ function(cc_test TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
endif()
endfunction()
function(cc_test_run TARGET_NAME)
if(WITH_TESTING)
set(oneValueArgs "")
set(multiValueArgs COMMAND ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS} COMMAND ${cc_test_COMMAND}
ARGS ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
...@@ -396,46 +398,21 @@ function(cc_test TARGET_NAME) ...@@ -396,46 +398,21 @@ function(cc_test TARGET_NAME)
# No unit test should exceed 10 minutes. # No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
endfunction(cc_test) endfunction()
# cc_test without default dependencies function(cc_test TARGET_NAME)
function(raw_cc_test TARGET_NAME)
if(WITH_TESTING) if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS) set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS}) cc_test_build(${TARGET_NAME}
if(WIN32) SRCS ${cc_test_SRCS}
if("${cc_test_DEPS};" MATCHES "python;") DEPS ${cc_test_DEPS})
list(REMOVE_ITEM cc_test_DEPS python) cc_test_run(${TARGET_NAME}
target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES}) COMMAND ${TARGET_NAME}
ARGS ${cc_test_ARGS})
endif() endif()
endif(WIN32) endfunction(cc_test)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} lite_gtest_main gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
endfunction(raw_cc_test)
function(_lite_cc_test args)
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
message(STATUS "building lite raw test: ${args}")
raw_cc_test(${args} ${ARGN})
else()
message(STATUS "building lite heavy test: ${args}")
cc_test(${args} ${ARGN})
endif()
endfunction()
function(nv_library TARGET_NAME) function(nv_library TARGET_NAME)
if (WITH_GPU) if (WITH_GPU)
...@@ -488,7 +465,6 @@ endfunction(nv_binary) ...@@ -488,7 +465,6 @@ endfunction(nv_binary)
function(nv_test TARGET_NAME) function(nv_test TARGET_NAME)
if (WITH_GPU AND WITH_TESTING) if (WITH_GPU AND WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
...@@ -498,9 +474,6 @@ function(nv_test TARGET_NAME) ...@@ -498,9 +474,6 @@ function(nv_test TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
...@@ -743,7 +716,7 @@ function(py_proto_compile TARGET_NAME) ...@@ -743,7 +716,7 @@ function(py_proto_compile TARGET_NAME)
cmake_parse_arguments(py_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(py_srcs) set(py_srcs)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS}) protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs} protobuf)
endfunction() endfunction()
function(py_test TARGET_NAME) function(py_test TARGET_NAME)
......
...@@ -110,7 +110,7 @@ function(op_library TARGET) ...@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -3,8 +3,6 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION}) ...@@ -3,8 +3,6 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD") set(tmp_version "HEAD")
set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?") set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+") set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
set(LATEST_PADDLE_VERSION "latest")
while ("${PADDLE_VERSION}" STREQUAL "") while ("${PADDLE_VERSION}" STREQUAL "")
# Check current branch name # Check current branch name
execute_process( execute_process(
...@@ -25,8 +23,8 @@ while ("${PADDLE_VERSION}" STREQUAL "") ...@@ -25,8 +23,8 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}") if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}")
# Check the tag is a correct version # Check the tag is a correct version
if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}") if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
# if no tag was found, set PADDLE_VERSION to "latest" # if no tag was found, set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") set(PADDLE_VERSION "0.0.0")
elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME}) string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name. else() # otherwise, get the previous git tag name.
...@@ -44,19 +42,19 @@ while ("${PADDLE_VERSION}" STREQUAL "") ...@@ -44,19 +42,19 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME}) string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME})
else() else()
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") set(PADDLE_VERSION "0.0.0")
endif() endif()
else() else()
# otherwise, we always set PADDLE_VERSION to "latest" # otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") set(PADDLE_VERSION "0.0.0")
endif() endif()
endif() endif()
else() else()
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") set(PADDLE_VERSION "0.0.0")
message(WARNING "Cannot add paddle version from git tag") message(WARNING "Cannot add paddle version from git tag")
endif() endif()
else() else()
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") set(PADDLE_VERSION "0.0.0")
message(WARNING "Cannot add paddle version for wrong git branch result") message(WARNING "Cannot add paddle version for wrong git branch result")
endif() endif()
endwhile() endwhile()
......
# to limit the mobile dependencies add_subdirectory(scripts)
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) add_subdirectory(testing)
add_subdirectory(scripts) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
endif()
add_subdirectory(fluid) add_subdirectory(fluid)
此差异已折叠。
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) # for mobile
add_subdirectory(lite)
return()
endif()
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
...@@ -10,8 +6,7 @@ add_subdirectory(operators) ...@@ -10,8 +6,7 @@ add_subdirectory(operators)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(recordio) add_subdirectory(recordio)
add_subdirectory(pybind) add_subdirectory(pybind)
add_subdirectory(train)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(train)
add_subdirectory(lite)
...@@ -29,7 +29,8 @@ add_subdirectory(io) ...@@ -29,7 +29,8 @@ add_subdirectory(io)
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto) proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto) proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto
data_feed_proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -124,7 +125,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co ...@@ -124,7 +125,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type data_feed_proto) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
...@@ -173,20 +174,20 @@ endif() ...@@ -173,20 +174,20 @@ endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector) cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto) graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
...@@ -201,10 +202,10 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -201,10 +202,10 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc pipeline_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc downpour_worker.cc pull_dense_worker.cc section_worker.cc
data_set.cc dataset_factory.cc device_worker_factory.cc data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto feed_fetch_method graph_to_program_pass data_feed_proto
...@@ -225,6 +226,8 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) ...@@ -225,6 +226,8 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc ) cc_test(tuple_test SRCS tuple_test.cc )
cc_test(inlined_vector_test SRCS inlined_vector_test.cc)
if (NOT WIN32) if (NOT WIN32)
cc_test(rw_lock_test SRCS rw_lock_test.cc) cc_test(rw_lock_test SRCS rw_lock_test.cc)
endif (NOT WIN32) endif (NOT WIN32)
......
...@@ -85,8 +85,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -85,8 +85,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
DataFeedDesc data_feed_desc; DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, bool success = data_feed_desc.ParseFromString(data_feed_desc_str);
&data_feed_desc); PADDLE_ENFORCE(success, "Fail to parse DataFeedDesc from string:\n%s",
data_feed_desc_str.c_str());
actual_thread_num_ = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
......
...@@ -95,6 +95,11 @@ class BlockingQueue { ...@@ -95,6 +95,11 @@ class BlockingQueue {
return q_.size(); return q_.size();
} }
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(q_);
}
private: private:
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
......
...@@ -20,6 +20,9 @@ limitations under the License. */ ...@@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX #ifdef _LINUX
#include <stdio_ext.h> #include <stdio_ext.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#endif #endif
#include <utility> #include <utility>
#include "gflags/gflags.h" #include "gflags/gflags.h"
...@@ -87,6 +90,13 @@ void DataFeed::CheckStart() { ...@@ -87,6 +90,13 @@ void DataFeed::CheckStart() {
PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet."); PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet.");
} }
void DataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
feed_vec_[i] = scope.FindVar(use_slots_[i])->GetMutable<LoDTensor>();
}
}
template <typename T> template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size); PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
...@@ -158,6 +168,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -158,6 +168,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
mutex_for_update_memory_data_ = nullptr; mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr; this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr; this->mutex_for_pick_file_ = nullptr;
fleet_send_sleep_seconds_ = 2;
} }
template <typename T> template <typename T>
...@@ -366,7 +377,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -366,7 +377,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_); std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_); std::vector<int> send_index(trainer_num_);
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_; uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_ + 1;
for (auto& vec : send_vec) { for (auto& vec : send_vec) {
vec.reserve(reserve_len); vec.reserve(reserve_len);
} }
...@@ -377,47 +388,34 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -377,47 +388,34 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto interval = GetMemoryDataInterval(); auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", " VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_; << interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash for (int64_t i = interval.first; i < interval.second;
// std::string ins_id = memory_data_[i].ins_id; i += fleet_send_batch_size_) {
int64_t random_num = rand_r(&rand_seed); for (int64_t j = 0; j < fleet_send_batch_size_ && i + j < interval.second;
++j) {
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_; int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i])); send_vec[node_id].push_back(&((*memory_data_)[i + j]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
}
} }
} total_status.clear();
// shuffle the sequence of sending to avoid network timeout error std::shuffle(send_index.begin(), send_index.end(),
std::random_shuffle(send_index.begin(), send_index.end()); fleet_ptr->LocalRandomEngine());
for (int index = 0; index < send_index.size(); ++index) { for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index]; int j = send_index[index];
if (send_vec[j].size() != 0) { if (send_vec[j].size() == 0) {
continue;
}
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret)); total_status.push_back(std::move(ret));
} send_vec[j].clear();
std::vector<T*>().swap(send_vec[j]);
} }
for (auto& t : total_status) { for (auto& t : total_status) {
t.wait(); t.wait();
} }
sleep(fleet_send_sleep_seconds_);
}
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_; VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
#endif #endif
} }
...@@ -436,6 +434,24 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() { ...@@ -436,6 +434,24 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
return std::make_pair(start, end); return std::make_pair(start, end);
} }
template <typename T>
int64_t InMemoryDataFeed<T>::GetChannelDataSize() {
if (cur_channel_ == 0) {
return shuffled_ins_->Size();
} else {
return shuffled_ins_out_->Size();
}
}
template <typename T>
void InMemoryDataFeed<T>::ReleaseChannelData() {
if (cur_channel_ == 0) {
shuffled_ins_->Clear();
} else {
shuffled_ins_out_->Clear();
}
}
// explicit instantiation // explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>; template class InMemoryDataFeed<std::vector<MultiSlotType>>;
...@@ -471,17 +487,17 @@ void MultiSlotDataFeed::Init( ...@@ -471,17 +487,17 @@ void MultiSlotDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense()); use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape; std::vector<int> local_shape;
if (slot.is_dense()) { if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(i) > 0) { if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i); total_dims_without_inductive_[i] *= slot.shape(j);
} }
if (slot.shape(i) == -1) { if (slot.shape(j) == -1) {
inductive_shape_index_[i] = i; inductive_shape_index_[i] = j;
} }
} }
} }
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(i)); local_shape.push_back(slot.shape(j));
} }
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
...@@ -805,22 +821,24 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -805,22 +821,24 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_[i] = slot.name(); all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type(); all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1; use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) { if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]); use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense()); use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape; std::vector<int> local_shape;
if (slot.is_dense()) { if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(i) > 0) { if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i); total_dims_without_inductive_[i] *= slot.shape(j);
} }
if (slot.shape(i) == -1) { if (slot.shape(j) == -1) {
inductive_shape_index_[i] = i; inductive_shape_index_[i] = j;
} }
} }
} }
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(i)); local_shape.push_back(slot.shape(j));
} }
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
...@@ -1001,5 +1019,205 @@ void MultiSlotInMemoryDataFeed::DeserializeIns( ...@@ -1001,5 +1019,205 @@ void MultiSlotInMemoryDataFeed::DeserializeIns(
fleet_ptr->Deserialize(ins, str); fleet_ptr->Deserialize(ins, str);
} }
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec_[i].GetType();
const auto& offset = ins_vec_[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec_[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec_[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int64_t total_dims = 1;
for (const auto e : use_slots_shape_[i]) {
total_dims *= e;
}
PADDLE_ENFORCE(
total_dims == total_instance,
"The actual data size of slot[%s] doesn't match its declaration",
use_slots_[i].c_str());
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
}
template <typename T>
int PrivateInstantDataFeed<T>::Next() {
if (ParseOneMiniBatch()) {
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
Postprocess();
std::string filename;
if (!PickOneFile(&filename)) {
return -1;
}
if (!Preprocess(filename)) {
return -1;
}
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
template <typename T>
void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
multi_inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) == -1) {
multi_inductive_shape_index_[i].push_back(j);
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
ins_vec_.resize(use_slots_.size());
finish_init_ = true;
}
template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());
struct stat sb;
fstat(fd_, &sb);
end_ = static_cast<size_t>(sb.st_size);
buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));
offset_ = 0;
return true;
}
bool MultiSlotFileInstantDataFeed::Postprocess() {
if (buffer_ != nullptr) {
munmap(buffer_, end_);
buffer_ = nullptr;
}
if (fd_ != -1) {
close(fd_);
fd_ = -1;
end_ = 0;
offset_ = 0;
}
return true;
}
bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
if (offset_ == end_) {
return false;
}
batch_size_ = 0;
while (batch_size_ < default_batch_size_ && offset_ < end_) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.");
offset_ += sizeof(uint16_t);
if (idx != -1) {
int inductive_size = multi_inductive_shape_index_[i].size();
if (UNLIKELY(batch_size_ == 0)) {
ins_vec_[idx].Init(all_slots_type_[i], default_batch_size_ * num);
ins_vec_[idx].InitOffset(default_batch_size_);
uint64_t* inductive_shape =
reinterpret_cast<uint64_t*>(buffer_ + offset_);
for (int inductive_id = 0; inductive_id < inductive_size;
++inductive_id) {
use_slots_shape_[i][multi_inductive_shape_index_[i][inductive_id]] =
static_cast<int>(*(inductive_shape + inductive_id));
}
}
num -= inductive_size;
offset_ += sizeof(uint64_t) * inductive_size;
if (type == 'f') {
ins_vec_[idx].AppendValues(
reinterpret_cast<float*>(buffer_ + offset_), num);
offset_ += num * sizeof(float);
} else if (type == 'u') {
ins_vec_[idx].AppendValues(
reinterpret_cast<uint64_t*>(buffer_ + offset_), num);
offset_ += num * sizeof(uint64_t);
}
} else {
if (type == 'f') {
offset_ += num * sizeof(float);
} else if (type == 'u') {
offset_ += num * sizeof(uint64_t);
}
}
}
++batch_size_;
// OPTIMIZE: It is better to insert check codes between instances for format
// checking
}
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_");
return true;
}
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,7 +59,7 @@ class DataFeed { ...@@ -59,7 +59,7 @@ class DataFeed {
file_idx_ = nullptr; file_idx_ = nullptr;
} }
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented."); PADDLE_THROW("This function(CheckFile) is not implemented.");
} }
...@@ -84,6 +84,9 @@ class DataFeed { ...@@ -84,6 +84,9 @@ class DataFeed {
// This function is used for binding feed_vec memory // This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name); virtual void AddFeedVar(Variable* var, const std::string& name);
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default // This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {} virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default // This function will do nothing at default
...@@ -115,6 +118,9 @@ class DataFeed { ...@@ -115,6 +118,9 @@ class DataFeed {
virtual void FillChannelToMemoryData() {} virtual void FillChannelToMemoryData() {}
// This function will do nothing at default // This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {} virtual void PutInsToChannel(const std::string& ins_str) {}
virtual int64_t GetChannelDataSize() { return 0; }
// This function will do nothing at default
virtual void ReleaseChannelData() {}
protected: protected:
// The following three functions are used to check if it is executed in this // The following three functions are used to check if it is executed in this
...@@ -145,6 +151,8 @@ class DataFeed { ...@@ -145,6 +151,8 @@ class DataFeed {
std::vector<std::vector<int>> use_slots_shape_; std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> inductive_shape_index_; std::vector<int> inductive_shape_index_;
std::vector<int> total_dims_without_inductive_; std::vector<int> total_dims_without_inductive_;
// For the inductive shape passed within data
std::vector<std::vector<int>> multi_inductive_shape_index_;
std::vector<int> std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_ use_slots_index_; // -1: not used; >=0: the index of use_slots_
...@@ -170,7 +178,6 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -170,7 +178,6 @@ class PrivateQueueDataFeed : public DataFeed {
public: public:
PrivateQueueDataFeed() {} PrivateQueueDataFeed() {}
virtual ~PrivateQueueDataFeed() {} virtual ~PrivateQueueDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual int Next(); virtual int Next();
...@@ -209,7 +216,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -209,7 +216,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public: public:
InMemoryDataFeed(); InMemoryDataFeed();
virtual ~InMemoryDataFeed() {} virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual int Next(); virtual int Next();
virtual void SetMemoryData(void* memory_data); virtual void SetMemoryData(void* memory_data);
...@@ -224,6 +231,8 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -224,6 +231,8 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual int64_t GetChannelDataSize();
virtual void ReleaseChannelData();
protected: protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
...@@ -248,6 +257,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -248,6 +257,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_; std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_; int64_t fleet_send_batch_size_;
// sleep after send is to slow down sending data, but it's trick,
// should be removed later.
int64_t fleet_send_sleep_seconds_;
}; };
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
...@@ -255,16 +267,25 @@ class MultiSlotType { ...@@ -255,16 +267,25 @@ class MultiSlotType {
public: public:
MultiSlotType() {} MultiSlotType() {}
~MultiSlotType() {} ~MultiSlotType() {}
void Init(const std::string& type) { void Init(const std::string& type, size_t reserved_size = 0) {
CheckType(type); CheckType(type);
if (type_[0] == 'f') { if (type_[0] == 'f') {
float_feasign_.clear(); float_feasign_.clear();
if (reserved_size) {
float_feasign_.reserve(reserved_size);
}
} else if (type_[0] == 'u') { } else if (type_[0] == 'u') {
uint64_feasign_.clear(); uint64_feasign_.clear();
if (reserved_size) {
uint64_feasign_.reserve(reserved_size);
}
} }
type_ = type; type_ = type;
} }
void InitOffset() { void InitOffset(size_t max_batch_size = 0) {
if (max_batch_size > 0) {
offset_.reserve(max_batch_size + 1);
}
offset_.resize(1); offset_.resize(1);
// LoDTensor' lod is counted from 0, the size of lod // LoDTensor' lod is counted from 0, the size of lod
// is one size larger than the size of data. // is one size larger than the size of data.
...@@ -280,6 +301,16 @@ class MultiSlotType { ...@@ -280,6 +301,16 @@ class MultiSlotType {
CheckUint64(); CheckUint64();
uint64_feasign_.push_back(v); uint64_feasign_.push_back(v);
} }
void CopyValues(const float* input, size_t size) {
CheckFloat();
float_feasign_.resize(size);
memcpy(float_feasign_.data(), input, size * sizeof(float));
}
void CopyValues(const uint64_t* input, size_t size) {
CheckUint64();
uint64_feasign_.resize(size);
memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
}
void AddIns(const MultiSlotType& ins) { void AddIns(const MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { // float if (ins.GetType()[0] == 'f') { // float
CheckFloat(); CheckFloat();
...@@ -293,11 +324,22 @@ class MultiSlotType { ...@@ -293,11 +324,22 @@ class MultiSlotType {
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end()); uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
} }
} }
void AppendValues(const uint64_t* input, size_t size) {
CheckUint64();
offset_.push_back(offset_.back() + size);
uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
}
void AppendValues(const float* input, size_t size) {
CheckFloat();
offset_.push_back(offset_.back() + size);
float_feasign_.insert(float_feasign_.end(), input, input + size);
}
const std::vector<float>& GetFloatData() const { return float_feasign_; } const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; } std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; } const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; } std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; } const std::string& GetType() const { return type_; }
size_t GetBatchSize() { return offset_.size() - 1; }
std::string& MutableType() { return type_; } std::string& MutableType() { return type_; }
std::string DebugString() { std::string DebugString() {
...@@ -347,7 +389,7 @@ class MultiSlotDataFeed ...@@ -347,7 +389,7 @@ class MultiSlotDataFeed
public: public:
MultiSlotDataFeed() {} MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {} virtual ~MultiSlotDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename); virtual bool CheckFile(const char* filename);
// virtual void ReadThread(); // virtual void ReadThread();
...@@ -366,7 +408,7 @@ class MultiSlotInMemoryDataFeed ...@@ -366,7 +408,7 @@ class MultiSlotInMemoryDataFeed
public: public:
MultiSlotInMemoryDataFeed() {} MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {} virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); virtual void Init(const DataFeedDesc& data_feed_desc);
protected: protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins, virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
...@@ -381,5 +423,54 @@ class MultiSlotInMemoryDataFeed ...@@ -381,5 +423,54 @@ class MultiSlotInMemoryDataFeed
const std::string& str); const std::string& str);
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
public:
PrivateInstantDataFeed() {}
virtual ~PrivateInstantDataFeed() {}
void Init(const DataFeedDesc& data_feed_desc) override;
bool Start() override { return true; }
int Next() override;
protected:
// The batched data buffer
std::vector<MultiSlotType> ins_vec_;
// This function is used to preprocess with a given filename, e.g. open it or
// mmap
virtual bool Preprocess(const std::string& filename) = 0;
// This function is used to postprocess system resource such as closing file
// NOTICE: Ensure that it is safe to call before Preprocess
virtual bool Postprocess() = 0;
// The reading and parsing method.
virtual bool ParseOneMiniBatch() = 0;
// This function is used to put ins_vec to feed_vec
virtual void PutToFeedVec();
};
class MultiSlotFileInstantDataFeed
: public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotFileInstantDataFeed() {}
virtual ~MultiSlotFileInstantDataFeed() {}
protected:
int fd_{-1};
char* buffer_{nullptr};
size_t end_{0};
size_t offset_{0};
bool Preprocess(const std::string& filename) override;
bool Postprocess() override;
bool ParseOneMiniBatch() override;
};
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -64,5 +64,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( ...@@ -64,5 +64,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -145,7 +147,6 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -145,7 +147,6 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: %s", in.type()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
...@@ -156,14 +157,21 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -156,14 +157,21 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
if (in_format != out_format) { if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type()); const std::string key = platform::ReorderMKLDNNHandler::GetHash(
in_tz, in_format, out_format, std::to_string(in_type));
auto in_memory = platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data); cpu_engine, key);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory); auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, expected_kernel_type.place_);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*reorder_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} else { } else {
out->ShareDataWith(in); out->ShareDataWith(in);
} }
......
...@@ -141,6 +141,9 @@ template <typename T> ...@@ -141,6 +141,9 @@ template <typename T>
void DatasetImpl<T>::ReleaseMemory() { void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_); std::vector<T>().swap(memory_data_);
for (int i = 0; i < readers_.size(); ++i) {
readers_[i]->ReleaseChannelData();
}
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
} }
...@@ -178,8 +181,10 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -178,8 +181,10 @@ void DatasetImpl<T>::GlobalShuffle() {
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }
// if it is not InMemory, memory_data_ is empty auto fleet_ptr = FleetWrapper::GetInstance();
std::random_shuffle(memory_data_.begin(), memory_data_.end()); // local shuffle all data before global shuffle
std::shuffle(memory_data_.begin(), memory_data_.end(),
fleet_ptr->LocalRandomEngine());
VLOG(3) << "start global shuffle threads"; VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
...@@ -260,6 +265,20 @@ void DatasetImpl<T>::DestroyReaders() { ...@@ -260,6 +265,20 @@ void DatasetImpl<T>::DestroyReaders() {
} }
} }
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
return memory_data_.size();
}
template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0;
for (int i = 0; i < readers_.size(); ++i) {
sum += readers_[i]->GetChannelDataSize();
}
return sum;
}
template <typename T> template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) { const std::string& msg) {
...@@ -267,7 +286,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, ...@@ -267,7 +286,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length(); << ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = rand_r(&rand_seed) % thread_num_; int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index; VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg); readers_[index]->PutInsToChannel(msg);
#endif #endif
......
...@@ -85,6 +85,10 @@ class Dataset { ...@@ -85,6 +85,10 @@ class Dataset {
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
// destroy readers // destroy readers
virtual void DestroyReaders() = 0; virtual void DestroyReaders() = 0;
// get memory data size
virtual int64_t GetMemoryDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -127,6 +131,8 @@ class DatasetImpl : public Dataset { ...@@ -127,6 +131,8 @@ class DatasetImpl : public Dataset {
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
......
...@@ -93,6 +93,6 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -93,6 +93,6 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass) record_skip_memory_opt_vars_pass)
...@@ -35,16 +35,9 @@ namespace details { ...@@ -35,16 +35,9 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLCommunicator *ctxs)
: OpHandleBase(node), : NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) {
local_scopes_(local_scopes), PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...@@ -71,7 +64,9 @@ void AllReduceOpHandle::RunAllReduceFuncs( ...@@ -71,7 +64,9 @@ void AllReduceOpHandle::RunAllReduceFuncs(
if (FLAGS_sync_nccl_allreduce) { if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) { for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto *nccl_ctxs =
nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream); cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) { if (e_sync != 0) {
...@@ -134,21 +129,12 @@ void AllReduceOpHandle::RunImpl() { ...@@ -134,21 +129,12 @@ void AllReduceOpHandle::RunImpl() {
numel = static_cast<size_t>(lod_tensor.numel()); numel = static_cast<size_t>(lod_tensor.numel());
} }
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( NCCLAllReduce(p, buffer, buffer, numel,
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, static_cast<ncclDataType_t>(dtype), ncclSum);
comm, stream));
}); });
} }
VLOG(10) << "allreduce size:" << numel * SizeOfType(lod_tensors[0]->type());
RunAllReduceFuncs(all_reduce_calls); RunAllReduceFuncs(all_reduce_calls);
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -28,13 +29,15 @@ namespace paddle { ...@@ -28,13 +29,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class AllReduceOpHandle : public OpHandleBase {
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
class AllReduceOpHandle : public NCCLOpHandleBase {
public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLCommunicator *ctxs);
#else #else
class AllReduceOpHandle : public OpHandleBase {
public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
...@@ -46,13 +49,17 @@ class AllReduceOpHandle : public OpHandleBase { ...@@ -46,13 +49,17 @@ class AllReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunAllReduceFuncs( void RunAllReduceFuncs(
const std::vector<std::function<void()>> &all_reduce_calls); const std::vector<std::function<void()>> &all_reduce_calls);
const platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
}; };
......
...@@ -51,9 +51,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -51,9 +51,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
VLOG(3) << "ProcessGraph"; VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx; RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx; RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) { for (auto &node : graphs[0]->Nodes()) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name(); VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) { if (node && node->IsOp()) {
if (node->Name() == "send") { if (node->Name() == "send") {
...@@ -66,10 +64,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -66,10 +64,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("sections")); node->Op()->GetNullableAttr("sections"));
auto trainer_id = auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id")); boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
operators::distributed::RpcContext(send_var_name, send_varnames, send_var_name, send_varnames, epmap, height_section, trainer_id);
epmap, height_section,
trainer_id);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") { } else if (node->Name() == "recv") {
...@@ -80,16 +76,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -80,16 +76,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("epmap")); node->Op()->GetNullableAttr("epmap"));
auto trainer_id = auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id")); boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
operators::distributed::RpcContext(recv_var_name, recv_varnames, recv_var_name, recv_varnames, epmap, {}, trainer_id);
epmap, {}, trainer_id);
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: " VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name]; << recv_varname_to_ctx[recv_var_name];
} }
} }
} }
}
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator"; VLOG(3) << "this is distribute mode, will use communicator";
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -26,6 +27,8 @@ limitations under the License. */ ...@@ -26,6 +27,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -46,6 +49,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -46,6 +49,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add graph_viz_pass";
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf( const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
...@@ -53,10 +57,27 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -53,10 +57,27 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
VLOG(1) << "Add record_skip_memory_opt_vars_pass";
AppendPass("record_skip_memory_opt_vars_pass"); AppendPass("record_skip_memory_opt_vars_pass");
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
VLOG(1) << "Add mkldnn_placement_pass";
AppendPass("mkldnn_placement_pass");
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
if (strategy_.enable_sequential_execution_) { if (strategy_.enable_sequential_execution_) {
VLOG(5) << "Add sequential_execution_pass"; VLOG(1) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
} }
...@@ -67,7 +88,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -67,7 +88,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion. // Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) { if (strategy.fuse_relu_depthwise_conv_) {
VLOG(5) << "Add fuse_relu_depthwise_conv_pass"; VLOG(1) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass"); AppendPass("fuse_relu_depthwise_conv_pass");
} }
...@@ -79,19 +100,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -79,19 +100,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace. // Add automatically inplace.
if (strategy_.enable_inplace_) { if (strategy_.enable_inplace_) {
VLOG(5) << "Add inplace_pass"; VLOG(1) << "Add inplace_pass";
AppendPass("inplace_pass"); AppendPass("inplace_pass");
} }
if (strategy_.fuse_elewise_add_act_ops_) { if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(5) << "Add fuse_elewise_add_act_pass"; VLOG(1) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass"); AppendPass("fuse_elewise_add_act_pass");
} }
// for single card training, fuse_all_reduce_ops is unnecessary. // for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass. // alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
VLOG(5) << "Add alloc_continuous_space_for_grad_pass"; VLOG(1) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass"); AppendPass("alloc_continuous_space_for_grad_pass");
} }
...@@ -106,11 +127,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -106,11 +127,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first, // NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing. // if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused. // Currently, only one type of optimization algorithm can be fused.
VLOG(5) << "Add fuse_adam_op_pass"; VLOG(1) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass"); AppendPass("fuse_adam_op_pass");
VLOG(5) << "Add fuse_sgd_op_pass"; VLOG(1) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_sgd_op_pass");
VLOG(5) << "Add fuse_momentum_op_pass"; VLOG(1) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass"); AppendPass("fuse_momentum_op_pass");
} }
} }
...@@ -140,7 +161,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -140,7 +161,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars // A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface. // , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) { if (strategy_.memory_optimize_) {
VLOG(5) << "Add memory_optimize_pass"; VLOG(1) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass"); AppendPass("memory_optimize_pass");
} }
...@@ -148,26 +169,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -148,26 +169,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// all original and fused operators. But no operators can be enabled this // all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass. // attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) { if (strategy_.cache_runtime_context_) {
VLOG(5) << "Add runtime_context_cache_pass"; VLOG(1) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass"); AppendPass("runtime_context_cache_pass");
} }
if (strategy_.cache_expected_kernel_) {
VLOG(10) << "Add expected_kernel_cache_pass";
AppendPass("expected_kernel_cache_pass");
}
AppendMultiDevPass(strategy_); AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing. // first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(5) << "Add fuse_all_reduce_op_pass"; VLOG(1) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass"); AppendPass("fuse_all_reduce_op_pass");
} }
// Add a graph print pass to record a graph with device info. // Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add multi_devices_print_pass";
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
const std::string graph_path = const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(), string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
...@@ -183,16 +200,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -183,16 +200,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (!strategy_.enable_parallel_graph_ && if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) || (SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) { strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(5) << "Add all_reduce_deps_pass"; VLOG(1) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
} }
if (strategy_.enable_backward_optimizer_op_deps_) {
VLOG(1) << "Add backward_op_deps_pass";
AppendPass("backward_optimizer_op_deps_pass");
}
if (strategy_.remove_unnecessary_lock_) { if (strategy_.remove_unnecessary_lock_) {
VLOG(5) << "Add modify_op_lock_and_record_event_pass"; VLOG(1) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass"); AppendPass("modify_op_lock_and_record_event_pass");
} }
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
VLOG(1) << "Add multi_devices_check_pass";
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
} }
...@@ -201,18 +224,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -201,18 +224,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ir::Pass *multi_devices_pass = nullptr; ir::Pass *multi_devices_pass = nullptr;
if (strategy_.async_mode_) { if (strategy_.async_mode_) {
VLOG(1) << "Add async_multi_devices_pass";
multi_devices_pass = AppendPass("async_multi_devices_pass").get(); multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) { } else if (strategy_.is_distribution_) {
VLOG(5) VLOG(1)
<< "Add dist_multi_devices_pass, multi device parameter server mode"; << "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(5) << "Add all_reduce_mode_multi_devices_pass"; VLOG(1) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass = multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get(); AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(5) << "Add reduce_mode_multi_devices_pass"; VLOG(1) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else { } else {
PADDLE_THROW("Unknown reduce strategy."); PADDLE_THROW("Unknown reduce strategy.");
...@@ -249,7 +273,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -249,7 +273,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const { platform::NCCLCommunicator *nccl_ctxs) const {
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
...@@ -271,9 +295,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -271,9 +295,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Set<size_t>(ir::kNRanks, new size_t(nranks)); pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif #endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" || } else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" || pass->Type() == "fuse_adam_op_pass" ||
...@@ -287,9 +311,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -287,9 +311,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
&local_scopes); &local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") { if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif #endif
} }
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") { } else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
...@@ -302,6 +329,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -302,6 +329,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
LOG(INFO) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
} else if (pass->Type() == "all_reduce_deps_pass") { } else if (pass->Type() == "all_reduce_deps_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif
LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this) LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_; << ", num_trainers:" << num_trainers_;
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
...@@ -313,6 +348,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -313,6 +348,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "inplace_pass") { } else if (pass->Type() == "inplace_pass") {
pass->Erase(ir::kUseCuda); pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda)); pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -339,6 +377,7 @@ USE_PASS(multi_devices_print_pass); ...@@ -339,6 +377,7 @@ USE_PASS(multi_devices_print_pass);
USE_PASS(memory_optimize_pass); USE_PASS(memory_optimize_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass); USE_PASS(all_reduce_deps_pass);
USE_PASS(backward_optimizer_op_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass); USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass); USE_PASS(lock_free_optimize_pass);
...@@ -349,5 +388,7 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -349,5 +388,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
USE_PASS(record_skip_memory_opt_vars_pass); USE_PASS(record_skip_memory_opt_vars_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
...@@ -79,6 +80,8 @@ struct BuildStrategy { ...@@ -79,6 +80,8 @@ struct BuildStrategy {
bool fuse_all_reduce_ops_{false}; bool fuse_all_reduce_ops_{false};
bool enable_backward_optimizer_op_deps_{false};
bool fuse_relu_depthwise_conv_{false}; bool fuse_relu_depthwise_conv_{false};
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
...@@ -108,7 +111,18 @@ struct BuildStrategy { ...@@ -108,7 +111,18 @@ struct BuildStrategy {
bool remove_unnecessary_lock_{true}; bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false}; bool cache_runtime_context_{false};
bool cache_expected_kernel_{true}; std::unordered_set<std::string> mkldnn_enabled_op_types_;
size_t nccl_comm_num_{1};
// The picture is here:
// https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
bool use_hierarchical_allreduce_{false};
// Nccl ranks in a node when use hierarchical allreduce, it's setted to gpu
// cards' number in most cases.
size_t hierarchical_allreduce_inter_nranks_{0};
// Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to
// nodes number.
size_t hierarchical_allreduce_exter_nranks_{0};
// NOTE: // NOTE:
// Before you add new options, think if it's a general strategy that works // Before you add new options, think if it's a general strategy that works
...@@ -135,7 +149,7 @@ struct BuildStrategy { ...@@ -135,7 +149,7 @@ struct BuildStrategy {
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const; platform::NCCLCommunicator *nccl_ctxs) const;
#else #else
const bool use_cuda) const; const bool use_cuda) const;
#endif #endif
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
...@@ -65,6 +66,7 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -65,6 +66,7 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
Scope *exec_scope = nullptr; Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -43,62 +44,43 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -43,62 +44,43 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_.emplace_back(op); bootstrap_ops_.emplace_back(op);
} }
} }
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
} }
FeedFetchList FastThreadedSSAGraphExecutor::Run( FeedFetchList FastThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>> std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get(); op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
size_t num_ops = op_deps->size();
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops; std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear();
for (auto &fetch_var_name : fetch_tensors) { InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { &fetch_ops, &ready_fetch_ops);
auto it = var_map.find(fetch_var_name); event.reset(nullptr);
if (it != var_map.end()) { if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
fetched_vars[fetch_var_name].push_back(*it->second.rbegin()); // If the num_threads is 1, we can record the order of operator's
} // execution in the first iteration, and in subsequent iterations,
} // run the recorded operators directly. This strategy could make the
} // execution faster.
VLOG(3) << "Run the traced ops.";
for (size_t i = 0; i < fetch_tensors.size(); ++i) { RunTracedOps(traced_ops_);
auto &var_name = fetch_tensors[i]; RunTracedOps(fetch_ops);
auto fetched_var_it = fetched_vars.find(var_name); if (exception_.IsCaught()) {
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), ExecutionFinal(&fetch_ops);
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second;
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops.emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
} }
} else {
size_t num_complete = 0; traced_ops_.clear();
remaining_ = 0; remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>(); auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
...@@ -107,6 +89,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -107,6 +89,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
for (auto op : ready_fetch_ops) { for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
size_t num_complete = 0;
while (num_complete != op_deps->size()) { while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop(); size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) { if (num_comp == -1UL) {
...@@ -121,28 +105,74 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -121,28 +105,74 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
} }
} }
if (exception_.IsCaught()) { if (exception_.IsCaught()) {
ClearFetchOp(graph_, &fetch_ops); ExecutionFinal(&fetch_ops);
exception_.ReThrow();
} }
} }
num_complete += num_comp; num_complete += num_comp;
} }
}
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
return fetches; return fetches;
} }
void FastThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops) {
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors.at(i);
auto fetched_var_it = fetched_vars->find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(),
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second;
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops->emplace_back(op);
}
}
}
bool FastThreadedSSAGraphExecutor::RunOp( bool FastThreadedSSAGraphExecutor::RunOp(
OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q, OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete) { size_t *complete) {
try { RunOpSync(op);
if (LIKELY(!exception_.IsCaught())) {
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); RecordOps(op);
} }
++(*complete); ++(*complete);
return true; return true;
} catch (...) { } else {
exception_.Catch(std::current_exception());
--remaining_; --remaining_;
complete_q->Push(-1UL); complete_q->Push(-1UL);
return false; return false;
...@@ -194,6 +224,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -194,6 +224,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
complete_q->Push(complete); complete_q->Push(complete);
}); });
} }
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] { atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>; auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
...@@ -206,6 +237,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { ...@@ -206,6 +237,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
} }
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_.Catch(std::current_exception());
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool pool_; ::ThreadPool pool_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
std::vector<OpHandleBase *> traced_ops_;
bool RunOp(OpHandleBase *op, bool RunOp(OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete); size_t *complete);
...@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::shared_ptr<BlockingQueue<size_t>> &complete_q); const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps(); void PrepareAtomicOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -44,17 +44,10 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>> ...@@ -44,17 +44,10 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
FusedAllReduceOpHandle::FusedAllReduceOpHandle( FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce, const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
const platform::NCCLContextMap *ctxs) const platform::NCCLCommunicator *ctxs)
: OpHandleBase(node), : NCCLOpHandleBase(node, places, ctxs),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), num_of_all_reduce_(num_of_all_reduce) {
num_of_all_reduce_(num_of_all_reduce),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
} }
#else #else
...@@ -167,17 +160,14 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -167,17 +160,14 @@ void FusedAllReduceOpHandle::RunImpl() {
auto &p = places_[i]; auto &p = places_[i];
void *buffer = const_cast<void *>(lod_tensor_data.at(i)); void *buffer = const_cast<void *>(lod_tensor_data.at(i));
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( NCCLAllReduce(p, buffer, buffer, numel,
buffer, buffer, numel, static_cast<ncclDataType_t>(nccl_dtype), static_cast<ncclDataType_t>(nccl_dtype), ncclSum);
ncclSum, comm, stream));
}); });
} }
VLOG(10) << "fusedallreduce size:" << numel * SizeOfType(dtype);
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) { if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device // Do not use NCCLGroup when manage NCCL by per thread per device
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -28,14 +29,15 @@ namespace paddle { ...@@ -28,14 +29,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct FusedAllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
struct FusedAllReduceOpHandle : public NCCLOpHandleBase {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const platform::NCCLContextMap *ctxs); const platform::NCCLCommunicator *ctxs);
#else #else
struct FusedAllReduceOpHandle : public OpHandleBase {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -52,11 +54,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase { ...@@ -52,11 +54,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
private: private:
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
size_t num_of_all_reduce_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
size_t num_of_all_reduce_;
// Check the dtype of the input // Check the dtype of the input
void GetDTypeAndNumel( void GetDTypeAndNumel(
......
...@@ -45,6 +45,7 @@ constexpr char kGraphVars[] = "vars"; ...@@ -45,6 +45,7 @@ constexpr char kGraphVars[] = "vars";
constexpr char kPlaces[] = "places"; constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs"; constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars; typedef std::unordered_set<VarHandleBase *> GraphDepVars;
......
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
namespace details { namespace details {
std::string OpHandleBase::DebugString() const { std::string OpHandleBase::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "("; ss << Name() << "(";
for (auto *var : inputs_) { for (auto *var : inputs_) {
ss << var->DebugString() << ", "; ss << var->DebugString() << ", ";
} }
...@@ -187,6 +187,11 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -187,6 +187,11 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
std::function<void()> method = callback; std::function<void()> method = callback;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
method = [method, p, this]() { method = [method, p, this]() {
VLOG(10) << "cudadevicecontext:"
<< static_cast<platform::CUDADeviceContext *>(p.second)
<< ", dev_id:"
<< boost::get<platform::CUDAPlace>(p.first).device;
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent( static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device), events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method); method);
......
...@@ -95,6 +95,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -95,6 +95,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
for (size_t i = 0; i < graphs_.size(); ++i) { for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release())); graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -29,6 +30,8 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, ...@@ -29,6 +30,8 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place(); auto &p = static_cast<VarHandle *>(in)->place();
if (ir::IsControlDepVar(*in->Node())) { if (ir::IsControlDepVar(*in->Node())) {
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include <string> #include <string>
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -67,6 +67,7 @@ struct ScaleLossGradFunctor { ...@@ -67,6 +67,7 @@ struct ScaleLossGradFunctor {
}; };
void ScaleLossGradOpHandle::RunImpl() { void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
// Doesn't wait any event // Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name(); std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -36,26 +36,10 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -36,26 +36,10 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) { if (drop_scope_counter_ == 0) {
// Create local scopes. platform::RecordEvent e("InitLocalExeScopes");
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) { PrepareLocalExeScopes();
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
} }
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
}
std::vector<framework::LoDTensor> fetch_data; std::vector<framework::LoDTensor> fetch_data;
std::exception_ptr eptr = nullptr; std::exception_ptr eptr = nullptr;
try { try {
...@@ -64,9 +48,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -64,9 +48,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
eptr = std::current_exception(); eptr = std::current_exception();
} }
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun");
++drop_scope_counter_; ++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
DropLocalExeScopes(); DropLocalExeScopes();
} }
...@@ -78,17 +60,41 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -78,17 +60,41 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
} }
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() { void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
for (auto p : places_) { for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
auto &local_scope = auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName);
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>(); if (local_scope_var != nullptr) {
auto &local_scope = *local_scope_var->GetMutable<Scope *>();
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
scope->EraseVars({std::string(details::kLocalExecScopeName)});
VLOG(3) << "Drop local execution scope: " << local_scope; VLOG(3) << "Drop local execution scope: " << local_scope;
} }
}
}
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
} }
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() { bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <ThreadPool.h>
#include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -51,6 +52,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,6 +52,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
bool NeedCreateLocalExeScope(); bool NeedCreateLocalExeScope();
void PrepareLocalExeScopes();
private: private:
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -30,7 +30,7 @@ namespace details { ...@@ -30,7 +30,7 @@ namespace details {
SparseAllReduceOpHandle::SparseAllReduceOpHandle( SparseAllReduceOpHandle::SparseAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs, bool is_encoded, int nranks) const platform::NCCLCommunicator *ctxs, bool is_encoded, int nranks)
: AllReduceOpHandle(node, local_scopes, places, ctxs), : AllReduceOpHandle(node, local_scopes, places, ctxs),
is_encoded_(is_encoded), is_encoded_(is_encoded),
nranks_(nranks) { nranks_(nranks) {
...@@ -102,7 +102,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -102,7 +102,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel; out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device; int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
......
...@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { ...@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
SparseAllReduceOpHandle(ir::Node *node, SparseAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs, const platform::NCCLCommunicator *ctxs,
bool is_encoded = false, int nranks = -1); bool is_encoded = false, int nranks = -1);
std::string Name() const override; std::string Name() const override;
......
...@@ -19,10 +19,13 @@ namespace framework { ...@@ -19,10 +19,13 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) { void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
......
...@@ -38,7 +38,7 @@ class SSAGraphExecutor { ...@@ -38,7 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops); void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -53,27 +53,40 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -53,27 +53,40 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get(); std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
CopyOpDeps(); CopyOpDeps();
VLOG(10) << "ThreadedSSAGraphExecutor::Run"; VLOG(10) << "ThreadedSSAGraphExecutor::Run";
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars( std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
new BlockingQueue<VarHandleBase *>); new BlockingQueue<VarHandleBase *>);
auto &pending_ops = op_deps->pending_ops_; auto &pending_ops = op_deps->pending_ops_;
auto &pending_vars = op_deps->pending_vars_; auto &pending_vars = op_deps->pending_vars_;
auto &ready_ops = op_deps->ready_ops_; auto &ready_ops = op_deps->ready_ops_;
size_t num_ops = op_deps->num_ops_;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<FetchOpHandle *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_ops, &pending_vars, &fetch_data); &pending_ops, &pending_vars, &fetch_data);
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
RunOp(ready_vars, op); RunOp(ready_vars, op);
...@@ -82,9 +95,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -82,9 +95,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
}; };
// Clean run context // Clean run context
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
// Keep loop until all vars are ready. // Keep loop until all vars are ready.
...@@ -94,14 +105,11 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -94,14 +105,11 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_, &fetch_ops); if (exception_holder_.IsCaught()) {
exception_holder_.ReThrow(); ExecutionFinal(&fetch_ops);
} else { } else {
continue; continue;
} }
...@@ -121,6 +129,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -121,6 +129,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
} }
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
}
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
...@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
...@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() { ...@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
InsertPendingOp(&pending_ops, op); InsertPendingOp(&pending_ops, op);
} }
} }
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators.");
for (auto ready_var : ready_vars) { for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) { for (auto *op : ready_var->PendingOps()) {
...@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() { ...@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_->pending_vars_.end()); op_deps_->pending_vars_.end());
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(), op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
op_deps_->ready_ops_.end()); op_deps_->ready_ops_.end());
op_deps->num_ops_ = op_deps_->num_ops_;
return std::unique_ptr<OpDependentData>(op_deps); return std::unique_ptr<OpDependentData>(op_deps);
}); });
} }
...@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp(
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) { details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
RunOpSync(op);
try { try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted"; VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
}; };
if (pool_) { if (pool_) {
run_op_futures_.emplace_back(pool_->enqueue(op_run)); run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else { } else {
op_run(); op_run();
} }
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
}
void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow();
}
void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -44,6 +44,7 @@ struct OpDependentData { ...@@ -44,6 +44,7 @@ struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_; std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_; std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_; std::unordered_set<OpHandleBase *> ready_ops_;
size_t num_ops_{0};
}; };
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::list<std::future<void>> run_op_futures_; std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<OpHandleBase *> traced_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops, void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const; OpHandleBase *op_instance) const;
...@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors, void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
...@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList *fetch_data); FeedFetchList *fetch_data);
void PrepareOpDeps(); void PrepareOpDeps();
void CopyOpDeps(); void CopyOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
}; };
} // namespace details } // namespace details
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <atomic>
#include <fstream> #include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -35,9 +36,17 @@ limitations under the License. */ ...@@ -35,9 +36,17 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define SEC_LOG \
VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \
<< "]: "
class PullDenseWorker { class PullDenseWorker {
public: public:
virtual ~PullDenseWorker() {} virtual ~PullDenseWorker() {}
...@@ -48,6 +57,7 @@ class PullDenseWorker { ...@@ -48,6 +57,7 @@ class PullDenseWorker {
void IncreaseThreadVersion(int thread_id, uint64_t table_id); void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id); void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec); void Wait(std::vector<::std::future<int32_t>>* status_vec);
void PullDense(bool force_update = false);
static std::shared_ptr<PullDenseWorker> GetInstance() { static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker()); s_instance_.reset(new paddle::framework::PullDenseWorker());
...@@ -92,7 +102,7 @@ class PullDenseWorker { ...@@ -92,7 +102,7 @@ class PullDenseWorker {
// should incorporate different type of device // should incorporate different type of device
class DeviceWorker { class DeviceWorker {
public: public:
DeviceWorker() {} DeviceWorker() { use_cvm_ = false; }
virtual ~DeviceWorker() {} virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0; virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0; virtual void SetDeviceIndex(int tid) = 0;
...@@ -114,6 +124,7 @@ class DeviceWorker { ...@@ -114,6 +124,7 @@ class DeviceWorker {
std::shared_ptr<DataFeed> device_reader_; std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_; int64_t batch_num_;
FetchConfig fetch_config_; FetchConfig fetch_config_;
bool use_cvm_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
...@@ -194,5 +205,101 @@ class DownpourWorker : public HogwildWorker { ...@@ -194,5 +205,101 @@ class DownpourWorker : public HogwildWorker {
std::vector<::std::future<int32_t>> push_dense_status_; std::vector<::std::future<int32_t>> push_dense_status_;
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
using ScopeQueue = operators::reader::BlockingQueue<Scope*>;
class SyncFunctor {
public:
SyncFunctor(int rank_id, int rank_num, int sync_steps);
virtual ~SyncFunctor() {}
void SetSyncParam(const std::vector<std::string>& sync_param) {
sync_param_ = &sync_param;
}
void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) {
nccl_ctx_map_ = nccl_ctx_map;
}
int operator()(Scope* scope);
static std::vector<Scope*> pipeline_scopes_;
static uint64_t sync_flag_;
protected:
const int rank_id_;
const int rank_num_;
const std::vector<std::string>* sync_param_ = nullptr;
platform::NCCLContextMap* nccl_ctx_map_ = nullptr;
uint64_t sync_signal_;
const int sync_steps_;
int counter_;
void Synchronize();
};
class SectionWorker : public DeviceWorker {
public:
SectionWorker() {}
~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
void BindingDataFeedMemory() override {}
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
void TrainFiles() override;
void TrainFilesWithProfiler() override;
void PrintFetchVars() override {}
const platform::Place& place() const { return place_; }
void SetSectionIndex(int section_id) { section_id_ = section_id; }
void SetDeviceIndex(int tid) override { pipeline_id_ = tid; }
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetVarNames(const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
in_var_names_ = &in_var_names;
out_var_names_ = &out_var_names;
}
void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) {
in_scope_queue_ = in_scope_queue;
out_scope_queue_ = out_scope_queue;
}
void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; }
void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; }
void SetSectionNum(int section_num) { section_num_ = section_num; }
void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; }
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
}
SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }
static std::atomic<int> cpu_id_;
protected:
void AutoSetCPUAffinity(bool reuse);
int section_id_;
int pipeline_id_;
int section_num_;
int pipeline_num_;
int thread_id_;
// This worker will consume scope from in_scope_queue_
// and produce scope to out_scope_queue_
ScopeQueue* in_scope_queue_ = nullptr;
ScopeQueue* out_scope_queue_ = nullptr;
const std::vector<std::string>* in_var_names_ = nullptr;
const std::vector<std::string>* out_var_names_ = nullptr;
std::mutex* worker_count_mutex_ = nullptr;
int* worker_count_ = nullptr;
paddle::platform::Place next_section_place_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -61,5 +61,8 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker( ...@@ -61,5 +61,8 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -63,6 +63,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -63,6 +63,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm();
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
...@@ -139,6 +140,16 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -139,6 +140,16 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
LoD data_lod{tensor_lod}; LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod); tensor_emb->set_lod(data_lod);
for (int index = 0; index < len; ++index) { for (int index = 0; index < len; ++index) {
if (use_cvm_) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(),
sizeof(float) * table.emb_dim());
fea_idx++;
} else {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2, memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim()); sizeof(float) * table.emb_dim());
...@@ -149,6 +160,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -149,6 +160,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
fea_idx++; fea_idx++;
} }
} }
}
} }
void DownpourWorker::TrainFilesWithProfiler() { void DownpourWorker::TrainFilesWithProfiler() {
...@@ -197,9 +209,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -197,9 +209,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i)); param_.program_config(0).pull_sparse_table_id(i));
TableParameter table; TableParameter table;
for (auto i : param_.sparse_table()) { for (auto j : param_.sparse_table()) {
if (i.table_id() == tid) { if (j.table_id() == tid) {
table = i; table = j;
break; break;
} }
} }
...@@ -259,7 +271,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -259,7 +271,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_); &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -367,9 +379,9 @@ void DownpourWorker::TrainFiles() { ...@@ -367,9 +379,9 @@ void DownpourWorker::TrainFiles() {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i)); param_.program_config(0).pull_sparse_table_id(i));
TableParameter table; TableParameter table;
for (auto i : param_.sparse_table()) { for (auto j : param_.sparse_table()) {
if (i.table_id() == tid) { if (j.table_id() == tid) {
table = i; table = j;
break; break;
} }
} }
...@@ -411,7 +423,7 @@ void DownpourWorker::TrainFiles() { ...@@ -411,7 +423,7 @@ void DownpourWorker::TrainFiles() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_); &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_);
} }
} }
......
...@@ -122,8 +122,9 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope, ...@@ -122,8 +122,9 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
const std::string& trainer_desc_str) { const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor"; VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc; TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str, bool success = trainer_desc.ParseFromString(trainer_desc_str);
&trainer_desc); PADDLE_ENFORCE(success, "Fail to parse TrainerDesc from string:\n%s",
trainer_desc_str.c_str());
VLOG(3) << "Going to create trainer, trainer class is " VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name(); << trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer; std::shared_ptr<TrainerBase> trainer;
...@@ -244,6 +245,12 @@ static bool has_fetch_operators( ...@@ -244,6 +245,12 @@ static bool has_fetch_operators(
return fetch_count > 0; return fetch_count > 0;
} }
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc);
}
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
...@@ -328,7 +335,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -328,7 +335,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) { if (FLAGS_use_ngraph && ctx->block_id_ == 0) {
paddle::operators::NgraphEngine::FuseNgraphOps( paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_); ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
} }
...@@ -368,6 +375,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -368,6 +375,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
bool keep_kids) { bool keep_kids) {
platform::RecordBlock b(kProgramId);
PADDLE_ENFORCE_NOT_NULL(scope); PADDLE_ENFORCE_NOT_NULL(scope);
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
...@@ -407,7 +415,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -407,7 +415,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get()); DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
} }
......
...@@ -83,6 +83,21 @@ class Executor { ...@@ -83,6 +83,21 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
std::unique_ptr<ExecutorPrepareContext> PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
...@@ -101,15 +116,6 @@ class Executor { ...@@ -101,15 +116,6 @@ class Executor {
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true, bool keep_kids = false); bool create_vars = true, bool keep_kids = false);
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope, void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
......
...@@ -281,9 +281,16 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -281,9 +281,16 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
const std::vector<std::string>& sparse_key_names, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status) { std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int grad_dim = emb_dim;
if (use_cvm) {
offset = 0;
grad_dim = emb_dim - 2;
}
CHECK_GE(grad_dim, 0);
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]); Variable* g_var = scope.FindVar(sparse_grad_names[i]);
...@@ -307,7 +314,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -307,7 +314,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
for (auto& t : *push_values) { for (auto& t : *push_values) {
t.resize(emb_dim + offset); t.resize(emb_dim + offset);
} }
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset;
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / dim, dim);
g_mat.rightCols(grad_dim) *= batch_size;
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) { for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) { if (ids[id_idx] == 0) {
g += emb_dim; g += emb_dim;
...@@ -315,10 +328,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -315,10 +328,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size()); CHECK(fea_idx < fea_labels.size());
if (use_cvm) {
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
} else {
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f; (*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]); (*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
}
g += emb_dim; g += emb_dim;
fea_idx++; fea_idx++;
} }
...@@ -337,6 +355,89 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -337,6 +355,89 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::SaveModel does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
ret.wait();
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list,
float decay) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (std::string& name : var_list) {
if (name.find("batch_sum") != std::string::npos) {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
VLOG(3) << "prepare shrink dense batch_sum";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
Eigen::Map<Eigen::MatrixXf> mat(g, 1, tensor->numel());
mat *= decay;
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
} else {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push shrink dense param failed, status[" << status << "]";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ClientFlush() {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->flush();
ret.wait();
#else
VLOG(0) << "FleetWrapper::ServerFlush does nothing when no pslib";
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
...@@ -398,6 +499,24 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) { ...@@ -398,6 +499,24 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#endif #endif
} }
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
#ifdef PADDLE_WITH_PSLIB
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
#endif
};
thread_local engine_wrapper_t r;
return r.engine;
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>( template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*); const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>( template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
......
...@@ -55,7 +55,7 @@ namespace framework { ...@@ -55,7 +55,7 @@ namespace framework {
class FleetWrapper { class FleetWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() {} FleetWrapper() { scale_sparse_gradient_with_batch_size_ = true; }
// Pull sparse variables from server in Sync mode // Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys // Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values // Param<out>: fea_values
...@@ -99,7 +99,8 @@ class FleetWrapper { ...@@ -99,7 +99,8 @@ class FleetWrapper {
const std::vector<std::string>& sparse_key_names, const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status); std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
...@@ -128,6 +129,19 @@ class FleetWrapper { ...@@ -128,6 +129,19 @@ class FleetWrapper {
// create client to client connection // create client to client connection
void CreateClient2ClientConnection(); void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
void ShrinkSparseTable(int table_id);
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay);
// register client to client communication // register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
...@@ -146,6 +160,9 @@ class FleetWrapper { ...@@ -146,6 +160,9 @@ class FleetWrapper {
return s_instance_; return s_instance_;
} }
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_; static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif #endif
...@@ -158,6 +175,7 @@ class FleetWrapper { ...@@ -158,6 +175,7 @@ class FleetWrapper {
protected: protected:
static bool is_initialized_; static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper); DISABLE_COPY_AND_ASSIGN(FleetWrapper);
}; };
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
// option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
package paddle.framework.proto; package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should // Any incompatible changes to ProgramDesc and its dependencies should
......
...@@ -24,9 +24,10 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) { ...@@ -24,9 +24,10 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param(); param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size()); skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) { for (int i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
use_cvm_ = desc.use_cvm();
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
......
...@@ -72,12 +72,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference) ...@@ -72,12 +72,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference) pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
if(ANAKIN_FOUND) if(ANAKIN_SUBGRAPH)
pass_library(simplify_anakin_priorbox_detection_out_pass inference) pass_library(simplify_anakin_priorbox_detection_out_pass inference)
endif() endif()
...@@ -86,12 +86,23 @@ if(WITH_MKLDNN) ...@@ -86,12 +86,23 @@ if(WITH_MKLDNN)
pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn) pass_library(cpu_quantize_squash_pass inference mkldnn)
endif() endif()
if(WITH_NGRAPH)
cc_library(ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge
analysis_helper subgraph_detector graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(APPEND ${pass_file} "USE_PASS(ngraph_subgraph_pass);\n")
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
...@@ -115,6 +126,8 @@ if (WITH_MKLDNN) ...@@ -115,6 +126,8 @@ if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
......
...@@ -23,15 +23,16 @@ ...@@ -23,15 +23,16 @@
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes
"fuse_parameter_memory_size is up limited memory size " "fuse_parameter_memory_size is up limited memory size(MB)"
"of one group parameters' gradient which is the input " "of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). " "of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that " "The default value is 0, it means that "
"not set group according to memory_size."); "not set group according to memory_size.");
DEFINE_int32( DEFINE_int32(
fuse_parameter_groups_size, 3, fuse_parameter_groups_size, 1,
"fuse_parameter_groups_size is the size of one group parameters' gradient. " "fuse_parameter_groups_size is the up limited size of one group "
"parameters' gradient. "
"The default value is a experimental result. If the " "The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is " "fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is " "the number of parameters' gradient. If the fuse_parameter_groups_size is "
...@@ -41,6 +42,9 @@ DEFINE_int32( ...@@ -41,6 +42,9 @@ DEFINE_int32(
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// unit of the FLAGS_fuse_parameter_memory_size.
static constexpr double kMB = 1048576.0;
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit // SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size' // test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test. // and 'FLAGS_fuse_parameter_groups_size' in unit test.
...@@ -50,15 +54,12 @@ void SetFuseParameterGroupsSize(int group_size) { ...@@ -50,15 +54,12 @@ void SetFuseParameterGroupsSize(int group_size) {
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; } int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) { void SetFuseParameterMemorySize(double memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size; FLAGS_fuse_parameter_memory_size = memory_size;
} }
uint64_t GetFuseParameterMemorySize() { double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; }
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype = static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL; framework::proto::VarType::Type::VarType_Type_BOOL;
...@@ -83,7 +84,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -83,7 +84,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
if (params_grads.size() == 0) { if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients"; LOG(WARNING) << "Doesn't find gradients";
return; return;
} }
...@@ -169,7 +170,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -169,7 +170,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
details::GroupGradsAndParams *group_grads_params) const { details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params); SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params); SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
} }
void SetGroupAccordingToLayers( void SetGroupAccordingToLayers(
...@@ -181,7 +181,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -181,7 +181,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
for (size_t i = 0; i < params_grads.size(); ++i) { for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of("."); auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) { if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i); layer_params[params_grads[i].first].emplace_back(i);
} else { } else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i); layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
} }
...@@ -190,7 +190,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -190,7 +190,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_grads_params->reserve(layer_params.size()); group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) { for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of("."); auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow; std::string key = params_grads[i].first;
if (pos != std::string::npos) { if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos); key = params_grads[i].first.substr(0, pos);
} }
...@@ -207,21 +207,40 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -207,21 +207,40 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
VLOG(10) << "SetGroupAccordingToLayers: "; VLOG(10) << "SetGroupAccordingToLayers: ";
if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params);
}
}
void PrintGroupInfo(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
for (size_t i = 0; i < group_grads_params->size(); ++i) { for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i; VLOG(10) << "group " << i;
std::stringstream out; std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) { size_t gps_size = 0;
out << "(" << p_g.second << ", " << p_g.first << "), "; for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
gps_size += size;
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
} }
VLOG(10) << out.str(); VLOG(10) << out.str()
<< ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << static_cast<double>(gps_size) / kMB
<< "(MB)";
} }
} }
void SetGroupAccordingToMemorySize( void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes, const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const { details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize(); const double group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) { if (group_memory_size <= 0.0) {
return; return;
} }
details::GroupGradsAndParams local_group_grads_params; details::GroupGradsAndParams local_group_grads_params;
...@@ -248,69 +267,26 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -248,69 +267,26 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(), group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end()); group_grads_params->at(j).end());
++j; ++j;
if (local_group_memory_size >= group_memory_size) { if (GetFuseParameterGroupsSize() > 1 &&
group_p_g.size() >
static_cast<size_t>(GetFuseParameterGroupsSize())) {
break; break;
} }
}
}
std::swap(*group_grads_params, local_group_grads_params); if (static_cast<double>(local_group_memory_size) / kMB >=
group_memory_size) {
VLOG(10) << string::Sprintf( break;
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
} }
VLOG(10) << out.str();
} }
} }
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params); std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):", VLOG(10) << string::Sprintf(
group_size); "SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i; if (VLOG_IS_ON(10)) {
std::stringstream out; PrintGroupInfo(var_nodes, group_grads_params);
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
} }
} }
......
...@@ -21,8 +21,8 @@ namespace ir { ...@@ -21,8 +21,8 @@ namespace ir {
void SetFuseParameterGroupsSize(int group_size); void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize(); int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size); void SetFuseParameterMemorySize(double memory_size);
uint64_t GetFuseParameterMemorySize(); double GetFuseParameterMemorySize();
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param) { void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters // Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters. // Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>(); scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>(); scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>(); scope.Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>(); scope.Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>(); scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>(); scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>(); scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>(); scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \ #define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \ VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
...@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) { ...@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
GATE_W(c); GATE_W(c);
#undef GATE_W #undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0"); auto* attention_fc_w = scope.FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0"); auto* attention_fc_b = scope.FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0"); auto* attention_output_w = scope.FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0"); auto* attention_output_b = scope.FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w, CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b); attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight); auto* lstm_weight = scope.Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>(); auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias); auto* lstm_bias = scope.Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>(); auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias // reshape attention_bias
auto* attention_bias_t = auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>(); scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t = auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>(); scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize( attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]})); make_ddim({1, attention_scalar_bias_t->dims()[0]}));
......
...@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \ #define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \ const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \ op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>() scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -77,9 +78,15 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -77,9 +78,15 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8")); desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale")); desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale")); desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
if (base_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
} }
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
...@@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
PADDLE_ENFORCE(scope);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias // Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name()); auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor = auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>(); fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var); PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope->FindVar(bias->Name()); auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name()); auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var); PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
...@@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef GET_NODE #undef GET_NODE
#define NEW_IMTERMEDIATE_OUT(key) \ #define NEW_IMTERMEDIATE_OUT(key) \
scope->Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>() scope.Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
NEW_IMTERMEDIATE_OUT(ReorderedH0); NEW_IMTERMEDIATE_OUT(ReorderedH0);
NEW_IMTERMEDIATE_OUT(XX); NEW_IMTERMEDIATE_OUT(XX);
NEW_IMTERMEDIATE_OUT(BatchedInput); NEW_IMTERMEDIATE_OUT(BatchedInput);
......
...@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \ #define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \ const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \ op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>() scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
......
...@@ -26,7 +26,7 @@ namespace framework { ...@@ -26,7 +26,7 @@ namespace framework {
namespace ir { namespace ir {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"}; std::unordered_set<std::string> act_types = {"relu", "scale", "tanh"};
graph = FuseActElewiseAdd(graph, act_types); graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types); graph = FuseElewiseAddAct(graph, act_types);
// backward // backward
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <unordered_map>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { ...@@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
Scope* FusePassBase::param_scope() const { Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr); auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope;
} }
void FusePassBase::AddStatis(int count_of_fused) const { void FusePassBase::AddStatis(int count_of_fused) const {
...@@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1, ...@@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1,
#else #else
return FUSE_NATIVE; return FUSE_NATIVE;
#endif #endif
}; }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -134,6 +134,7 @@ void Graph::ResolveHazard( ...@@ -134,6 +134,7 @@ void Graph::ResolveHazard(
ir::Node *dep_var = CreateControlDepVar(); ir::Node *dep_var = CreateControlDepVar();
write_op->inputs.push_back(dep_var); write_op->inputs.push_back(dep_var);
upstream_op->outputs.push_back(dep_var); upstream_op->outputs.push_back(dep_var);
VLOG(10) << "add dep_var:" << dep_var->Name();
dep_var->outputs.push_back(write_op); dep_var->outputs.push_back(write_op);
dep_var->inputs.push_back(upstream_op); dep_var->inputs.push_back(upstream_op);
} }
...@@ -157,6 +158,7 @@ void Graph::ResolveHazard( ...@@ -157,6 +158,7 @@ void Graph::ResolveHazard(
if (has_dep) continue; if (has_dep) continue;
ir::Node *dep_var = CreateControlDepVar(); ir::Node *dep_var = CreateControlDepVar();
VLOG(10) << "add dep_var:" << dep_var->Name();
read_op->outputs.push_back(dep_var); read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op); dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var); write_op->inputs.push_back(dep_var);
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -785,6 +788,33 @@ PDNode *patterns::ConvReLU::operator()( ...@@ -785,6 +788,33 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var; return relu_out_var;
} }
PDNode *patterns::ConvBReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu6");
// output
auto *brelu_out_var = pattern->NewNode(brelu_out_repr())
->AsOutput()
->assert_is_op_output("relu6");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var});
return brelu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()( PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) { paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators // Create Operators
...@@ -869,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -869,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
} }
} }
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Filter
auto *fc_weight_var = pattern->NewNode(weights_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias_var = pattern->NewNode(bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out_var = pattern->NewNode(output_repr())
->AsOutput()
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
return fc_out_var;
}
PDNode *patterns::Embedding::operator()(PDNode *x) { PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids"); x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op = auto *lookup_table_op =
...@@ -1035,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( ...@@ -1035,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad; return ele_add_grad;
} }
// conv_type: conv2d, conv3d, conv2d_transpose
PDNode *patterns::ConvBias::operator()( PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { paddle::framework::ir::PDNode *conv_input, std::string conv_type) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
// Create Operators // Create Operators
conv_input->assert_is_op_input(type, "Input"); conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type); auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *eltiwse_op = auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables // Create variables
...@@ -1048,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()( ...@@ -1048,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input(type, "Filter"); ->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op(type) ->assert_is_only_output_of_op(conv_type)
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add // Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
...@@ -1157,6 +1214,57 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { ...@@ -1157,6 +1214,57 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var; return out_var;
} }
PDNode *patterns::Concat::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto output_var = pattern->NewNode(concat_out_repr())
->AsOutput()
->assert_is_op_output("concat", "Out");
concat_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ConcatReLU::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto concat_out =
pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
concat_op->LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
PDNode *patterns::ConvConcatReLU::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("relu", "X");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
conv_op->LinksTo({conv_out});
concat_op->LinksFrom({conv_out}).LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
std::unordered_set<std::string> conv_act_set({"identity", "relu"}); std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
...@@ -1641,13 +1749,16 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, ...@@ -1641,13 +1749,16 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type, const std::string &op_type,
const std::string &weight_name, const std::string &weight_name,
int times, int times,
const std::string &quant_type) { const std::string &quant_type,
const int kNumFields = 5; const std::string &dequant_type) {
int kNumFields = 5;
const int kQuantizedWeightOffset = 0; const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1; const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2; const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3; const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4; const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
// the quant op always be one. // the quant op always be one.
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale")) auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale") ->assert_is_op_input(quant_type, "InScale")
...@@ -1655,11 +1766,19 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, ...@@ -1655,11 +1766,19 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
auto quant_op = auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type); pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
auto quant_op_out_scale = PDNode *quant_op_out_scale = nullptr;
pattern->NewNode(GetNodeName("quant_op_out_scale")) if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale") ->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale") ->assert_is_op_nth_input(dequant_type, "Scales", 1)
->AsIntermediate(); ->AsIntermediate();
} else {
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input(dequant_type, "Scale")
->AsIntermediate();
}
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out")) auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out") ->assert_is_op_output(quant_type, "Out")
...@@ -1680,16 +1799,25 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, ...@@ -1680,16 +1799,25 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i)) pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type) ->assert_is_op_output(op_type)
->assert_is_op_input("fake_dequantize_max_abs", "X") ->assert_is_op_input(dequant_type, "X")
->AsIntermediate()); ->AsIntermediate());
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i)) pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op("fake_dequantize_max_abs")); ->assert_is_op(dequant_type));
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i)) pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output("fake_dequantize_max_abs", "Out") ->assert_is_op_output(dequant_type, "Out")
->AsOutput()); ->AsOutput());
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(pattern
->NewNode(GetNodeName("dequant_channel_scale") +
std::to_string(i))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput());
}
} }
quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
...@@ -1699,8 +1827,14 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, ...@@ -1699,8 +1827,14 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]}); {nodes[i * kNumFields + kQuantizedOpOffset]});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
} else {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
}
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]}); {nodes[i * kNumFields + kDequantOpOffset]});
} }
...@@ -1737,6 +1871,41 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { ...@@ -1737,6 +1871,41 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out->LinksFrom({reshape2_op}); reshape2_out->LinksFrom({reshape2_op});
} }
void patterns::DeleteQuantDequantOpPattern::operator()() {
auto any_op_out =
pattern->NewNode(any_op_out_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "X")
->AsInput();
auto quant_dequant_op_inscale =
pattern->NewNode(quant_dequant_op_inscale_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "InScale")
->AsInput();
auto quant_dequant_op =
pattern->NewNode(quant_dequant_op_repr())
->assert_is_op("fake_quantize_dequantize_moving_average_abs_max");
auto quant_dequant_out =
pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "Out")
->AsIntermediate();
auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "OutScale")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ #pragma once
#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -126,5 +125,3 @@ class LockFreeOptimizePass : public Pass { ...@@ -126,5 +125,3 @@ class LockFreeOptimizePass : public Pass {
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
...@@ -48,8 +48,6 @@ DEFINE_bool( ...@@ -48,8 +48,6 @@ DEFINE_bool(
"Such as scale, elementwise_add" "Such as scale, elementwise_add"
"By default, it's turned off"); "By default, it's turned off");
DECLARE_string(memory_optimize_debug);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -461,13 +459,6 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { ...@@ -461,13 +459,6 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue; continue;
} }
// Debug Interface. Which would be skipped by the pass.
if (out_arg == FLAGS_memory_optimize_debug) {
VLOG(4) << "Skiped var by force. FLAGS_memory_optimize_debug="
<< out_node->Name();
continue;
}
VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name() VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name()
<< " in " << op_type; << " in " << op_type;
RenameInOut(op_node, in_node, out_node); RenameInOut(op_node, in_node, out_node);
......
...@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { ...@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
// fail since "states" and "ex_states" cannot be found in main block. // fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their // When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped. // gradient should be skipped.
auto& ex_states = auto ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states")); boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto& states = auto states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states")); boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") { if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states}); UpdateSkipVarSet(skip_vars, {ex_states, states});
...@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { ...@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
UpdateSkipVarSet( UpdateSkipVarSet(
skip_vars, skip_vars,
{ToGradVarName(op_desc->Input("parameters")), {ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("input")), ex_states, states, ToGradVarName(op_desc->Input("inputs")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)}); ToGradVarName(ex_states), ToGradVarName(states)});
} }
} }
......
...@@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
std::string type = is_conv3d() ? "conv3d" : "conv2d";
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
gpd.mutable_pattern() gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput() ->AsInput()
->assert_is_op_input(type, "Input"); ->assert_is_op_input(type(), "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input, is_conv3d()); conv_bias_pattern(conv_input, type());
int found_conv_bias_count = 0; int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
VLOG(3) << "do not perform conv+bias fuse"; VLOG(3) << "do not perform " + type() + "+bias fuse";
return; return;
} }
...@@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()})); desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()})); desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()})); desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType(type); desc.SetType(type());
for (auto& attr : conv->Op()->GetAttrMap()) { for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second); desc.SetAttr(attr.first, attr.second);
...@@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass); paddle::framework::ir::ConvBiasFusePass);
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DTransposeBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass); paddle::framework::ir::Conv3DBiasFusePass);
...@@ -48,10 +48,17 @@ class CPUQuantizePass : public FusePassBase { ...@@ -48,10 +48,17 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePool(Graph* graph) const; void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned, double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const; std::string scale_attr_name = "") const;
// quantize all inputs of given name with the same (minimum) scale
void QuantizeInputs(Graph* g, Node* op, std::string input_name,
VarQuantScale* scales, bool are_unsigned,
std::string scale_attr_name = "") const;
void DequantizeOutput(Graph* g, Node* op, Node* output, void DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name, double scale_to_one, std::string output_name, double scale_to_one,
bool is_unsigned, bool is_unsigned,
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册