未验证 提交 b57eac73 编写于 作者: Y Yan Chunwei 提交者: GitHub

Lite/update for x86 (#19027)

上级 fbbd8208
......@@ -19,36 +19,6 @@ set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
include(system)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cmake_minimum_required(VERSION 3.10)
# TODO(TJ): make as function check_default
if(NOT DEFINED ARM_TARGET_OS)
set(ARM_TARGET_OS "android" CACHE STRING "Choose ARM Target OS")
endif()
set(ARM_TARGET_OS_LIST "android" "armlinux") # TODO: "ios"
set_property(CACHE ARM_TARGET_OS PROPERTY STRINGS ${ARM_TARGET_OS_LIST})
if (NOT ARM_TARGET_OS IN_LIST ARM_TARGET_OS_LIST)
message(FATAL_ERROR "ARM_TARGET_OS must be in one of ${ARM_TARGET_OS_LIST}")
endif()
if(NOT DEFINED ARM_TARGET_ARCH_ABI)
set(ARM_TARGET_ARCH_ABI "arm64-v8a" CACHE STRING "Choose ARM Target ARCH ABI")
endif()
set(ARM_TARGET_ARCH_ABI_LIST "arm64-v8a" "armeabi-v7a" "armeabi-v7a-softfp" "armeabi-v7a-hf")
set_property(CACHE ARM_TARGET_ARCH_ABI PROPERTY STRINGS ${ARM_TARGET_ARCH_ABI_LIST})
if (NOT ARM_TARGET_ARCH_ABI IN_LIST ARM_TARGET_ARCH_ABI_LIST)
message(FATAL_ERROR "ARM_TARGET_ARCH_ABI must be in one of ${ARM_TARGET_ARCH_ABI_LIST}")
endif()
if(NOT DEFINED TARGET_ARCH_ABI)
set(ARCH_ABI "arm64-v8a" CACHE STRING "Choose android platform")
endif()
include(cross_compiling/host)
include(cross_compiling/armlinux)
include(cross_compiling/android)
endif()
project(paddle CXX C)
message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
"${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
......@@ -71,9 +41,7 @@ if(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
endif(WIN32)
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
find_package(CUDA QUIET)
endif()
find_package(CUDA QUIET)
find_package(Git REQUIRED)
find_package(Threads REQUIRED)
......@@ -111,79 +79,19 @@ option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VER
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
if(ANDROID OR IOS OR ARMLINUX)
set(WITH_GPU OFF CACHE STRING
"Disable GPU when cross-compiling for Android and iOS" FORCE)
set(WITH_DSO OFF CACHE STRING
"Disable DSO when cross-compiling for Android and iOS" FORCE)
set(WITH_AVX OFF CACHE STRING
"Disable AVX when cross-compiling for Android and iOS" FORCE)
set(WITH_PYTHON OFF CACHE STRING
"Disable PYTHON when cross-compiling for Android and iOS" FORCE)
set(WITH_RDMA OFF CACHE STRING
"Disable RDMA when cross-compiling for Android and iOS" FORCE)
set(WITH_MKL OFF CACHE STRING
"Disable MKL when cross-compiling for Android and iOS" FORCE)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING
"Default use Release in android" FORCE)
endif()
if(NOT THIRD_PARTY_BUILD_TYPE)
set(THIRD_PARTY_BUILD_TYPE "MinSizeRel" CACHE STRING
"Default use MinSizeRel in android" FORCE)
endif()
# PY_VERSION
if(NOT PY_VERSION)
set(PY_VERSION 2.7)
endif()
# for lite, both server and mobile framework.
option(WITH_LITE "Enable lite framework" OFF)
option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_ARM "Enable ARM in lite mode" OFF)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF)
option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
"Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
endif()
include_directories("${PADDLE_SOURCE_DIR}")
# for mobile
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
message(STATUS "Building the mobile framework")
# include the necessary thirdparty dependencies
include(external/gflags) # download, build, install gflags
include(external/glog) # download, build, install glog
include(external/gtest) # download, build, install gtest
#include(external/zlib) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf
include(external/eigen) # download eigen3
include(generic) # simplify cmake module
include(configure) # add paddle env configuration
add_definitions(-std=c++11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_subdirectory(paddle)
return()
"Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
endif()
# PY_VERSION
if(NOT PY_VERSION)
set(PY_VERSION 2.7)
endif()
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
if (APPLE)
set(WITH_MKL OFF CACHE STRING
"Disable MKL for building on mac" FORCE)
......@@ -194,12 +102,16 @@ if (WIN32)
"Disable DISTRIBUTE when compiling for Windows" FORCE)
endif()
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
"A path setting fluid shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries")
set(THIRD_PARTY_BUILD_TYPE Release)
set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN)
......@@ -273,6 +185,7 @@ if(WITH_BRPC_RDMA)
endif()
endif()
include(external/threadpool)
include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure
......@@ -321,6 +234,7 @@ include(coveralls) # set code coverage
include(inference_lib) # add paddle fluid inference libraries
include_directories("${PADDLE_SOURCE_DIR}")
if(WITH_AMD_GPU)
find_package(HIP)
......
# A image for building paddle binaries
# Use cuda devel base image for both cpu and gpu environment
# When you modify it, please be aware of cudnn-runtime version
# and libcudnn.so.x in paddle/scripts/docker/build.sh
FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
......@@ -76,7 +75,7 @@ RUN curl -s -q https://glide.sh/get | sh
# 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle.
# See https://github.com/PaddlePaddle/Paddle/issues/10129 for details.
RUN wget -q https://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \
RUN wget -q https://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \
tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \
cp -rf /usr/local/TensorRT/lib /usr
......@@ -93,17 +92,17 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
# version(1.7.1 for now), which causes building documentation failed.
RUN pip3 --no-cache-dir install -U wheel && \
RUN pip3 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \
pip3 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.6 --no-cache-dir install -U wheel && \
pip3.6 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \
pip3.6 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.6 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
pip3.7 --no-cache-dir install -U wheel && \
pip3.7 --no-cache-dir install -U wheel py-cpuinfo==5.0.0 && \
pip3.7 --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip3.7 --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \
pip --no-cache-dir install -U pip setuptools wheel && \
pip --no-cache-dir install -U pip setuptools wheel py-cpuinfo==5.0.0 && \
pip --no-cache-dir install -U docopt PyYAML sphinx==1.5.6 && \
pip --no-cache-dir install sphinx-rtd-theme==0.1.9 recommonmark
......
......@@ -98,9 +98,11 @@ We provide [English](http://www.paddlepaddle.org/documentation/docs/en/1.4/begin
We appreciate your contributions!
## Ask Questions
## Communication
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
- [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 432676488 (PaddlePaddle).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
......@@ -80,9 +80,11 @@ pip install paddlepaddle-gpu==1.4.1.post85
欢迎您的贡献!
## 答疑
## 交流与反馈
欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交
- 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议
- QQ群: 432676488 (PaddlePaddle)
- [论坛](http://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供
if(NOT WITH_GPU)
return()
endif()
set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT")
find_path(ANAKIN_INCLUDE_DIR anakin_config.h
PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include
......@@ -16,9 +12,7 @@ find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so
DOC "Path to ANAKIN library.")
if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY)
if(WITH_DSO)
set(ANAKIN_FOUND ON)
endif(WITH_DSO)
else()
set(ANAKIN_FOUND OFF)
endif()
......@@ -31,3 +25,8 @@ if(ANAKIN_FOUND)
link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()
if(ANAKIN_FOUND AND WITH_GPU AND WITH_DSO)
message(STATUS "Compile with anakin subgraph.")
set(ANAKIN_SUBGRAPH ON)
endif()
......@@ -30,6 +30,7 @@ endif(NOT WITH_PROFILER)
if(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG})
add_definitions(-DPADDLE_WITH_AVX)
elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG})
endif()
......@@ -157,29 +158,3 @@ endif(WITH_BRPC_RDMA)
if(ON_INFER)
add_definitions(-DPADDLE_ON_INFERENCE)
endif(ON_INFER)
if(WITH_WBAES)
add_definitions(-DPADDLE_WITH_WBAES)
endif(WITH_WBAES)
# for lite
# TODO(Superjomn) not work fine with the option
if (LITE_WITH_CUDA)
add_definitions("-DLITE_WITH_CUDA")
endif()
if (LITE_WITH_X86)
add_definitions("-DLITE_WITH_X86")
endif()
if (LITE_WITH_ARM)
add_definitions("-DLITE_WITH_ARM")
endif()
if (LITE_WITH_PROFILE)
add_definitions("-DLITE_WITH_PROFILE")
endif()
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK")
endif()
......@@ -141,12 +141,10 @@ endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
add_definitions("-DPADDLE_CUDA_BINVER=\"60\"")
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"70\"")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
......@@ -154,18 +152,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif()
add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"")
include_directories(${CUDA_INCLUDE_DIRS})
if(NOT WITH_DSO)
......
......@@ -96,7 +96,7 @@ if(CUDNN_FOUND)
endif()
message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. "
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ")
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ")
endif()
endif()
......@@ -38,5 +38,3 @@ ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
......@@ -12,6 +12,13 @@ if(NOT WITH_FAST_MATH)
add_definitions(-DEIGEN_FAST_MATH=0)
endif()
if(WIN32)
set(EIGEN_GIT_REPOSITORY https://github.com/wopeizl/eigen-git-mirror)
set(EIGEN_GIT_TAG support_cuda9_win)
else()
set(EIGEN_GIT_REPOSITORY https://github.com/eigenteam/eigen-git-mirror)
set(EIGEN_GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c)
endif()
if(WITH_AMD_GPU)
ExternalProject_Add(
extern_eigen3
......@@ -29,10 +36,10 @@ else()
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror"
GIT_REPOSITORY "${EIGEN_GIT_REPOSITORY}"
# eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c
GIT_TAG ${EIGEN_GIT_TAG}
PREFIX ${EIGEN_SOURCE_DIR}
DOWNLOAD_NAME "eigen"
UPDATE_COMMAND ""
......
......@@ -18,31 +18,13 @@ SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags)
SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags)
SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE)
IF(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ELSE(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}")
endif()
ExternalProject_Add(
extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS}
......@@ -50,24 +32,24 @@ ExternalProject_Add(
GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a
PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DBUILD_STATIC_LIBS=ON
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DBUILD_STATIC_LIBS=ON
-DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${OPTIONAL_ARGS}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib")
add_custom_command(TARGET extern_gflags POST_BUILD
COMMAND cmake -E copy ${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib ${GFLAGS_INSTALL_DIR}/lib/libgflags.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags)
......
......@@ -19,7 +19,7 @@ SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog)
SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE)
IF(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.lib" CACHE FILEPATH "glog library." FORCE)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/glog.lib" CACHE FILEPATH "glog library." FORCE)
SET(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4530")
ELSE(WIN32)
SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE)
......@@ -31,24 +31,6 @@ INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR})
SET(GLOG_REPOSITORY "https://github.com/google/glog.git")
SET(GLOG_TAG "v0.3.5")
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}")
endif()
ExternalProject_Add(
extern_glog
${EXTERNAL_PROJECT_LOG_ARGS}
......@@ -57,7 +39,14 @@ ExternalProject_Add(
GIT_TAG ${GLOG_TAG}
PREFIX ${GLOG_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS ${OPTIONAL_ARGS}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
......@@ -71,13 +60,6 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${GLOG_INSTALL_DIR}/lib/libglog.lib")
add_custom_command(TARGET extern_glog POST_BUILD
COMMAND cmake -E copy ${GLOG_INSTALL_DIR}/lib/glog.lib ${GLOG_INSTALL_DIR}/lib/libglog.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
......
......@@ -43,24 +43,6 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF()
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
"-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}"
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}")
if(ANDROID)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS}
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}")
endif()
ExternalProject_Add(
extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS}
......@@ -69,7 +51,14 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
GIT_TAG "release-1.8.0"
PREFIX ${GTEST_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS ${OPTIONAL_ARGS}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_GMOCK=ON
......
......@@ -38,6 +38,7 @@ IF(WIN32)
SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_LIB_DEPS ${MKLML_LIB_DIR}/msvcr120.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE()
#TODO(intel-huying):
......
......@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_GIT_TAG "127e0dedfaac8c6f2b148cc03bf5f67ac5fbe6fe")
SET(NGRAPH_GIT_TAG "4ec94acc11084a5d53418f565529310fa584899a")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
......
......@@ -142,6 +142,7 @@ IF (WIN32)
ENDIF(WIN32)
if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH)
find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH)
find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH)
......@@ -177,28 +178,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}"
PARENT_SCOPE)
SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
SET(OPTIONAL_CACHE_ARGS "")
SET(OPTIONAL_ARGS "")
IF(BUILD_FOR_HOST)
SET(OPTIONAL_ARGS
"-DCMAKE_C_COMPILER=${HOST_C_COMPILER}"
"-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}"
"-Dprotobuf_WITH_ZLIB=OFF"
"-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}")
SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}")
SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF")
ELSE()
# protobuf have compile issue when use android stl c++_static
SET(PROTOBUF_REPO "https://github.com/tensor-tang/protobuf.git")
SET(PROTOBUF_TAG "mobile")
SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF"
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}"
"-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}"
"-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}"
SET(OPTIONAL_ARGS
"-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}"
"-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}"
......@@ -206,18 +191,25 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}"
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}")
"-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}"
"-Dprotobuf_WITH_ZLIB=ON"
"-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}"
${EXTERNAL_OPTIONAL_ARGS})
SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}")
ENDIF()
IF(WIN32)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64")
ENDIF()
SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
ExternalProject_Add(
${TARGET_NAME}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PROTOBUF_SOURCES_DIR}
UPDATE_COMMAND ""
#DEPENDS zlib
DEPENDS zlib
GIT_REPOSITORY ${PROTOBUF_REPO}
GIT_TAG ${PROTOBUF_TAG}
CONFIGURE_COMMAND
......@@ -241,13 +233,6 @@ ENDFUNCTION()
SET(PROTOBUF_VERSION 3.1.0)
IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
build_protobuf(protobuf_host TRUE)
LIST(APPEND external_project_dependencies protobuf_host)
SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_host_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE)
ENDIF()
IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE)
......@@ -260,12 +245,7 @@ IF(NOT PROTOBUF_FOUND)
SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY}
CACHE FILEPATH "protoc library." FORCE)
IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf)
ELSE()
SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE)
PROMPT_PROTOBUF_LIB(extern_protobuf)
ENDIF()
SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE)
PROMPT_PROTOBUF_LIB(extern_protobuf)
ENDIF(NOT PROTOBUF_FOUND)
......@@ -29,9 +29,9 @@ INCLUDE(ExternalProject)
SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(PSLIB_VER "0.1.1" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/ps/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
......
......@@ -53,12 +53,7 @@ ExternalProject_Add(
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
add_custom_command(TARGET extern_snappy POST_BUILD
COMMAND cmake -E copy ${SNAPPY_INSTALL_DIR}/lib/snappy.lib ${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib
)
ENDIF()
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
......
......@@ -64,12 +64,7 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
)
IF(WIN32)
IF(NOT EXISTS "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}")
add_custom_command(TARGET extern_warpctc POST_BUILD
COMMAND cmake -E copy ${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} ${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}
)
ENDIF()
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
else(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
......
......@@ -56,12 +56,7 @@ else()
endif()
if (WIN32)
IF(NOT EXISTS "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
add_custom_command(TARGET extern_xxhash POST_BUILD
COMMAND cmake -E copy ${XXHASH_INSTALL_DIR}/lib/xxhash.lib ${XXHASH_INSTALL_DIR}/lib/libxxhash.lib
)
ENDIF()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib")
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/xxhash.lib")
else()
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
endif ()
......
......@@ -44,12 +44,7 @@ ExternalProject_Add(
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${ZLIB_INSTALL_DIR}/lib/libz.lib")
add_custom_command(TARGET extern_zlib POST_BUILD
COMMAND cmake -E copy ${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib ${ZLIB_INSTALL_DIR}/lib/libz.lib
)
ENDIF()
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.lib" CACHE FILEPATH "zlib library." FORCE)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/zlibstatic.lib" CACHE FILEPATH "zlib library." FORCE)
ELSE(WIN32)
SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE)
ENDIF(WIN32)
......
......@@ -93,10 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE)
find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl")
if (NOT ANDROID)
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -lrt")
endif()
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt")
endif(NOT APPLE)
set_property(GLOBAL PROPERTY FLUID_MODULES "")
......@@ -366,11 +363,10 @@ function(cc_binary TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${os_dependency_modules})
endfunction(cc_binary)
function(cc_test TARGET_NAME)
function(cc_test_build TARGET_NAME)
if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS})
if(WIN32)
......@@ -383,12 +379,18 @@ function(cc_test TARGET_NAME)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME})
endif()
endfunction()
function(cc_test_run TARGET_NAME)
if(WITH_TESTING)
set(oneValueArgs "")
set(multiValueArgs COMMAND ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
COMMAND ${cc_test_COMMAND}
ARGS ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
......@@ -396,46 +398,21 @@ function(cc_test TARGET_NAME)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
endfunction(cc_test)
endfunction()
# cc_test without default dependencies
function(raw_cc_test TARGET_NAME)
function(cc_test TARGET_NAME)
if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS})
if(WIN32)
if("${cc_test_DEPS};" MATCHES "python;")
list(REMOVE_ITEM cc_test_DEPS python)
target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES})
endif()
endif(WIN32)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} lite_gtest_main gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
cc_test_build(${TARGET_NAME}
SRCS ${cc_test_SRCS}
DEPS ${cc_test_DEPS})
cc_test_run(${TARGET_NAME}
COMMAND ${TARGET_NAME}
ARGS ${cc_test_ARGS})
endif()
endfunction(raw_cc_test)
function(_lite_cc_test args)
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
message(STATUS "building lite raw test: ${args}")
raw_cc_test(${args} ${ARGN})
else()
message(STATUS "building lite heavy test: ${args}")
cc_test(${args} ${ARGN})
endif()
endfunction()
endfunction(cc_test)
function(nv_library TARGET_NAME)
if (WITH_GPU)
......@@ -488,7 +465,6 @@ endfunction(nv_binary)
function(nv_test TARGET_NAME)
if (WITH_GPU AND WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -498,9 +474,6 @@ function(nv_test TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
......@@ -743,7 +716,7 @@ function(py_proto_compile TARGET_NAME)
cmake_parse_arguments(py_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(py_srcs)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs} protobuf)
endfunction()
function(py_test TARGET_NAME)
......
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -3,8 +3,6 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD")
set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
set(LATEST_PADDLE_VERSION "latest")
while ("${PADDLE_VERSION}" STREQUAL "")
# Check current branch name
execute_process(
......@@ -25,8 +23,8 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}")
# Check the tag is a correct version
if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
# if no tag was found, set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
# if no tag was found, set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "0.0.0")
elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name.
......@@ -44,19 +42,19 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME})
else()
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
set(PADDLE_VERSION "0.0.0")
endif()
else()
# otherwise, we always set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
# otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "0.0.0")
endif()
endif()
else()
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
set(PADDLE_VERSION "0.0.0")
message(WARNING "Cannot add paddle version from git tag")
endif()
else()
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
set(PADDLE_VERSION "0.0.0")
message(WARNING "Cannot add paddle version for wrong git branch result")
endif()
endwhile()
......
# to limit the mobile dependencies
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
endif()
add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
add_subdirectory(fluid)
此差异已折叠。
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) # for mobile
add_subdirectory(lite)
return()
endif()
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
......@@ -10,8 +6,7 @@ add_subdirectory(operators)
add_subdirectory(string)
add_subdirectory(recordio)
add_subdirectory(pybind)
add_subdirectory(train)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
add_subdirectory(lite)
add_subdirectory(train)
......@@ -29,7 +29,8 @@ add_subdirectory(io)
proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto
data_feed_proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
......@@ -124,7 +125,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type data_feed_proto)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......@@ -173,20 +174,20 @@ endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
......@@ -201,10 +202,10 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc pipeline_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc dataset_factory.cc
downpour_worker.cc pull_dense_worker.cc section_worker.cc
device_worker_factory.cc data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
......@@ -225,6 +226,8 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
cc_test(inlined_vector_test SRCS inlined_vector_test.cc)
if (NOT WIN32)
cc_test(rw_lock_test SRCS rw_lock_test.cc)
endif (NOT WIN32)
......
......@@ -85,8 +85,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);
bool success = data_feed_desc.ParseFromString(data_feed_desc_str);
PADDLE_ENFORCE(success, "Fail to parse DataFeedDesc from string:\n%s",
data_feed_desc_str.c_str());
actual_thread_num_ = thread_num;
int file_cnt = filelist.size();
......
......@@ -95,6 +95,11 @@ class BlockingQueue {
return q_.size();
}
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(q_);
}
private:
std::mutex mutex_;
std::condition_variable cv_;
......
......@@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX
#include <stdio_ext.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#endif
#include <utility>
#include "gflags/gflags.h"
......@@ -87,6 +90,13 @@ void DataFeed::CheckStart() {
PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet.");
}
void DataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
feed_vec_[i] = scope.FindVar(use_slots_[i])->GetMutable<LoDTensor>();
}
}
template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
......@@ -158,6 +168,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
mutex_for_update_memory_data_ = nullptr;
this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr;
fleet_send_sleep_seconds_ = 2;
}
template <typename T>
......@@ -366,7 +377,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_);
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_ + 1;
for (auto& vec : send_vec) {
vec.reserve(reserve_len);
}
......@@ -377,46 +388,33 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto interval = GetMemoryDataInterval();
VLOG(3) << "global shuffle data from [" << interval.first << ", "
<< interval.second << "), thread_id=" << thread_id_;
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id;
int64_t random_num = rand_r(&rand_seed);
int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
<< ", ins num=" << send_vec[j].size() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
send_vec[j].clear();
total_status.push_back(std::move(ret));
}
for (int64_t i = interval.first; i < interval.second;
i += fleet_send_batch_size_) {
for (int64_t j = 0; j < fleet_send_batch_size_ && i + j < interval.second;
++j) {
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i + j]));
}
}
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
if (send_vec[j].size() != 0) {
total_status.clear();
std::shuffle(send_index.begin(), send_index.end(),
fleet_ptr->LocalRandomEngine());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
if (send_vec[j].size() == 0) {
continue;
}
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
<< ", thread_id=" << thread_id_;
auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str);
VLOG(3) << "end send, thread_id=" << thread_id_;
total_status.push_back(std::move(ret));
send_vec[j].clear();
}
std::vector<T*>().swap(send_vec[j]);
}
for (auto& t : total_status) {
t.wait();
for (auto& t : total_status) {
t.wait();
}
sleep(fleet_send_sleep_seconds_);
}
VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_;
#endif
......@@ -436,6 +434,24 @@ std::pair<int64_t, int64_t> InMemoryDataFeed<T>::GetMemoryDataInterval() {
return std::make_pair(start, end);
}
template <typename T>
int64_t InMemoryDataFeed<T>::GetChannelDataSize() {
if (cur_channel_ == 0) {
return shuffled_ins_->Size();
} else {
return shuffled_ins_out_->Size();
}
}
template <typename T>
void InMemoryDataFeed<T>::ReleaseChannelData() {
if (cur_channel_ == 0) {
shuffled_ins_->Clear();
} else {
shuffled_ins_out_->Clear();
}
}
// explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
......@@ -471,17 +487,17 @@ void MultiSlotDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
if (slot.shape(j) == -1) {
inductive_shape_index_[i] = j;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(i));
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
......@@ -805,22 +821,24 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
if (slot.shape(j) == -1) {
inductive_shape_index_[i] = j;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(i));
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
......@@ -1001,5 +1019,205 @@ void MultiSlotInMemoryDataFeed::DeserializeIns(
fleet_ptr->Deserialize(ins, str);
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec_[i].GetType();
const auto& offset = ins_vec_[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec_[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec_[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int64_t total_dims = 1;
for (const auto e : use_slots_shape_[i]) {
total_dims *= e;
}
PADDLE_ENFORCE(
total_dims == total_instance,
"The actual data size of slot[%s] doesn't match its declaration",
use_slots_[i].c_str());
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
}
template <typename T>
int PrivateInstantDataFeed<T>::Next() {
if (ParseOneMiniBatch()) {
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
Postprocess();
std::string filename;
if (!PickOneFile(&filename)) {
return -1;
}
if (!Preprocess(filename)) {
return -1;
}
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
template <typename T>
void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
multi_inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) == -1) {
multi_inductive_shape_index_[i].push_back(j);
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
ins_vec_.resize(use_slots_.size());
finish_init_ = true;
}
template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());
struct stat sb;
fstat(fd_, &sb);
end_ = static_cast<size_t>(sb.st_size);
buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));
offset_ = 0;
return true;
}
bool MultiSlotFileInstantDataFeed::Postprocess() {
if (buffer_ != nullptr) {
munmap(buffer_, end_);
buffer_ = nullptr;
}
if (fd_ != -1) {
close(fd_);
fd_ = -1;
end_ = 0;
offset_ = 0;
}
return true;
}
bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
if (offset_ == end_) {
return false;
}
batch_size_ = 0;
while (batch_size_ < default_batch_size_ && offset_ < end_) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.");
offset_ += sizeof(uint16_t);
if (idx != -1) {
int inductive_size = multi_inductive_shape_index_[i].size();
if (UNLIKELY(batch_size_ == 0)) {
ins_vec_[idx].Init(all_slots_type_[i], default_batch_size_ * num);
ins_vec_[idx].InitOffset(default_batch_size_);
uint64_t* inductive_shape =
reinterpret_cast<uint64_t*>(buffer_ + offset_);
for (int inductive_id = 0; inductive_id < inductive_size;
++inductive_id) {
use_slots_shape_[i][multi_inductive_shape_index_[i][inductive_id]] =
static_cast<int>(*(inductive_shape + inductive_id));
}
}
num -= inductive_size;
offset_ += sizeof(uint64_t) * inductive_size;
if (type == 'f') {
ins_vec_[idx].AppendValues(
reinterpret_cast<float*>(buffer_ + offset_), num);
offset_ += num * sizeof(float);
} else if (type == 'u') {
ins_vec_[idx].AppendValues(
reinterpret_cast<uint64_t*>(buffer_ + offset_), num);
offset_ += num * sizeof(uint64_t);
}
} else {
if (type == 'f') {
offset_ += num * sizeof(float);
} else if (type == 'u') {
offset_ += num * sizeof(uint64_t);
}
}
}
++batch_size_;
// OPTIMIZE: It is better to insert check codes between instances for format
// checking
}
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_");
return true;
}
#endif
} // namespace framework
} // namespace paddle
......@@ -59,7 +59,7 @@ class DataFeed {
file_idx_ = nullptr;
}
virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented.");
}
......@@ -84,6 +84,9 @@ class DataFeed {
// This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name);
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default
......@@ -115,6 +118,9 @@ class DataFeed {
virtual void FillChannelToMemoryData() {}
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {}
virtual int64_t GetChannelDataSize() { return 0; }
// This function will do nothing at default
virtual void ReleaseChannelData() {}
protected:
// The following three functions are used to check if it is executed in this
......@@ -145,6 +151,8 @@ class DataFeed {
std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> inductive_shape_index_;
std::vector<int> total_dims_without_inductive_;
// For the inductive shape passed within data
std::vector<std::vector<int>> multi_inductive_shape_index_;
std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_
......@@ -170,7 +178,6 @@ class PrivateQueueDataFeed : public DataFeed {
public:
PrivateQueueDataFeed() {}
virtual ~PrivateQueueDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
......@@ -209,7 +216,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
......@@ -224,6 +231,8 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual int64_t GetChannelDataSize();
virtual void ReleaseChannelData();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
......@@ -248,6 +257,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
// sleep after send is to slow down sending data, but it's trick,
// should be removed later.
int64_t fleet_send_sleep_seconds_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
......@@ -255,16 +267,25 @@ class MultiSlotType {
public:
MultiSlotType() {}
~MultiSlotType() {}
void Init(const std::string& type) {
void Init(const std::string& type, size_t reserved_size = 0) {
CheckType(type);
if (type_[0] == 'f') {
float_feasign_.clear();
if (reserved_size) {
float_feasign_.reserve(reserved_size);
}
} else if (type_[0] == 'u') {
uint64_feasign_.clear();
if (reserved_size) {
uint64_feasign_.reserve(reserved_size);
}
}
type_ = type;
}
void InitOffset() {
void InitOffset(size_t max_batch_size = 0) {
if (max_batch_size > 0) {
offset_.reserve(max_batch_size + 1);
}
offset_.resize(1);
// LoDTensor' lod is counted from 0, the size of lod
// is one size larger than the size of data.
......@@ -280,6 +301,16 @@ class MultiSlotType {
CheckUint64();
uint64_feasign_.push_back(v);
}
void CopyValues(const float* input, size_t size) {
CheckFloat();
float_feasign_.resize(size);
memcpy(float_feasign_.data(), input, size * sizeof(float));
}
void CopyValues(const uint64_t* input, size_t size) {
CheckUint64();
uint64_feasign_.resize(size);
memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
}
void AddIns(const MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { // float
CheckFloat();
......@@ -293,11 +324,22 @@ class MultiSlotType {
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
}
}
void AppendValues(const uint64_t* input, size_t size) {
CheckUint64();
offset_.push_back(offset_.back() + size);
uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
}
void AppendValues(const float* input, size_t size) {
CheckFloat();
offset_.push_back(offset_.back() + size);
float_feasign_.insert(float_feasign_.end(), input, input + size);
}
const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; }
size_t GetBatchSize() { return offset_.size() - 1; }
std::string& MutableType() { return type_; }
std::string DebugString() {
......@@ -347,7 +389,7 @@ class MultiSlotDataFeed
public:
MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename);
// virtual void ReadThread();
......@@ -366,7 +408,7 @@ class MultiSlotInMemoryDataFeed
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual void Init(const DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
......@@ -381,5 +423,54 @@ class MultiSlotInMemoryDataFeed
const std::string& str);
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
public:
PrivateInstantDataFeed() {}
virtual ~PrivateInstantDataFeed() {}
void Init(const DataFeedDesc& data_feed_desc) override;
bool Start() override { return true; }
int Next() override;
protected:
// The batched data buffer
std::vector<MultiSlotType> ins_vec_;
// This function is used to preprocess with a given filename, e.g. open it or
// mmap
virtual bool Preprocess(const std::string& filename) = 0;
// This function is used to postprocess system resource such as closing file
// NOTICE: Ensure that it is safe to call before Preprocess
virtual bool Postprocess() = 0;
// The reading and parsing method.
virtual bool ParseOneMiniBatch() = 0;
// This function is used to put ins_vec to feed_vec
virtual void PutToFeedVec();
};
class MultiSlotFileInstantDataFeed
: public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotFileInstantDataFeed() {}
virtual ~MultiSlotFileInstantDataFeed() {}
protected:
int fd_{-1};
char* buffer_{nullptr};
size_t end_{0};
size_t offset_{0};
bool Preprocess(const std::string& filename) override;
bool Postprocess() override;
bool ParseOneMiniBatch() override;
};
#endif
} // namespace framework
} // namespace paddle
......@@ -64,5 +64,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
} // namespace framework
} // namespace paddle
......@@ -13,11 +13,13 @@
// limitations under the License.
#include "paddle/fluid/framework/data_layout_transform.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#endif
namespace paddle {
......@@ -145,7 +147,6 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
......@@ -156,14 +157,21 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
const std::string key = platform::ReorderMKLDNNHandler::GetHash(
in_tz, in_format, out_format, std::to_string(in_type));
auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key);
platform::Reorder(in_memory, out_memory);
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, expected_kernel_type.place_);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*reorder_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} else {
out->ShareDataWith(in);
}
......
......@@ -141,6 +141,9 @@ template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
for (int i = 0; i < readers_.size(); ++i) {
readers_[i]->ReleaseChannelData();
}
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
......@@ -178,8 +181,10 @@ void DatasetImpl<T>::GlobalShuffle() {
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance();
// local shuffle all data before global shuffle
std::shuffle(memory_data_.begin(), memory_data_.end(),
fleet_ptr->LocalRandomEngine());
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
......@@ -260,6 +265,20 @@ void DatasetImpl<T>::DestroyReaders() {
}
}
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
return memory_data_.size();
}
template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0;
for (int i = 0; i < readers_.size(); ++i) {
sum += readers_[i]->GetChannelDataSize();
}
return sum;
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
......@@ -267,7 +286,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = rand_r(&rand_seed) % thread_num_;
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
#endif
......
......@@ -85,6 +85,10 @@ class Dataset {
virtual void CreateReaders() = 0;
// destroy readers
virtual void DestroyReaders() = 0;
// get memory data size
virtual int64_t GetMemoryDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
......@@ -127,6 +131,8 @@ class DatasetImpl : public Dataset {
virtual void GlobalShuffle();
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
......
......@@ -93,6 +93,6 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass)
......@@ -35,16 +35,9 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
const platform::NCCLCommunicator *ctxs)
: NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#else
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
......@@ -71,7 +64,9 @@ void AllReduceOpHandle::RunAllReduceFuncs(
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto *nccl_ctxs =
nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
......@@ -134,21 +129,12 @@ void AllReduceOpHandle::RunImpl() {
numel = static_cast<size_t>(lod_tensor.numel());
}
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
NCCLAllReduce(p, buffer, buffer, numel,
static_cast<ncclDataType_t>(dtype), ncclSum);
});
}
VLOG(10) << "allreduce size:" << numel * SizeOfType(lod_tensors[0]->type());
RunAllReduceFuncs(all_reduce_calls);
#else
PADDLE_THROW("Not compiled with CUDA");
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......@@ -28,13 +29,15 @@ namespace paddle {
namespace framework {
namespace details {
class AllReduceOpHandle : public OpHandleBase {
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
class AllReduceOpHandle : public NCCLOpHandleBase {
public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
const platform::NCCLCommunicator *ctxs);
#else
class AllReduceOpHandle : public OpHandleBase {
public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
......@@ -46,13 +49,17 @@ class AllReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_;
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunAllReduceFuncs(
const std::vector<std::function<void()>> &all_reduce_calls);
const platform::NCCLContextMap *nccl_ctxs_;
#endif
};
......
......@@ -51,45 +51,39 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section,
trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {}, trainer_id);
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -26,6 +27,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace framework {
namespace details {
......@@ -46,6 +49,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
: ir::PassBuilder(), strategy_(strategy) {
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add graph_viz_pass";
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
......@@ -53,10 +57,27 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
VLOG(1) << "Add record_skip_memory_opt_vars_pass";
AppendPass("record_skip_memory_opt_vars_pass");
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
VLOG(1) << "Add mkldnn_placement_pass";
AppendPass("mkldnn_placement_pass");
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
if (strategy_.enable_sequential_execution_) {
VLOG(5) << "Add sequential_execution_pass";
VLOG(1) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass");
}
......@@ -67,7 +88,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
VLOG(5) << "Add fuse_relu_depthwise_conv_pass";
VLOG(1) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass");
}
......@@ -79,19 +100,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace.
if (strategy_.enable_inplace_) {
VLOG(5) << "Add inplace_pass";
VLOG(1) << "Add inplace_pass";
AppendPass("inplace_pass");
}
if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(5) << "Add fuse_elewise_add_act_pass";
VLOG(1) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) {
VLOG(5) << "Add alloc_continuous_space_for_grad_pass";
VLOG(1) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}
......@@ -106,11 +127,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(5) << "Add fuse_adam_op_pass";
VLOG(1) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(5) << "Add fuse_sgd_op_pass";
VLOG(1) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
VLOG(5) << "Add fuse_momentum_op_pass";
VLOG(1) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
}
}
......@@ -140,7 +161,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) {
VLOG(5) << "Add memory_optimize_pass";
VLOG(1) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
......@@ -148,26 +169,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(5) << "Add runtime_context_cache_pass";
VLOG(1) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
if (strategy_.cache_expected_kernel_) {
VLOG(10) << "Add expected_kernel_cache_pass";
AppendPass("expected_kernel_cache_pass");
}
AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(5) << "Add fuse_all_reduce_op_pass";
VLOG(1) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass");
}
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add multi_devices_print_pass";
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
......@@ -183,16 +200,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(5) << "Add all_reduce_deps_pass";
VLOG(1) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
if (strategy_.enable_backward_optimizer_op_deps_) {
VLOG(1) << "Add backward_op_deps_pass";
AppendPass("backward_optimizer_op_deps_pass");
}
if (strategy_.remove_unnecessary_lock_) {
VLOG(5) << "Add modify_op_lock_and_record_event_pass";
VLOG(1) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass");
}
// Verify that the graph is correct for multi-device executor.
VLOG(1) << "Add multi_devices_check_pass";
AppendPass("multi_devices_check_pass");
}
......@@ -201,18 +224,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ir::Pass *multi_devices_pass = nullptr;
if (strategy_.async_mode_) {
VLOG(1) << "Add async_multi_devices_pass";
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(5)
VLOG(1)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(5) << "Add all_reduce_mode_multi_devices_pass";
VLOG(1) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(5) << "Add reduce_mode_multi_devices_pass";
VLOG(1) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
......@@ -249,7 +273,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
platform::NCCLCommunicator *nccl_ctxs) const {
#else
const bool use_cuda) const {
#endif
......@@ -271,9 +295,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
......@@ -287,9 +311,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
&local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif
}
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
......@@ -302,6 +329,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_;
} else if (pass->Type() == "all_reduce_deps_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif
LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_;
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
......@@ -313,6 +348,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "inplace_pass") {
pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......@@ -339,6 +377,7 @@ USE_PASS(multi_devices_print_pass);
USE_PASS(memory_optimize_pass);
USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass);
USE_PASS(backward_optimizer_op_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
......@@ -349,5 +388,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
USE_PASS(record_skip_memory_opt_vars_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h"
......@@ -79,6 +80,8 @@ struct BuildStrategy {
bool fuse_all_reduce_ops_{false};
bool enable_backward_optimizer_op_deps_{false};
bool fuse_relu_depthwise_conv_{false};
bool sync_batch_norm_{false};
......@@ -108,7 +111,18 @@ struct BuildStrategy {
bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false};
bool cache_expected_kernel_{true};
std::unordered_set<std::string> mkldnn_enabled_op_types_;
size_t nccl_comm_num_{1};
// The picture is here:
// https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
bool use_hierarchical_allreduce_{false};
// Nccl ranks in a node when use hierarchical allreduce, it's setted to gpu
// cards' number in most cases.
size_t hierarchical_allreduce_inter_nranks_{0};
// Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to
// nodes number.
size_t hierarchical_allreduce_exter_nranks_{0};
// NOTE:
// Before you add new options, think if it's a general strategy that works
......@@ -135,7 +149,7 @@ struct BuildStrategy {
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
platform::NCCLCommunicator *nccl_ctxs) const;
#else
const bool use_cuda) const;
#endif
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
......@@ -65,6 +66,7 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) {
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -43,35 +44,97 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_.emplace_back(op);
}
}
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
PrepareAtomicOpDeps();
}
FeedFetchList FastThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps();
size_t num_ops = op_deps->size();
paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear();
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops);
event.reset(nullptr);
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
size_t num_complete = 0;
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
int remaining = 0;
while (true) {
remaining = remaining_;
if (remaining == 0) {
break;
}
for (int i = 0; i < remaining; ++i) {
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
}
num_complete += num_comp;
}
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
}
void FastThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops) {
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
auto &var_name = fetch_tensors.at(i);
auto fetched_var_it = fetched_vars->find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(),
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
......@@ -80,8 +143,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops.emplace_back(op);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
......@@ -94,55 +157,22 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
}
size_t num_complete = 0;
remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
int remaining = 0;
while (true) {
remaining = remaining_;
if (remaining == 0) {
break;
}
for (int i = 0; i < remaining; ++i) {
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ClearFetchOp(graph_, &fetch_ops);
exception_.ReThrow();
}
ready_fetch_ops->emplace_back(op);
}
num_complete += num_comp;
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
}
bool FastThreadedSSAGraphExecutor::RunOp(
OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete) {
try {
RunOpSync(op);
if (LIKELY(!exception_.IsCaught())) {
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
RecordOps(op);
}
++(*complete);
return true;
} catch (...) {
exception_.Catch(std::current_exception());
} else {
--remaining_;
complete_q->Push(-1UL);
return false;
......@@ -194,6 +224,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
complete_q->Push(complete);
});
}
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
......@@ -206,6 +237,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
}
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_.Catch(std::current_exception());
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool pool_;
::ThreadPool prepare_pool_;
std::vector<OpHandleBase *> traced_ops_;
bool RunOp(OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete);
......@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops);
};
} // namespace details
} // namespace framework
......
......@@ -44,17 +44,10 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
const platform::NCCLContextMap *ctxs)
: OpHandleBase(node),
const platform::NCCLCommunicator *ctxs)
: NCCLOpHandleBase(node, places, ctxs),
local_scopes_(local_scopes),
places_(places),
num_of_all_reduce_(num_of_all_reduce),
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
num_of_all_reduce_(num_of_all_reduce) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#else
......@@ -167,17 +160,14 @@ void FusedAllReduceOpHandle::RunImpl() {
auto &p = places_[i];
void *buffer = const_cast<void *>(lod_tensor_data.at(i));
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(nccl_dtype),
ncclSum, comm, stream));
NCCLAllReduce(p, buffer, buffer, numel,
static_cast<ncclDataType_t>(nccl_dtype), ncclSum);
});
}
VLOG(10) << "fusedallreduce size:" << numel * SizeOfType(dtype);
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......@@ -28,14 +29,15 @@ namespace paddle {
namespace framework {
namespace details {
struct FusedAllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
struct FusedAllReduceOpHandle : public NCCLOpHandleBase {
FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const size_t num_of_all_reduce,
const platform::NCCLContextMap *ctxs);
const platform::NCCLCommunicator *ctxs);
#else
struct FusedAllReduceOpHandle : public OpHandleBase {
FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
......@@ -52,11 +54,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
private:
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_;
size_t num_of_all_reduce_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs_;
#endif
size_t num_of_all_reduce_;
// Check the dtype of the input
void GetDTypeAndNumel(
......
......@@ -45,6 +45,7 @@ constexpr char kGraphVars[] = "vars";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
......
......@@ -20,7 +20,7 @@ namespace framework {
namespace details {
std::string OpHandleBase::DebugString() const {
std::stringstream ss;
ss << "(";
ss << Name() << "(";
for (auto *var : inputs_) {
ss << var->DebugString() << ", ";
}
......@@ -187,6 +187,11 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
std::function<void()> method = callback;
for (auto &p : dev_ctxes_) {
method = [method, p, this]() {
VLOG(10) << "cudadevicecontext:"
<< static_cast<platform::CUDADeviceContext *>(p.second)
<< ", dev_id:"
<< boost::get<platform::CUDAPlace>(p.first).device;
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
......
......@@ -95,6 +95,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
}
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -29,6 +30,8 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {}
void RPCOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place();
if (ir::IsControlDepVar(*in->Node())) {
......
......@@ -13,8 +13,8 @@
// limitations under the License.
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include <string>
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -67,6 +67,7 @@ struct ScaleLossGradFunctor {
};
void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -36,26 +36,10 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
platform::RecordEvent e("InitLocalExeScopes");
PrepareLocalExeScopes();
}
std::vector<framework::LoDTensor> fetch_data;
std::exception_ptr eptr = nullptr;
try {
......@@ -64,9 +48,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
eptr = std::current_exception();
}
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun");
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
DropLocalExeScopes();
}
......@@ -78,16 +60,40 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0;
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
VLOG(3) << "Drop local execution scope: " << local_scope;
auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName);
if (local_scope_var != nullptr) {
auto &local_scope = *local_scope_var->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
scope->EraseVars({std::string(details::kLocalExecScopeName)});
VLOG(3) << "Drop local execution scope: " << local_scope;
}
}
}
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
}
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <list>
#include <memory>
#include <string>
#include <vector>
......@@ -51,6 +52,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
bool NeedCreateLocalExeScope();
void PrepareLocalExeScopes();
private:
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
......
......@@ -30,7 +30,7 @@ namespace details {
SparseAllReduceOpHandle::SparseAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs, bool is_encoded, int nranks)
const platform::NCCLCommunicator *ctxs, bool is_encoded, int nranks)
: AllReduceOpHandle(node, local_scopes, places, ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
......@@ -102,7 +102,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
......
......@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
SparseAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs,
const platform::NCCLCommunicator *ctxs,
bool is_encoded = false, int nranks = -1);
std::string Name() const override;
......
......@@ -19,10 +19,13 @@ namespace framework {
namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}
......
......@@ -38,7 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
};
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -53,74 +53,84 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
CopyOpDeps();
VLOG(10) << "ThreadedSSAGraphExecutor::Run";
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
new BlockingQueue<VarHandleBase *>);
auto &pending_ops = op_deps->pending_ops_;
auto &pending_vars = op_deps->pending_vars_;
auto &ready_ops = op_deps->ready_ops_;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
size_t num_ops = op_deps->num_ops_;
// Step 2. Insert FetchOps
std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_ops, &pending_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
RunOp(ready_vars, op);
}
set.clear();
};
// Clean run context
run_op_futures_.clear();
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
// Keep loop until all vars are ready.
run_all_ops(ready_ops);
// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
RunOp(ready_vars, op);
}
set.clear();
};
// Clean run context
run_op_futures_.clear();
while (!pending_vars.empty()) {
// 1. Run All Ready ops
// Keep loop until all vars are ready.
run_all_ops(ready_ops);
// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) {
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
ClearFetchOp(graph_, &fetch_ops);
exception_holder_.ReThrow();
} else {
continue;
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
} else {
continue;
}
}
}
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
}
}
}
}
PADDLE_ENFORCE(ready_ops.empty());
}
PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
......@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
......@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
InsertPendingOp(&pending_ops, op);
}
}
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators.");
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
......@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_->pending_vars_.end());
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
op_deps_->ready_ops_.end());
op_deps->num_ops_ = op_deps_->num_ops_;
return std::unique_ptr<OpDependentData>(op_deps);
});
}
......@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp(
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
RunOpSync(op);
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
};
if (pool_) {
run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
}
void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow();
}
void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
} // namespace details
} // namespace framework
......
......@@ -44,6 +44,7 @@ struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_;
size_t num_ops_{0};
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
......@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<OpHandleBase *> traced_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
......@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
......@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList *fetch_data);
void PrepareOpDeps();
void CopyOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
};
} // namespace details
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <atomic>
#include <fstream>
#include <map>
#include <memory>
......@@ -35,9 +36,17 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
#define SEC_LOG \
VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \
<< "]: "
class PullDenseWorker {
public:
virtual ~PullDenseWorker() {}
......@@ -48,6 +57,7 @@ class PullDenseWorker {
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
void PullDense(bool force_update = false);
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
......@@ -92,7 +102,7 @@ class PullDenseWorker {
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() {}
DeviceWorker() { use_cvm_ = false; }
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
......@@ -114,6 +124,7 @@ class DeviceWorker {
std::shared_ptr<DataFeed> device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
bool use_cvm_;
};
class CPUWorkerBase : public DeviceWorker {
......@@ -194,5 +205,101 @@ class DownpourWorker : public HogwildWorker {
std::vector<::std::future<int32_t>> push_dense_status_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
using ScopeQueue = operators::reader::BlockingQueue<Scope*>;
class SyncFunctor {
public:
SyncFunctor(int rank_id, int rank_num, int sync_steps);
virtual ~SyncFunctor() {}
void SetSyncParam(const std::vector<std::string>& sync_param) {
sync_param_ = &sync_param;
}
void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) {
nccl_ctx_map_ = nccl_ctx_map;
}
int operator()(Scope* scope);
static std::vector<Scope*> pipeline_scopes_;
static uint64_t sync_flag_;
protected:
const int rank_id_;
const int rank_num_;
const std::vector<std::string>* sync_param_ = nullptr;
platform::NCCLContextMap* nccl_ctx_map_ = nullptr;
uint64_t sync_signal_;
const int sync_steps_;
int counter_;
void Synchronize();
};
class SectionWorker : public DeviceWorker {
public:
SectionWorker() {}
~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
void BindingDataFeedMemory() override {}
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
void TrainFiles() override;
void TrainFilesWithProfiler() override;
void PrintFetchVars() override {}
const platform::Place& place() const { return place_; }
void SetSectionIndex(int section_id) { section_id_ = section_id; }
void SetDeviceIndex(int tid) override { pipeline_id_ = tid; }
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetVarNames(const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
in_var_names_ = &in_var_names;
out_var_names_ = &out_var_names;
}
void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) {
in_scope_queue_ = in_scope_queue;
out_scope_queue_ = out_scope_queue;
}
void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; }
void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; }
void SetSectionNum(int section_num) { section_num_ = section_num; }
void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; }
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
}
SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }
static std::atomic<int> cpu_id_;
protected:
void AutoSetCPUAffinity(bool reuse);
int section_id_;
int pipeline_id_;
int section_num_;
int pipeline_num_;
int thread_id_;
// This worker will consume scope from in_scope_queue_
// and produce scope to out_scope_queue_
ScopeQueue* in_scope_queue_ = nullptr;
ScopeQueue* out_scope_queue_ = nullptr;
const std::vector<std::string>* in_var_names_ = nullptr;
const std::vector<std::string>* out_var_names_ = nullptr;
std::mutex* worker_count_mutex_ = nullptr;
int* worker_count_ = nullptr;
paddle::platform::Place next_section_place_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
} // namespace framework
} // namespace paddle
......@@ -61,5 +61,8 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif
} // namespace framework
} // namespace paddle
......@@ -63,6 +63,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm();
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
......@@ -139,14 +140,25 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod);
for (int index = 0; index < len; ++index) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
if (use_cvm_) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(),
sizeof(float) * table.emb_dim());
fea_idx++;
} else {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim());
continue;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
sizeof(float) * table.emb_dim());
continue;
fea_idx++;
}
memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
sizeof(float) * table.emb_dim());
fea_idx++;
}
}
}
......@@ -197,9 +209,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
for (auto j : param_.sparse_table()) {
if (j.table_id() == tid) {
table = j;
break;
}
}
......@@ -259,7 +271,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
......@@ -367,9 +379,9 @@ void DownpourWorker::TrainFiles() {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
for (auto j : param_.sparse_table()) {
if (j.table_id() == tid) {
table = j;
break;
}
}
......@@ -411,7 +423,7 @@ void DownpourWorker::TrainFiles() {
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_);
}
}
......
......@@ -122,8 +122,9 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
bool success = trainer_desc.ParseFromString(trainer_desc_str);
PADDLE_ENFORCE(success, "Fail to parse TrainerDesc from string:\n%s",
trainer_desc_str.c_str());
VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer;
......@@ -244,6 +245,12 @@ static bool has_fetch_operators(
return fetch_count > 0;
}
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
return Prepare(program, block_id, skip_ref_cnt_vars, force_disable_gc);
}
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
......@@ -328,7 +335,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
if (FLAGS_use_ngraph && ctx->block_id_ == 0) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
......@@ -368,6 +375,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars,
bool keep_kids) {
platform::RecordBlock b(kProgramId);
PADDLE_ENFORCE_NOT_NULL(scope);
Scope* local_scope = scope;
if (create_vars) {
......@@ -407,7 +415,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_);
if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
}
......
......@@ -83,6 +83,21 @@ class Executor {
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
std::unique_ptr<ExecutorPrepareContext> PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
......@@ -101,15 +116,6 @@ class Executor {
bool create_local_scope = true,
bool create_vars = true, bool keep_kids = false);
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
......
......@@ -281,9 +281,16 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status) {
std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm) {
#ifdef PADDLE_WITH_PSLIB
int offset = 2;
int grad_dim = emb_dim;
if (use_cvm) {
offset = 0;
grad_dim = emb_dim - 2;
}
CHECK_GE(grad_dim, 0);
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
......@@ -307,7 +314,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset;
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / dim, dim);
g_mat.rightCols(grad_dim) *= batch_size;
}
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
g += emb_dim;
......@@ -315,10 +328,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
}
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
if (use_cvm) {
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
} else {
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
}
g += emb_dim;
fea_idx++;
}
......@@ -337,6 +355,89 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::SaveModel does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
ret.wait();
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list,
float decay) {
#ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions;
for (std::string& name : var_list) {
if (name.find("batch_sum") != std::string::npos) {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
VLOG(3) << "prepare shrink dense batch_sum";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
Eigen::Map<Eigen::MatrixXf> mat(g, 1, tensor->numel());
mat *= decay;
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
} else {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push shrink dense param failed, status[" << status << "]";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib";
#endif
}
void FleetWrapper::ClientFlush() {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->flush();
ret.wait();
#else
VLOG(0) << "FleetWrapper::ServerFlush does nothing when no pslib";
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
......@@ -398,6 +499,24 @@ void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#endif
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
#ifdef PADDLE_WITH_PSLIB
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
#endif
};
thread_local engine_wrapper_t r;
return r.engine;
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
......
......@@ -55,7 +55,7 @@ namespace framework {
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {}
FleetWrapper() { scale_sparse_gradient_with_batch_size_ = true; }
// Pull sparse variables from server in Sync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Param<out>: fea_values
......@@ -99,7 +99,8 @@ class FleetWrapper {
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status);
std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
......@@ -128,6 +129,19 @@ class FleetWrapper {
// create client to client connection
void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
void ShrinkSparseTable(int table_id);
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay);
// register client to client communication
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
......@@ -146,6 +160,9 @@ class FleetWrapper {
return s_instance_;
}
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
#ifdef PADDLE_WITH_PSLIB
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif
......@@ -158,6 +175,7 @@ class FleetWrapper {
protected:
static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
// option optimize_for = LITE_RUNTIME;
option optimize_for = LITE_RUNTIME;
package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should
......
......@@ -24,9 +24,10 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
for (int i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
use_cvm_ = desc.use_cvm();
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
......
......@@ -72,12 +72,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
if(ANAKIN_FOUND)
if(ANAKIN_SUBGRAPH)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
endif()
......@@ -86,12 +86,23 @@ if(WITH_MKLDNN)
pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn)
endif()
if(WITH_NGRAPH)
cc_library(ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge
analysis_helper subgraph_detector graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(APPEND ${pass_file} "USE_PASS(ngraph_subgraph_pass);\n")
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
......@@ -115,6 +126,8 @@ if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
......
......@@ -23,15 +23,16 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes
"fuse_parameter_memory_size is up limited memory size(MB)"
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
fuse_parameter_groups_size, 1,
"fuse_parameter_groups_size is the up limited size of one group "
"parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
......@@ -41,6 +42,9 @@ DEFINE_int32(
namespace paddle {
namespace framework {
namespace ir {
// unit of the FLAGS_fuse_parameter_memory_size.
static constexpr double kMB = 1048576.0;
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
......@@ -50,15 +54,12 @@ void SetFuseParameterGroupsSize(int group_size) {
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
void SetFuseParameterMemorySize(double memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; }
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
......@@ -83,7 +84,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
LOG(WARNING) << "Doesn't find gradients";
return;
}
......@@ -169,7 +170,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void SetGroupAccordingToLayers(
......@@ -181,7 +181,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
layer_params[params_grads[i].first].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
......@@ -190,7 +190,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
std::string key = params_grads[i].first;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
......@@ -207,21 +207,40 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
VLOG(10) << "SetGroupAccordingToLayers: ";
if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params);
}
}
void PrintGroupInfo(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
size_t gps_size = 0;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
gps_size += size;
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
VLOG(10) << out.str()
<< ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << static_cast<double>(gps_size) / kMB
<< "(MB)";
}
}
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
const double group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size <= 0.0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
......@@ -248,7 +267,14 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
if (GetFuseParameterGroupsSize() > 1 &&
group_p_g.size() >
static_cast<size_t>(GetFuseParameterGroupsSize())) {
break;
}
if (static_cast<double>(local_group_memory_size) / kMB >=
group_memory_size) {
break;
}
}
......@@ -257,60 +283,10 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
"SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params);
}
}
......
......@@ -21,8 +21,8 @@ namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
void SetFuseParameterMemorySize(double memory_size);
double GetFuseParameterMemorySize();
} // namespace ir
} // namespace framework
......
......@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope.Var(param.Hidden)->GetMutable<LoDTensor>();
scope.Var(param.Cell)->GetMutable<LoDTensor>();
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
......@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
GATE_W(c);
#undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
auto* attention_fc_w = scope.FindVar("attention_fc.w_0");
auto* attention_fc_b = scope.FindVar("attention_fc.b_0");
auto* attention_output_w = scope.FindVar("attention_output.w_0");
auto* attention_output_b = scope.FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight = scope.Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias = scope.Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
......
......@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
......@@ -77,9 +78,15 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
if (base_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
......@@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
PADDLE_ENFORCE(scope);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
......@@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef GET_NODE
#define NEW_IMTERMEDIATE_OUT(key) \
scope->Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
scope.Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
NEW_IMTERMEDIATE_OUT(ReorderedH0);
NEW_IMTERMEDIATE_OUT(XX);
NEW_IMTERMEDIATE_OUT(BatchedInput);
......
......@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
......
......@@ -26,7 +26,7 @@ namespace framework {
namespace ir {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
std::unordered_set<std::string> act_types = {"relu", "scale", "tanh"};
graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types);
// backward
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <unordered_map>
namespace paddle {
namespace framework {
......@@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope;
}
void FusePassBase::AddStatis(int count_of_fused) const {
......@@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1,
#else
return FUSE_NATIVE;
#endif
};
}
} // namespace ir
} // namespace framework
......
......@@ -134,6 +134,7 @@ void Graph::ResolveHazard(
ir::Node *dep_var = CreateControlDepVar();
write_op->inputs.push_back(dep_var);
upstream_op->outputs.push_back(dep_var);
VLOG(10) << "add dep_var:" << dep_var->Name();
dep_var->outputs.push_back(write_op);
dep_var->inputs.push_back(upstream_op);
}
......@@ -157,6 +158,7 @@ void Graph::ResolveHazard(
if (has_dep) continue;
ir::Node *dep_var = CreateControlDepVar();
VLOG(10) << "add dep_var:" << dep_var->Name();
read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var);
......
......@@ -14,7 +14,10 @@
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -785,6 +788,33 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var;
}
PDNode *patterns::ConvBReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu6");
// output
auto *brelu_out_var = pattern->NewNode(brelu_out_repr())
->AsOutput()
->assert_is_op_output("relu6");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var});
return brelu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
......@@ -869,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
}
}
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Filter
auto *fc_weight_var = pattern->NewNode(weights_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias_var = pattern->NewNode(bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out_var = pattern->NewNode(output_repr())
->AsOutput()
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
return fc_out_var;
}
PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op =
......@@ -1035,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad;
}
// conv_type: conv2d, conv3d, conv2d_transpose
PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
paddle::framework::ir::PDNode *conv_input, std::string conv_type) {
// Create Operators
conv_input->assert_is_op_input(type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type);
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables
......@@ -1048,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(type, "Filter");
->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(type)
->assert_is_only_output_of_op(conv_type)
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
......@@ -1157,6 +1214,57 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var;
}
PDNode *patterns::Concat::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto output_var = pattern->NewNode(concat_out_repr())
->AsOutput()
->assert_is_op_output("concat", "Out");
concat_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ConcatReLU::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto concat_out =
pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
concat_op->LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
PDNode *patterns::ConvConcatReLU::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("relu", "X");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
conv_op->LinksTo({conv_out});
concat_op->LinksFrom({conv_out}).LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
......@@ -1641,13 +1749,16 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times,
const std::string &quant_type) {
const int kNumFields = 5;
const std::string &quant_type,
const std::string &dequant_type) {
int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
// the quant op always be one.
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale")
......@@ -1655,11 +1766,19 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
PDNode *quant_op_out_scale = nullptr;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_nth_input(dequant_type, "Scales", 1)
->AsIntermediate();
} else {
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input(dequant_type, "Scale")
->AsIntermediate();
}
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
......@@ -1680,16 +1799,25 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->assert_is_op_input(dequant_type, "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op("fake_dequantize_max_abs"));
->assert_is_op(dequant_type));
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->assert_is_op_output(dequant_type, "Out")
->AsOutput());
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(pattern
->NewNode(GetNodeName("dequant_channel_scale") +
std::to_string(i))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput());
}
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
......@@ -1699,8 +1827,14 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
} else {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
}
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
......@@ -1737,6 +1871,41 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out->LinksFrom({reshape2_op});
}
void patterns::DeleteQuantDequantOpPattern::operator()() {
auto any_op_out =
pattern->NewNode(any_op_out_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "X")
->AsInput();
auto quant_dequant_op_inscale =
pattern->NewNode(quant_dequant_op_inscale_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "InScale")
->AsInput();
auto quant_dequant_op =
pattern->NewNode(quant_dequant_op_repr())
->assert_is_op("fake_quantize_dequantize_moving_average_abs_max");
auto quant_dequant_out =
pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "Out")
->AsIntermediate();
auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "OutScale")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,3 +14,4 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
cc_library(backward_optimizer_op_deps_pass SRCS backward_optimizer_op_deps_pass.cc DEPS graph graph_helper pass)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册